In [6]:
%load_ext autoreload
%autoreload 2

from fastai.text.all import *
from fastai.vision.all import *
import pandas as pd
import torch
from tqdm.notebook import tqdm

from utils import get_dls

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

# # tensorflow RNG
# tf.random.set_seed(seed)

In [8]:
OUT_DIM=6

In [9]:
class GetActs(Transform):
    def encodes(self, x):        
        img_file = text_file = None
        
        if x["has_text"]:
            text_file = Path(x["activation_path"] + ".npy")
            if x["has_image"]:
                img_file = Path(text_file.as_posix().replace("text", "img").replace("npy", "pt"))
        else:
            img_file = Path(x["activation_path"] + ".pt")
        
        if img_file is None:
            img_act = torch.zeros((4096))
        else:
            img_act = torch.load(img_file)
                            
        if text_file is None:
            text_act = torch.zeros((3840))
            text_none = True
        else:
            text_act = tensor(np.load(text_file))
        
        img_none = img_file == None
        text_none = text_file == None
                            
        return (img_act, text_act, img_none, text_none)

In [10]:
class ImgTextFusion(Module):
    def __init__(self, head, embs_for_none=True, img_emb_dim=4096, text_emb_dim=3840):
        self.head = head.cuda()
        self.embs_for_none = embs_for_none
        if embs_for_none:
            self.img_none_emb = torch.nn.Embedding(num_embeddings=1, embedding_dim=img_emb_dim).cuda()
            self.text_none_emb = torch.nn.Embedding(num_embeddings=1, embedding_dim=text_emb_dim).cuda()
            self.index= tensor(0).cuda()
    
    def forward(self, x):
        img_act, text_act, img_none, text_none = x
        if self.embs_for_none:
            img_act[img_none] = self.img_none_emb(self.index)
            text_act[text_none] = self.text_none_emb(self.index)
        return self.head(torch.cat([img_act, text_act], axis=-1))

In [11]:
def create_head(nf, n_out, lin_ftrs=None, ps=0.5, bn_final=False, lin_first=False):
    "Model head that takes `nf` features, runs through `lin_ftrs`, and out `n_out` classes."
    lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]
    ps = L(ps)
    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = []
    if lin_first: layers.append(nn.Dropout(ps.pop(0)))
    for ni,no,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):
        layers += LinBnDrop(ni, no, bn=True, p=p, act=actn, lin_first=lin_first)
    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))
    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
    return nn.Sequential(*layers)

In [12]:
dls = torch.load("./data/fusion_dl_v2.pth")

In [13]:
dls.train.shuffle = False; dls.train.get_idxs()[:5]

[0, 1, 2, 3, 4]

In [14]:
dls.train.drop_last = False; dls.train.drop_last

False

In [15]:
test_dl = torch.load("./data/test_dl_fusion_text.pth")

In [16]:
head = create_head(4096 + 3840, OUT_DIM, lin_ftrs=[128])

In [17]:
model = ImgTextFusion(head)

In [54]:
learn = Learner(dls, model)

In [55]:
learn.load("best_fusion_128_moreEpochs")

<fastai.learner.Learner at 0x7fb551f80dc0>

In [56]:
class GetActs(HookCallback):
    def __init__(self, modules=None, remove_end=True, detach=True, cpu=True):
        super().__init__(modules, None, remove_end, True, detach, cpu)
        self.acts = L()
    def hook(self, m, i, o): return o
    def after_pred(self): self.acts += self.hooks.stored
    def before_fit(self):
        super().before_fit()
        self.acts = L()

In [57]:
learn.add_cb(GetActs([learn.model.head[3]]))

<fastai.learner.Learner at 0x7fb551f80dc0>

In [58]:
learn.validate()

(#1) [0.34672975540161133]

In [61]:
valid_acts = learn.get_acts.acts

In [62]:
valid_acts = torch.cat(list(valid_acts)); valid_acts.shape

torch.Size([107577, 128])

In [63]:
learn.get_preds(0)

(TensorImage([[1.0691e-05, 3.0703e-04, 1.8721e-02, 9.8094e-01, 1.1829e-05, 9.8704e-06],
         [2.2204e-08, 1.0903e-06, 6.2452e-08, 9.9998e-01, 2.2554e-05, 2.6479e-07],
         [1.0612e-08, 2.4040e-06, 2.9739e-08, 9.9992e-01, 7.7347e-05, 2.0426e-07],
         ...,
         [5.4627e-04, 1.5725e-01, 6.8867e-04, 6.6404e-01, 1.7646e-01, 1.0221e-03],
         [5.9546e-05, 1.2693e-02, 4.5226e-05, 6.1158e-01, 3.7543e-01, 1.8918e-04],
         [6.4312e-04, 3.2100e-02, 2.9593e-04, 8.6563e-01, 1.0007e-01, 1.2532e-03]]),
 TensorCategory([3, 3, 3,  ..., 4, 4, 4]))

In [65]:
train_acts = learn.get_acts.acts

In [66]:
train_acts = torch.cat(list(train_acts)); train_acts.shape

torch.Size([162187, 128])

In [67]:
assert len(train_acts) == len(dls.train_ds)

In [68]:
learn.get_preds(dl=test_dl)

(TensorImage([[2.8076e-11, 8.1904e-08, 1.8519e-11, 1.0000e+00, 2.9645e-07, 6.3525e-10],
         [1.7164e-11, 3.6945e-08, 2.5709e-11, 1.0000e+00, 1.1012e-06, 1.4324e-10],
         [1.1687e-08, 1.5438e-03, 3.3145e-07, 9.9845e-01, 7.9952e-06, 4.3045e-09],
         ...,
         [9.1729e-12, 2.2881e-07, 4.8685e-11, 1.0000e+00, 6.5150e-07, 1.8891e-10],
         [2.4747e-11, 3.0487e-08, 7.3309e-12, 1.0000e+00, 1.5077e-07, 9.6348e-10],
         [2.4747e-11, 3.0487e-08, 7.3309e-12, 1.0000e+00, 1.5077e-07, 9.6348e-10]]),
 TensorCategory([3, 3, 3,  ..., 3, 3, 3]))

In [69]:
test_acts = learn.get_acts.acts

In [70]:
test_acts = torch.cat(list(test_acts)); test_acts.shape

torch.Size([95526, 128])

In [71]:
assert len(test_acts) == len(test_dl.dataset)

In [72]:
train_idx = tensor(dls.train.items[dls.train.items["has_text"]].index.values)

In [73]:
train_acts_filtered = torch.index_select(train_acts, 0, train_idx)

In [74]:
valid_idx = tensor(dls.valid.items[dls.valid.items["has_text"]].index.values)

In [75]:
valid_acts_filtered = torch.index_select(valid_acts, 0, valid_idx)

In [77]:
train_acts_filtered.shape, valid_acts_filtered.shape, test_acts.shape

(torch.Size([149217, 128]), torch.Size([94735, 128]), torch.Size([95526, 128]))

In [97]:
for idx, item in dls.valid.items[dls.valid.items["has_text"]].iterrows():
    filename = Path(item["activation_path"].replace("text", "fusion"))
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(valid_acts[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 94735

In [98]:
for idx, item in dls.train.items[dls.train.items["has_text"]].iterrows():
    filename = Path(item["activation_path"].replace("text", "fusion"))
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(train_acts[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 149217

In [99]:
for idx, item in test_dl.items[test_dl.items["has_text"]].iterrows():
    filename = Path(item["activation_path"].replace("text", "fusion"))
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(test_acts[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 95526