In [42]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import numpy as np
import torch

from collections import defaultdict

from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence
from tqdm.auto import tqdm

from loader.dataset import VizwizDataset
from loader.images import ImageS3
from commons.utils import embedding_matrix, tensor_to_word_fn
from models.resnext101 import TransformerAttention
from eval.metrics import bleu, cider, rouge, spice, meteor, bleu_score_fn


from IPython.core.display import HTML

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# %pip install ipywidgets, nltk

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
train = VizwizDataset(dtype='train', ret_type='tensor', copy_img_to_mem=False, device=device, partial=100)
vocabulary = train.getVocab()

loading annotations into memory...
Done (t=0.30s)
creating index...
index created! imgs = 23431, anns = 100575
tokenizing caption... done!!


In [49]:
val = VizwizDataset(dtype='val', ret_type='tensor', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=50)

loading annotations into memory...
Done (t=0.10s)
creating index...
index created! imgs = 7750, anns = 33145
tokenizing caption... done!!


In [50]:
train_eval = VizwizDataset(dtype='train', ret_type='corpus', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=100)
val_eval = VizwizDataset(dtype='val', ret_type='corpus', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=50)

loading annotations into memory...
Done (t=0.41s)
creating index...
index created! imgs = 23431, anns = 100575
tokenizing caption... done!!
loading annotations into memory...
Done (t=0.09s)
creating index...
index created! imgs = 7750, anns = 33145
tokenizing caption... done!!


In [23]:
test = VizwizDataset(dtype='test', ret_type='tensor', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=50)

loading annotations into memory...
Done (t=0.02s)
creating index...
index created! imgs = 8000, anns = 0
no caption found... done!!


In [37]:
MODEL = 'resnext101_attention'
GLOVE_DIR = 'annotations/glove'
ENCODER_DIM = 2048
EMBEDDING_DIM = 300
ATTENTION_DIM = 256
DECODER_DIM = 256

BATCH_SIZE = 10 # 128
LOG_INTERVAL = 25 * (256 // BATCH_SIZE)
LR = 5e-4
NUM_EPOCHS = 2
SAVE_FREQ = 10


In [33]:
embedding_mtx = embedding_matrix(embedding_dim=EMBEDDING_DIM, word2idx=vocabulary.word2idx, glove_dir=GLOVE_DIR)
embedding_mtx.shape

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=378.0), HTML(value='')))




(378, 300)

In [None]:
def fit(train_loader, val_loader, model, loss_fn, optimizer, desc=''):
    
    # train model
    train_acc = 0.0
    train_loss = 0.0
    
    model.train()
    t = tqdm(iter(train_loader), desc=f'{desc}_train')
    for batch_idx, batch in enumerate(t):
        images, captions, lengths, fname = batch
        
        optimizer.zero_grad()
        scores, captions_sorted, decode_len, alphas, sort_idx = model(images, captions, lengths)
        # exclude <start> and only includes after <start> to <end>
        targets = captions_sorted[:, 1:]
        # remove pads or timesteps that were not decoded
        scores = pack_padded_sequence(scores, decode_len, batch_first=True)[0]
        targets = pack_padded_sequence(targets, decode_len, batch_first=True)[0]
        
        loss = loss_fn(scores, targets)
        loss.backward()
        optimizer.step()
        
        train_acc += (torch.argmax(scores, dim=1) == targets).sum().float().item() / targets.size(0)
        train_loss += loss.item()
        
        t.set_postfix({
            'loss': train_loss / (batch_idx + 1),
            'acc': train_acc / (batch_idx + 1)
        }, refresh=True)
        
        if (batch_idx + 1) % LOG_INTERVAL == 0:
            print(f'{desc}_train {batch_idx + 1}/{len(train_loader)} '
                  f'train_loss: {train_loss / (batch_idx + 1):.4f} '
                  f'train_acc: {train_acc / (batch_idx + 1):.4f}')
            
    train_loss_mean = train_loss / len(train_loader)

    # validate model
    val_acc = 0.0
    val_loss = 0.0
    
    model.eval()
    t_val = tqdm(iter(val_loader), desc=f'{desc}_val')
    for batch_idx, batch in enumerate(t_val):
        images, captions, lengths, fname = batch ## TODO: need to unit test this
        
        scores, captions_sorted, decode_len, alphas, sort_idx = model(images, captions, lengths)
        # exclude <start> and only includes after <start> to <end>
        targets = captions_sorted[:, 1:]
        # remove pads or timesteps that were not decoded
        scores = pack_padded_sequence(scores, decode_len, batch_first=True)[0]
        targets = pack_padded_sequence(targets, decode_len, batch_first=True)[0]
        
        loss = loss_fn(scores, targets)
        val_acc += (torch.argmax(scores, dim=1) == targets).sum().float().item() / targets.size(0)
        val_loss += loss.item()
        
        t_val.set_postfix({
            'loss': val_loss / (batch_idx + 1),
            'acc': val_acc / (batch_idx + 1)
        }, refresh=True)
        
        if (batch_idx + 1) % LOG_INTERVAL == 0:
            print(f'{desc}_val {batch_idx + 1}/{len(val_loader)} '
                  f'val_loss: {val_loss / (batch_idx + 1):.4f} '
                  f'val_acc: {val_acc / (batch_idx + 1):.4f}')
    
    val_loss_mean = val_loss / len(val_loader)
    
    return train_loss_mean, val_loss_mean


def evaluate(dataset, model, loss_fn, tensor_to_word_fn, 
             transformation, dtype='val', device=torch.device('cpu'), desc=''):
    
    model.eval()
    
    pred_byfname = dict()
    caps_byfname = defaultdict(list)
    evals = dict()
    
    bleus, ciders, rouges = [], [], []
    running_bleu = [0.0] * 5
    
    captions_df = dataset.df.groupby('file_name').agg(captions=('caption', list)).reset_index()
    
    t = tqdm(df.iterrows(), desc=f'{desc}')
    for idx, row in enumerate(t):
        fname = row[1]['file_name']
        captions = row[1]['captions']
        
        fpath = os.path.join('vizwiz', dtype, fname)
        img = transformations(ImageS3.getImage(fpath)).to(device)
        
        outputs = tensor_to_word_fn(model.sample(img, startseq_idx=vocabulary.word2idx['<start>']).cpu().numpy())
        sentence = ' '.join(outputs)
        
        for i in range(1, 5):
            running_bleu[i] += bleu_score_fn(reference_corpus=captions, candidate_corpus=outputs, n=i)
        t.set_postfix({
            'bleu1': running_bleu[1] / (idx + 1),
            'bleu4': running_bleu[4] / (idx + 1),
        }, refresh=True)
        
        pred_byfname[fname] = [sentence]
        caps_byfname[fname] = captions
    
    _bleu = bleu(caps_byfname, pred_byfname)
    _cider = cider(caps_byfname, pred_byfname)
    _rouge = rouge(caps_byfname, pred_byfname)
    
    for i in range(1, 5):
        running_bleu[i] /= len(dataset.df.shape[0])
    
    bleus.append(_bleu)
    ciders.append(_cider)
    rouges.append(_rouge)
    
    evals['bleu'] = bleus
    evals['cider'] = ciders
    evals['rouge'] = rouges
    evals['running_bleu'] = running_bleu
    
    return evals
    

In [47]:
vocab_size = len(vocabulary.vocab)
transformer = TransformerAttention(encoded_image_size=14, attention_dim=ATTENTION_DIM, embedding_dim=EMBEDDING_DIM, 
                                   decoder_dim=DECODER_DIM, vocab_size=vocab_size, encoder_dim=ENCODER_DIM, 
                                   embedding_matrix=embedding_mtx, train_embedding=True).to(device)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train.pad_value).to(device)
corpus_bleu_score_fn = bleu_score_fn(4, 'corpus')
tensor2word_fn = tensor_to_word_fn(idx2word=vocabulary.idx2word)

params = transformer.parameters()
optimizer = torch.optim.Adam(params=params, lr=LR)


In [None]:
train_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.RandomCrop(256),  # get 256x256 crop from random location
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

eval_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.CenterCrop(256),  # get 256x256 crop from random location
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

In [25]:
train_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.RandomCrop(256),  # get 256x256 crop from random location
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

train.transformations = train_transformations

In [26]:
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False)

In [28]:
t = tqdm(iter(train_loader), desc=f'train')
for batch_idx, batch in enumerate(t):
    images, captions, lengths, fname = batch
    print(batch_idx, fname)

HBox(children=(HTML(value='train'), FloatProgress(value=0.0, max=10.0), HTML(value='')))

0 ('VizWiz_train_00000022.jpg', 'VizWiz_train_00000016.jpg', 'VizWiz_train_00000022.jpg', 'VizWiz_train_00000004.jpg', 'VizWiz_train_00000002.jpg', 'VizWiz_train_00000010.jpg', 'VizWiz_train_00000017.jpg', 'VizWiz_train_00000013.jpg', 'VizWiz_train_00000015.jpg', 'VizWiz_train_00000012.jpg')
1 ('VizWiz_train_00000021.jpg', 'VizWiz_train_00000011.jpg', 'VizWiz_train_00000006.jpg', 'VizWiz_train_00000003.jpg', 'VizWiz_train_00000006.jpg', 'VizWiz_train_00000014.jpg', 'VizWiz_train_00000008.jpg', 'VizWiz_train_00000014.jpg', 'VizWiz_train_00000007.jpg', 'VizWiz_train_00000013.jpg')
2 ('VizWiz_train_00000001.jpg', 'VizWiz_train_00000018.jpg', 'VizWiz_train_00000004.jpg', 'VizWiz_train_00000017.jpg', 'VizWiz_train_00000023.jpg', 'VizWiz_train_00000005.jpg', 'VizWiz_train_00000002.jpg', 'VizWiz_train_00000018.jpg', 'VizWiz_train_00000005.jpg', 'VizWiz_train_00000007.jpg')
3 ('VizWiz_train_00000013.jpg', 'VizWiz_train_00000010.jpg', 'VizWiz_train_00000020.jpg', 'VizWiz_train_00000005.jpg', 'V