In [None]:
%load_ext autoreload
%autoreload 2

from fastai.text.all import *
from fastai.vision.all import *
import pandas as pd
import torch
from tqdm.notebook import tqdm

from utils import get_dls

In [None]:
seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

# # tensorflow RNG
# tf.random.set_seed(seed)

In [None]:
path = Path("/mnt/nas/backups/08-07-2020/desktopg01/lisa/Data/small_flow")

In [14]:
dls = get_dls(path, 64, 224)

In [15]:
test_items = get_image_files(path, folders="test")

In [16]:
test_dl = dls.test_dl(test_items, with_labels=True)

In [17]:
class GetActs(HookCallback):
    def __init__(self, modules=None, remove_end=True, detach=True, cpu=True):
        super().__init__(modules, None, remove_end, True, detach, cpu)
        self.acts = L()
    def hook(self, m, i, o): return o
    def after_pred(self): self.acts += self.hooks.stored
    def before_fit(self):
        super().before_fit()
        self.acts = L()

In [18]:
learn = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat())

In [19]:
learn.add_cb(GetActs([learn.model[1][1]]))

<fastai.learner.Learner at 0x7f7f721f2520>

In [20]:
learn.load("resnet50-fine-tuned-1E-disc-6E-224_class_weights")

<fastai.learner.Learner at 0x7f7f721f2520>

In [21]:
learn.validate()

(#1) [1.2211800813674927]

In [22]:
valid_preds = learn.get_acts.acts

In [61]:
assert len(valid_preds) == len(dls.valid_ds)

In [23]:
valid_preds = torch.cat(list(valid_preds)); valid_preds.shape

torch.Size([102997, 4096])

In [27]:
for idx, item in enumerate(dls.valid.items):
    filename = Path("./activations/img")/re.search(r'val\/[^\/]*\/[^.]*',item.as_posix())[0]
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(valid_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 102996

In [169]:
dls.train.shuffle = False

In [190]:
dls.train.get_idxs()[:5]

[0, 1, 2, 3, 4]

In [183]:
learn.get_preds(0)

(tensor([[0.8197, 0.0093, 0.1101, 0.0164, 0.0068, 0.0377],
         [0.8459, 0.0065, 0.0865, 0.0528, 0.0041, 0.0042],
         [0.3050, 0.0246, 0.5869, 0.0494, 0.0226, 0.0115],
         ...,
         [0.0133, 0.0957, 0.0119, 0.1682, 0.1593, 0.5516],
         [0.0784, 0.3088, 0.0565, 0.3083, 0.2176, 0.0304],
         [0.0440, 0.0098, 0.0142, 0.0569, 0.0333, 0.8419]]),
 tensor([0, 0, 0,  ..., 5, 5, 5]))

In [184]:
train_preds = learn.get_acts.acts

In [185]:
train_preds = torch.cat(list(train_preds)); train_preds.shape

torch.Size([158308, 4096])

In [186]:
assert len(train_preds) == len(dls.train_ds)

In [188]:
for idx, item in enumerate(dls.train.items):
    filename = Path("./activations/img")/re.search(r'train\/[^\/]*\/[^.]*',item.as_posix())[0]
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(train_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 158308

In [82]:
learn.get_preds(0, test_dl)

(tensor([[0.0225, 0.5520, 0.0682, 0.1537, 0.1347, 0.0689],
         [0.2001, 0.2781, 0.0585, 0.2991, 0.1331, 0.0310],
         [0.6417, 0.1008, 0.0783, 0.0907, 0.0288, 0.0596],
         ...,
         [0.0146, 0.3446, 0.2006, 0.0970, 0.1984, 0.1448],
         [0.0521, 0.0578, 0.0330, 0.4747, 0.2682, 0.1142],
         [0.0390, 0.2792, 0.1027, 0.1181, 0.0407, 0.4203]]),
 tensor([0, 0, 0,  ..., 5, 5, 5]))

In [83]:
test_preds = learn.get_acts.acts

In [84]:
test_preds = torch.cat(list(test_preds)); test_preds.shape

torch.Size([98577, 4096])

In [85]:
assert len(test_preds) == len(test_dl.dataset)

In [124]:
for idx, item in enumerate(test_dl.items):
    filename = Path("./activations/img")/re.search(r'test\/[^\/]*\/[^.]*',item.as_posix())[0]
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(test_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 98577