In [1]:
%load_ext autoreload
%autoreload 2

from fastai.text.all import *
from fastai.vision.all import *
import pandas as pd
import torch
from tqdm.notebook import tqdm

from utils import get_dls

In [2]:
seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

# # tensorflow RNG
# tf.random.set_seed(seed)

In [3]:
path = Path("/mnt/nas/backups/08-07-2020/desktopg01/lisa/Data/small_flow")

In [4]:
dls = get_dls(path, 64, 224)

In [5]:
test_items = get_image_files(path, folders="test")

In [6]:
test_dl = dls.test_dl(test_items, with_labels=True)

In [7]:
class GetActs(HookCallback):
    def __init__(self, modules=None, remove_end=True, detach=True, cpu=True):
        super().__init__(modules, None, remove_end, True, detach, cpu)
        self.acts = L()
    def hook(self, m, i, o): return o
    def after_pred(self): self.acts += self.hooks.stored
    def before_fit(self):
        super().before_fit()
        self.acts = L()

In [30]:
learn = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat())

In [9]:
learn.add_cb(GetActs([learn.model[1][1]]))

<fastai.learner.Learner at 0x7f9199de5430>

In [10]:
learn.load("best_image_weights_224")

<fastai.learner.Learner at 0x7f9199de5430>

In [11]:
learn.validate()

(#1) [1.1870307922363281]

In [12]:
valid_preds = learn.get_acts.acts

In [13]:
valid_preds = torch.cat(list(valid_preds)); valid_preds.shape

torch.Size([102997, 4096])

In [14]:
assert len(valid_preds) == len(dls.valid_ds)

In [15]:
for idx, item in enumerate(dls.valid.items):
    filename = Path("./activations/img")/re.search(r'val\/[^\/]*\/[^.]*',item.as_posix())[0]
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(valid_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 102997

In [16]:
dls.train.shuffle = False

In [17]:
dls.train.get_idxs()[:5]

[0, 1, 2, 3, 4]

In [18]:
learn.get_preds(0)

(TensorImage([[9.4377e-01, 4.6239e-04, 2.6300e-03, 2.7966e-03, 2.3224e-04, 5.0113e-02],
         [9.9129e-01, 4.6016e-04, 1.4170e-03, 5.0892e-03, 3.4972e-04, 1.3904e-03],
         [9.6137e-01, 1.6679e-03, 7.2716e-03, 1.0404e-02, 1.2957e-03, 1.7993e-02],
         ...,
         [1.7435e-04, 3.8101e-05, 2.0885e-04, 3.6850e-04, 2.2262e-04, 9.9899e-01],
         [4.3517e-02, 1.9101e-01, 9.1897e-02, 1.7730e-01, 3.0745e-01, 1.8883e-01],
         [1.7257e-02, 1.8537e-03, 1.1543e-02, 2.6991e-02, 1.7678e-02, 9.2468e-01]]),
 TensorCategory([0, 0, 0,  ..., 5, 5, 5]))

In [19]:
train_preds = learn.get_acts.acts

In [20]:
train_preds = torch.cat(list(train_preds)); train_preds.shape

torch.Size([158308, 4096])

In [21]:
assert len(train_preds) == len(dls.train_ds)

In [22]:
for idx, item in enumerate(dls.train.items):
    filename = Path("./activations/img")/re.search(r'train\/[^\/]*\/[^.]*',item.as_posix())[0]
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(train_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 158308

In [23]:
learn.get_preds(dl=test_dl)

(TensorImage([[0.1436, 0.3127, 0.0666, 0.1007, 0.0941, 0.2823],
         [0.6757, 0.0645, 0.0299, 0.1373, 0.0516, 0.0410],
         [0.9433, 0.0076, 0.0178, 0.0195, 0.0044, 0.0075],
         ...,
         [0.0601, 0.2577, 0.0815, 0.1288, 0.0345, 0.4374],
         [0.0750, 0.0357, 0.0938, 0.2608, 0.1918, 0.3430],
         [0.0069, 0.0378, 0.0884, 0.0292, 0.0106, 0.8271]]),
 TensorCategory([0, 0, 0,  ..., 5, 5, 5]))

In [24]:
test_preds = learn.get_acts.acts

In [25]:
test_preds = torch.cat(list(test_preds)); test_preds.shape

torch.Size([98577, 4096])

In [26]:
assert len(test_preds) == len(test_dl.dataset)

In [27]:
for idx, item in enumerate(test_dl.items):
    filename = Path("./activations/img")/re.search(r'test\/[^\/]*\/[^.]*',item.as_posix())[0]
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(test_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 98577