In [6]:
%load_ext autoreload
%autoreload 2

from fastai.text.all import *
from fastai.vision.all import *
import pandas as pd
import torch
from tqdm.notebook import tqdm

from utils import get_dls

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

# # tensorflow RNG
# tf.random.set_seed(seed)

In [8]:
OUT_DIM=6

In [9]:
class GetActs(Transform):
    def encodes(self, x):        
        img_file = text_file = None
        
        if x["has_text"]:
            text_file = Path(x["activation_path"] + ".npy")
            if x["has_image"]:
                img_file = Path(text_file.as_posix().replace("text", "img").replace("npy", "pt"))
        else:
            img_file = Path(x["activation_path"] + ".pt")
        
        if img_file is None:
            img_act = torch.zeros((4096))
        else:
            img_act = torch.load(img_file)
                            
        if text_file is None:
            text_act = torch.zeros((3840))
            text_none = True
        else:
            text_act = tensor(np.load(text_file))
        
        img_none = img_file == None
        text_none = text_file == None
                            
        return (img_act, text_act, img_none, text_none)

In [10]:
class ImgTextFusion(Module):
    def __init__(self, head, embs_for_none=True, img_emb_dim=4096, text_emb_dim=3840):
        self.head = head.cuda()
        self.embs_for_none = embs_for_none
        if embs_for_none:
            self.img_none_emb = torch.nn.Embedding(num_embeddings=1, embedding_dim=img_emb_dim).cuda()
            self.text_none_emb = torch.nn.Embedding(num_embeddings=1, embedding_dim=text_emb_dim).cuda()
            self.index= tensor(0).cuda()
    
    def forward(self, x):
        img_act, text_act, img_none, text_none = x
        if self.embs_for_none:
            img_act[img_none] = self.img_none_emb(self.index)
            text_act[text_none] = self.text_none_emb(self.index)
        return self.head(torch.cat([img_act, text_act], axis=-1))

In [11]:
def create_head(nf, n_out, lin_ftrs=None, ps=0.5, bn_final=False, lin_first=False):
    "Model head that takes `nf` features, runs through `lin_ftrs`, and out `n_out` classes."
    lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]
    ps = L(ps)
    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = []
    if lin_first: layers.append(nn.Dropout(ps.pop(0)))
    for ni,no,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):
        layers += LinBnDrop(ni, no, bn=True, p=p, act=actn, lin_first=lin_first)
    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))
    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
    return nn.Sequential(*layers)

In [12]:
dls = torch.load("./data/fusion_dl_v2.pth")

In [13]:
dls.train.shuffle = False; dls.train.get_idxs()[:5]

[0, 1, 2, 3, 4]

In [14]:
dls.train.drop_last = False; dls.train.drop_last

False

In [15]:
test_dl = torch.load("./data/test_dl_fusion_text.pth")

In [16]:
head = create_head(4096 + 3840, OUT_DIM, lin_ftrs=[128])

In [17]:
model = ImgTextFusion(head)

In [54]:
learn = Learner(dls, model)

In [55]:
learn.load("best_fusion_128_moreEpochs")

<fastai.learner.Learner at 0x7fb551f80dc0>

In [56]:
class GetActs(HookCallback):
    def __init__(self, modules=None, remove_end=True, detach=True, cpu=True):
        super().__init__(modules, None, remove_end, True, detach, cpu)
        self.acts = L()
    def hook(self, m, i, o): return o
    def after_pred(self): self.acts += self.hooks.stored
    def before_fit(self):
        super().before_fit()
        self.acts = L()

In [57]:
learn.add_cb(GetActs([learn.model.head[3]]))

<fastai.learner.Learner at 0x7fb551f80dc0>

In [58]:
learn.validate()

(#1) [0.34672975540161133]

In [61]:
valid_acts = learn.get_acts.acts

In [62]:
valid_acts = torch.cat(list(valid_acts)); valid_acts.shape

torch.Size([107577, 128])

In [None]:
learn.get_preds(0)

In [None]:
train_acts = learn.get_acts.acts

In [None]:
train_acts = torch.cat(list(train_acts)); train_acts.shape

In [None]:
assert len(train_acts) == len(dls.train_ds)

In [None]:
learn.get_preds(dl=test_dl)

In [None]:
test_acts = learn.get_acts.acts

In [None]:
test_acts = torch.cat(list(test_acts)); test_acts.shape

In [None]:
assert len(test_acts) == len(test_dl.dataset)

In [None]:
train_idx = tensor(dls.train.items[dls.train.items["has_text"]].index.values)

In [None]:
train_acts_filtered = torch.index_select(train_acts, 0, train_idx)

In [None]:
valid_idx = tensor(dls.valid.items[dls.valid.items["has_text"]].index.values)

In [None]:
valid_acts_filtered = torch.index_select(valid_acts, 0, valid_idx)

In [None]:
train_acts_filterd.shape, valid_acts_filtered.shape, test_acts.shape

In [None]:
for idx, item in enumerate(dls.valid.items):
    filename = Path("./activations/fusion")/re.search(r'val\/[^\/]*\/[^.]*',item.as_posix())[0]
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(valid_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)