In [11]:
%load_ext autoreload
%autoreload 2

from fastai.text.all import *
from fastai.vision.all import *
import pandas as pd
import torch
from tqdm.notebook import tqdm

from utils import get_dls

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

# # tensorflow RNG
# tf.random.set_seed(seed)

In [13]:
path = Path("/mnt/nas/backups/08-07-2020/desktopg01/lisa/Data/small_flow")

In [None]:
dls = get_dls(path, 64, 224)

In [None]:
test_items = get_image_files(path, folders="test")

In [None]:
test_dl = dls.test_dl(test_items, with_labels=True)

In [None]:
class GetActs(HookCallback):
    def __init__(self, modules=None, remove_end=True, detach=True, cpu=True):
        super().__init__(modules, None, remove_end, True, detach, cpu)
        self.acts = L()
    def hook(self, m, i, o): return o
    def after_pred(self): self.acts += self.hooks.stored
    def before_fit(self):
        super().before_fit()
        self.acts = L()

In [None]:
learn = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat())

In [None]:
learn.add_cb(GetActs([learn.model[1][1]]))

In [None]:
learn.load("resnet50-fine-tuned-1E-disc-6E-224_class_weights")

In [None]:
learn.validate()

In [None]:
valid_preds = learn.get_acts.acts

In [None]:
valid_preds = torch.cat(list(valid_preds)); valid_preds.shape

In [None]:
n = np.random.randint(0, valid_preds.shape[0])
learn.predict(dls.valid_ds[n][0])
assert valid_preds[n] == learn.get_acts.acts

In [None]:
for idx, item in enumerate(dls.valid.items):
    filename = Path("./activations/img")/re.search(r'val\/[^\/]*\/[^.]*',item.as_posix())[0]
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(valid_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx}", end='\r', flush=True)

In [None]:
learn.validate(0)

In [None]:
train_preds = learn.get_acts.acts

In [None]:
train_preds = torch.cat(list(train_preds)); train_preds.shape

In [None]:
n = np.random.randint(0, train_preds.shape[0])
learn.predict(dls.train_ds[n][0])
assert train_preds[n] == learn.get_acts.acts

In [None]:
for idx, item in enumerate(dls.train.items):
    filename = Path("./activations/img")/re.search(r'train\/[^\/]*\/[^.]*',item.as_posix())[0]
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(train_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx}", end='\r', flush=True)

In [None]:
learn.validate(0, test_dl)

In [None]:
test_preds = learn.get_acts.acts

In [None]:
test_preds = torch.cat(list(test_preds)); test_preds.shape