In [1]:
%load_ext autoreload
%autoreload 2

from fastai.text.all import *
from fastai.vision.all import *
import pandas as pd
import torch
from tqdm.notebook import tqdm

from utils import get_dls

In [2]:
seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

# # tensorflow RNG
# tf.random.set_seed(seed)

In [3]:
path = Path("/mnt/nas/databases/Tobacco800/unziped/page_imgs/raw")

In [4]:
df = pd.read_csv('/mnt/nas/databases/Tobacco800/unziped/train.csv', delimiter=';',  usecols=['binder','docid','class'])

In [37]:
df['split'] = 'train';print(df.shape); df.head()

(1031, 4)


Unnamed: 0,binder,docid,class,split
0,Tobacco800,aah97e00-page02_1,FirstPage,train
1,Tobacco800,aah97e00-page02_2,NextPage,train
2,Tobacco800,aam09c00,FirstPage,train
3,Tobacco800,aao54e00_1,FirstPage,train
4,Tobacco800,aao54e00_2,NextPage,train


In [38]:
df['split'][-200:] = 'valid'; df.iloc[-202:-198];

In [39]:
df

Unnamed: 0,binder,docid,class,split
0,Tobacco800,aah97e00-page02_1,FirstPage,train
1,Tobacco800,aah97e00-page02_2,NextPage,train
2,Tobacco800,aam09c00,FirstPage,train
3,Tobacco800,aao54e00_1,FirstPage,train
4,Tobacco800,aao54e00_2,NextPage,train
...,...,...,...,...
1026,Tobacco800,thl51a00-page02_2,NextPage,valid
1027,Tobacco800,tji44a00,FirstPage,valid
1028,Tobacco800,tjr72f00-page02_1,FirstPage,valid
1029,Tobacco800,tjr72f00-page02_2,NextPage,valid


In [40]:
df_test = pd.read_csv('/mnt/nas/databases/Tobacco800/unziped/test.csv', delimiter=';',  usecols=['binder','docid','class'])

In [41]:
df_test['split'] = 'test'

In [42]:
df = pd.concat([df, df_test], axis=0); df.reset_index(drop=True, inplace=True); df

Unnamed: 0,binder,docid,class,split
0,Tobacco800,aah97e00-page02_1,FirstPage,train
1,Tobacco800,aah97e00-page02_2,NextPage,train
2,Tobacco800,aam09c00,FirstPage,train
3,Tobacco800,aao54e00_1,FirstPage,train
4,Tobacco800,aao54e00_2,NextPage,train
...,...,...,...,...
1285,Tobacco800,zrz94a00-page02_2,NextPage,test
1286,Tobacco800,zss86d00,FirstPage,test
1287,Tobacco800,ztz52d00-page02_1,FirstPage,test
1288,Tobacco800,ztz52d00-page02_2,NextPage,test


In [43]:
def splitter(df):
    train = df[df['split']=='train'].index.tolist()
    valid = df[df['split']=='valid'].index.tolist()
    test = df[df['split']=='test'].index.tolist()
    return train,valid, test

In [44]:
def get_x(r): return path/f'{r["docid"]}.tif'
def get_y(r): return r['class']

In [45]:
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   get_x=get_x,
                   get_y=get_y,
                   splitter=splitter,
                   item_tfms=Resize(460),
                   batch_tfms=[*aug_transforms(size=224, min_scale=0.9,
                                               do_flip=False, max_rotate=0,
                                               max_warp=0),
                               Normalize.from_stats(*imagenet_stats)])

In [46]:
dls = dblock.dataloaders(df, bs=64)

In [47]:
#test_dl = dls.test_dl(test_items, with_labels=True)

In [48]:
class GetActs(HookCallback):
    def __init__(self, modules=None, remove_end=True, detach=True, cpu=True):
        super().__init__(modules, None, remove_end, True, detach, cpu)
        self.acts = L()
    def hook(self, m, i, o): return o
    def after_pred(self): self.acts += self.hooks.stored
    def before_fit(self):
        super().before_fit()
        self.acts = L()

In [49]:
learn = cnn_learner(dls, resnet50, loss_func=CrossEntropyLossFlat())

In [50]:
learn.add_cb(GetActs([learn.model[1][1]]))

<fastai.learner.Learner at 0x7f065c20d250>

In [51]:
learn.load("best_image_no_weights_224")

<fastai.learner.Learner at 0x7f065c20d250>

In [52]:
learn.validate()

(#1) [0.6184462904930115]

In [53]:
valid_preds = learn.get_acts.acts

In [54]:
valid_preds = torch.cat(list(valid_preds)); valid_preds.shape

torch.Size([200, 4096])

In [55]:
sample = list(dls.valid.items)

In [56]:
sample[2]

'class'

In [58]:
assert len(valid_preds) == len(dls.valid_ds)

In [59]:
for idx, item in enumerate(dls.valid.items['docid']):
    filename = Path("./activations/img")/item
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(valid_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 200

In [60]:
dls.train.shuffle = False

In [61]:
dls.train.get_idxs()[:5]

[0, 1, 2, 3, 4]

In [62]:
learn.get_preds(0)

(tensor([[0.8982, 0.1018],
         [0.0591, 0.9409],
         [0.8528, 0.1472],
         ...,
         [0.9824, 0.0176],
         [0.9203, 0.0797],
         [0.9355, 0.0645]]),
 TensorCategory([0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
         0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0,
         0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0,
         0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
         1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1,
         1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0,
         0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0,
         0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1,
         1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0,
    

In [63]:
train_preds = learn.get_acts.acts

In [64]:
train_preds = torch.cat(list(train_preds)); train_preds.shape

torch.Size([831, 4096])

In [65]:
assert len(train_preds) == len(dls.train_ds)

In [66]:
for idx, item in enumerate(dls.train.items['docid']):
    filename = Path("./activations/img")/item
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(train_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 831

In [None]:
learn.get_preds(ds_idx=2)

In [69]:
test_preds = learn.get_acts.acts

In [70]:
test_preds = torch.cat(list(test_preds)); test_preds.shape

torch.Size([259, 4096])

In [89]:
dls[2].items

Unnamed: 0,binder,docid,class,split
1031,Tobacco800,tkj51f00_1,FirstPage,test
1032,Tobacco800,tkj51f00_2,NextPage,test
1033,Tobacco800,tkj51f00_3,NextPage,test
1034,Tobacco800,tkj51f00_4,NextPage,test
1035,Tobacco800,tkj51f00_5,NextPage,test
...,...,...,...,...
1285,Tobacco800,zrz94a00-page02_2,NextPage,test
1286,Tobacco800,zss86d00,FirstPage,test
1287,Tobacco800,ztz52d00-page02_1,FirstPage,test
1288,Tobacco800,ztz52d00-page02_2,NextPage,test


In [92]:
assert len(test_preds) == len(dls[2].items)

In [90]:
for idx, item in enumerate(dls[2].items['docid']):
    filename = Path("./activations/img")/item
    filename.parent.mkdir(parents=True, exist_ok=True)
    torch.save(test_preds[idx].clone(), filename.as_posix() + ".pt")
    print(f"Saving example {idx+1}", end='\r', flush=True)

Saving example 259