In [None]:
from fastai.vision.all import *
import pandas as pd
import numpy as np
import albumentations
import cv2

In [None]:
IMG_SIZE = 448
SAMPLE = False

In [None]:
# using abhishek folds: https://www.kaggle.com/abhishek/step-1-create-folds
df = pd.read_csv('../input/step-1-create-folds/train_folds.csv')
frac = 0.1 if SAMPLE else 1
EPOCHS = 1 if SAMPLE else 10
nfold = 2 if SAMPLE else 5
df = df.sample(frac=frac, random_state=42).reset_index(drop=True)
path = Path('../input/commonlit-spacy-images/images')

In [None]:
class BagOfImagesModel(Module):
    def __init__(self, encoder):
        self.encoder = encoder
        self.bn1 = nn.BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.d1 = nn.Dropout(p=0.25, inplace=False)
        self.l1 = nn.Linear(in_features=4096, out_features=512, bias=False)
        self.bn2 = nn.BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.d2 = nn.Dropout(p=0.5, inplace=False)
        self.l2 = nn.Linear(in_features=512, out_features=1, bias=False)

    def forward(self, imgs):
        b,n,ch,h,w = imgs.shape
        unrolled = imgs.reshape(-1,ch,h,w)
        ftrs = self.encoder(unrolled).squeeze()
        num_ftrs = ftrs.shape[-1]
        ftrs = ftrs.reshape(b,n,num_ftrs)        
        ftrs_max = torch.max(ftrs, 1, keepdim=True)[0].squeeze()
        ftrs_mean = torch.mean(ftrs, 1, keepdim=True).squeeze()
        if b == 1: # error with batch size 1 being squeezed out above
            ftrs_max = ftrs_max[None, ...]
            ftrs_mean = ftrs_mean[None, ...]
        ftrs_cat = torch.cat([ftrs_max, ftrs_mean], 1)
        x = self.bn1(ftrs_cat)
        x = self.d1(x)
        x = self.l1(x)
        x = F.relu(x)
        x = self.bn2(x)
        x = self.d2(x)
        out = self.l2(x) 
        return out

In [None]:
aug = albumentations.Compose([
        albumentations.LongestMaxSize(max_size=IMG_SIZE, p=1.0),
        albumentations.PadIfNeeded(min_height=IMG_SIZE, min_width=IMG_SIZE, border_mode=0, value=0., p=1.0),
        albumentations.Normalize(p=1.0)],
    p=1.)

In [None]:
class ImageBagDataset(torch.utils.data.Dataset):
    def __init__(self, df, max_imgs, path, aug):
        self.df = df
        self.max_imgs = max_imgs
        self.path = path
        self.aug = aug
        
    def __getitem__(self, i):
        image_id = self.df['id'].loc[i]
        target = self.df['target'].loc[i]
        target = torch.tensor(target, dtype=torch.float)
        img_folder = self.path/image_id
        num_imgs = len(img_folder.ls())
        if num_imgs <= self.max_imgs:
            img_paths = [self.path/f'{image_id}/{i}.png' for i in range(num_imgs)]
            imgs = [self._open_img(x) for x in img_paths]
            npad = self.max_imgs - len(imgs)
            imgs += [torch.zeros(3,IMG_SIZE,IMG_SIZE)] * npad
        if num_imgs > self.max_imgs:
            w = torch.ones(num_imgs)
            idxs = torch.multinomial(w, self.max_imgs)
            img_paths = [self.path/f'{image_id}/{i}.png' for i in idxs]
            imgs = [self._open_img(x) for x in img_paths]
        imgs = torch.stack(imgs)
        return (imgs, target)
    
    def __len__(self): 
        return len(self.df)
    
    def _open_img(self, x):
        img = cv2.imread(str(x), cv2.IMREAD_UNCHANGED)[...,:3]
        img = self.aug(image=img)['image']
        img = torch.tensor(img, dtype=torch.float)
        img = img.permute(2,0,1)
        return img

In [None]:
# for test, we will use all sentences / images for each example with batch size = 1

class TestImageBagDataset(torch.utils.data.Dataset):
    def __init__(self, df, path, aug):
        self.df = df
        self.path = path
        self.aug = aug
        
    def __getitem__(self, i):
        image_id = self.df['id'].loc[i]
        target = self.df['target'].loc[i]
        target = torch.tensor(target, dtype=torch.float)
        img_folder = self.path/image_id
        num_imgs = len(img_folder.ls())
        img_paths = [self.path/f'{image_id}/{i}.png' for i in range(num_imgs)]
        imgs = [self._open_img(x) for x in img_paths]
        imgs = torch.stack(imgs)
        return (imgs, target)
    
    def __len__(self): 
        return len(self.df)
    
    def _open_img(self, x):
        img = cv2.imread(str(x), cv2.IMREAD_UNCHANGED)[...,:3]
        img = self.aug(image=img)['image']
        img = torch.tensor(img, dtype=torch.float)
        img = img.permute(2,0,1)
        return img

In [None]:
from matplotlib import pyplot as plt
def visualize(image):
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(image)

dataset = ImageBagDataset(df, 3, path, aug)
visualize(dataset[1][0][0].permute(1,2,0))

In [None]:
def train_fold(k):
    df_train = df[df.kfold != k].reset_index(drop=True)
    df_valid = df[df.kfold == k].reset_index(drop=True)

    train_ds = ImageBagDataset(df_train, 3, path, aug)
    valid_ds = ImageBagDataset(df_valid, 3, path, aug)

    dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=8)
    dls = dls.cuda()

    encoder = create_body(resnet50, cut=-1)
    net = BagOfImagesModel(encoder)
    net = net.cuda()

    learn = Learner(dls, net, loss_func=MSELossFlat(), metrics=rmse, model_dir="./model/").to_fp16()
    learn.fit_one_cycle(EPOCHS, lr_max=3e-3)
    learn.save(f'model_{k}')

    # validate with full set of images, not just a random 2 !!!
    test_ds = TestImageBagDataset(df_valid, path, aug)
    test_dls = DataLoaders.from_dsets(test_ds, test_ds, bs=1) # a hack to get dataloaders, there is probably a better way
    learn.dls = test_dls.cuda()
    fin_loss, fin_rmse = learn.validate()
    return fin_rmse
    

In [None]:
rmses = []
for k in range(nfold):
    fin_rmse = train_fold(k)
    rmses.append(fin_rmse)

print(np.array(rmses).mean())
print(rmses)

In [None]:
!zip -qr models.zip model
!rm -r model