In [1]:
%load_ext autoreload
%autoreload 2

import os
import string
import pandas as pd
import numpy as np
import torch

from collections import defaultdict

from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence
from tqdm.auto import tqdm

from loader.dataset import VizwizDataset
from loader.images import ImageS3
from commons.utils import embedding_matrix, tensor_to_word_fn
from models.resnext101 import TransformerAttention
from eval.metrics import bleu, cider, rouge, spice, meteor, bleu_score_fn


from IPython.core.display import HTML

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# %pip install ipywidgets nltk

In [2]:
train = VizwizDataset(dtype='train', ret_type='tensor', copy_img_to_mem=False, device=device, partial=100)
vocabulary = train.getVocab()

loading annotations into memory...
Done (t=0.46s)
creating index...
index created! imgs = 23431, anns = 100575
tokenizing caption... done!!


In [3]:
val = VizwizDataset(dtype='val', ret_type='tensor', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=50)

loading annotations into memory...
Done (t=0.17s)
creating index...
index created! imgs = 7750, anns = 33145
tokenizing caption... done!!


In [4]:
val_eval = VizwizDataset(dtype='val', ret_type='corpus', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=50)

loading annotations into memory...
Done (t=0.19s)
creating index...
index created! imgs = 7750, anns = 33145
tokenizing caption... done!!


In [5]:
test = VizwizDataset(dtype='test', ret_type='tensor', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=50)

loading annotations into memory...
Done (t=0.04s)
creating index...
index created! imgs = 8000, anns = 0
no caption found... done!!


In [6]:
MODEL = 'resnext101_attention'
GLOVE_DIR = 'annotations/glove'
ENCODER_DIM = 2048
EMBEDDING_DIM = 300
ATTENTION_DIM = 256
DECODER_DIM = 256

BATCH_SIZE = 10 # 128
LOG_INTERVAL = 25 * (256 // BATCH_SIZE)
LR = 5e-4
NUM_EPOCHS = 2
SAVE_FREQ = 10


In [7]:
embedding_mtx = embedding_matrix(embedding_dim=EMBEDDING_DIM, word2idx=vocabulary.word2idx, glove_dir=GLOVE_DIR)
embedding_mtx.shape

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=378.0), HTML(value='')))




(378, 300)

In [17]:
def fit(train_loader, val_loader, model, loss_fn, optimizer, desc=''):
    
    # train model
    train_acc = 0.0
    train_loss = 0.0
    
    model.train()
    t = tqdm(iter(train_loader), desc=f'{desc} ::: train')
    for batch_idx, batch in enumerate(t):
        images, captions, lengths, fname = batch
        
        optimizer.zero_grad()
        scores, captions_sorted, decode_len, alphas, sort_idx = model(images, captions, lengths)
        # exclude <start> and only includes after <start> to <end>
        targets = captions_sorted[:, 1:]
        # remove pads or timesteps that were not decoded
        scores = pack_padded_sequence(scores, decode_len, batch_first=True)[0]
        targets = pack_padded_sequence(targets, decode_len, batch_first=True)[0]
        
        loss = loss_fn(scores, targets)
        loss.backward()
        optimizer.step()
        
        train_acc += (torch.argmax(scores, dim=1) == targets).sum().float().item() / targets.size(0)
        train_loss += loss.item()
        
        t.set_postfix({
            'loss': train_loss / (batch_idx + 1),
            'acc': train_acc / (batch_idx + 1)
        }, refresh=True)
        
        if (batch_idx + 1) % LOG_INTERVAL == 0:
            print(f'{desc}_train {batch_idx + 1}/{len(train_loader)} '
                  f'train_loss: {train_loss / (batch_idx + 1):.4f} '
                  f'train_acc: {train_acc / (batch_idx + 1):.4f}')
            
    train_loss_mean = train_loss / len(train_loader)

    # validate model
    val_acc = 0.0
    val_loss = 0.0
    
    model.eval()
    t_val = tqdm(iter(val_loader), desc=f'{desc} ::: val')
    for batch_idx, batch in enumerate(t_val):
        images, captions, lengths, fname = batch
        
        scores, captions_sorted, decode_len, alphas, sort_idx = model(images, captions, lengths)
        # exclude <start> and only includes after <start> to <end>
        targets = captions_sorted[:, 1:]
        # remove pads or timesteps that were not decoded
        scores = pack_padded_sequence(scores, decode_len, batch_first=True)[0]
        targets = pack_padded_sequence(targets, decode_len, batch_first=True)[0]
        
        loss = loss_fn(scores, targets)
        val_acc += (torch.argmax(scores, dim=1) == targets).sum().float().item() / targets.size(0)
        val_loss += loss.item()
        
        t_val.set_postfix({
            'loss': val_loss / (batch_idx + 1),
            'acc': val_acc / (batch_idx + 1)
        }, refresh=True)
        
        if (batch_idx + 1) % LOG_INTERVAL == 0:
            print(f'{desc}_val {batch_idx + 1}/{len(val_loader)} '
                  f'val_loss: {val_loss / (batch_idx + 1):.4f} '
                  f'val_acc: {val_acc / (batch_idx + 1):.4f}')
    
    val_loss_mean = val_loss / len(val_loader)
    
    return train_loss_mean, val_loss_mean

def detokenize(tokens):
    return ''.join([' ' + i if not i.startswith("'") and i not in string.punctuation else i for i in tokens]).strip()

def evaluate(data_loader, model, bleu_score_fn, tensor_to_word_fn, device=torch.device('cpu'), desc=''):
    
    model.eval()
    
    pred_byfname = dict()
    caps_byfname = defaultdict(list)
    scores = dict()
    
    bleus, ciders, rouges = [], [], []
    running_bleu = [0.0] * 5
    
    t = tqdm(iter(data_loader), desc=f'{desc}')
    for batch_idx, batch in enumerate(t):
        images, captions, lengths, fname = batch
        outputs = tensor_to_word_fn(model.sample(images, startseq_idx=vocabulary.word2idx['<start>']).cpu().numpy())
        
        for i in range(1, 5):
            running_bleu[i] += bleu_score_fn(reference_corpus=captions, candidate_corpus=outputs, n=i)
        t.set_postfix({
            'bleu1': running_bleu[1] / (batch_idx + 1),
            'bleu4': running_bleu[4] / (batch_idx + 1)
        }, refresh=True)
        
        for f, o, c in zip(fname, outputs, captions):
            if not f in pred_byfname:
                pred_byfname[f] = [detokenize(o)]
            caps_byfname[f].append(detokenize(c))

    # mean running_bleu score
    for i in range(1, 5):
        running_bleu[i] /= len(data_loader)
    scores['running_bleu'] = running_bleu

    # calculate overall score
    _bleu = bleu(caps_byfname, pred_byfname)
    _cider = cider(caps_byfname, pred_byfname)
    _rouge = rouge(caps_byfname, pred_byfname)
    
    bleus.append(_bleu)
    ciders.append(_cider)
    rouges.append(_rouge)
    
    scores['bleu'] = bleus
    scores['cider'] = ciders
    scores['rouge'] = rouges
    
    print('bleu', scores['bleu'])
    print('cider', scores['cider'])
    print('rouge', scores['rouge'])
    
    return scores
    

In [9]:
vocab_size = len(vocabulary.vocab)
transformer = TransformerAttention(encoded_image_size=14, attention_dim=ATTENTION_DIM, embedding_dim=EMBEDDING_DIM, 
                                   decoder_dim=DECODER_DIM, vocab_size=vocab_size, encoder_dim=ENCODER_DIM, 
                                   embedding_matrix=embedding_mtx, train_embedding=True).to(device) 

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train.pad_value).to(device)
corpus_bleu_score_fn = bleu_score_fn(4, 'corpus')
tensor2word_fn = tensor_to_word_fn(idx2word=vocabulary.idx2word)

params = transformer.parameters()
optimizer = torch.optim.Adam(params=params, lr=LR)


In [10]:
train_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.RandomCrop(256),  # get 256x256 crop from random location
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

eval_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.CenterCrop(256),  # get 256x256 crop from random location
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

train.transformations = train_transformations
val.transformations = train_transformations
val_eval.transformations = eval_transformations

In [11]:
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False)

eval_collate_fn = lambda batch: (torch.stack([x[0] for x in batch]), [x[1] for x in batch], [x[2] for x in batch], [x[3] for x in batch])
val_eval_loader = DataLoader(val_eval, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, collate_fn=eval_collate_fn)

In [18]:
train_loss_min = 100
val_loss_min = 100
val_bleu4_max = 0.0

for epoch in range(NUM_EPOCHS):
    train_loss, val_loss = fit(train_loader, val_loader, model=transformer, loss_fn=loss_fn, optimizer=optimizer,
                              desc=f'Epoch {epoch+1} of {NUM_EPOCHS}')
    
    with torch.no_grad():
        scores =  evaluate(val_eval_loader, model=transformer, bleu_score_fn=corpus_bleu_score_fn, 
                           tensor_to_word_fn=tensor2word_fn, desc='Eval Score')
        
        print(f'Epoch {epoch + 1}/{NUM_EPOCHS}')
        print('================================')
        print(''.join([f'val_running_bleu{i}: {scores["running_bleu"][i]:.4f} ' for i in (1, 4)]))
        print(''.join([f'val_bleu{i + 1}: {scores["bleu"][i]:.4f} ' for i in range(0, 4)]))

HBox(children=(HTML(value='Epoch 1 of 2 ::: train'), FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(HTML(value='Epoch 1 of 2 ::: val'), FloatProgress(value=0.0, max=5.0), HTML(value='')))




HBox(children=(HTML(value='Eval Score'), FloatProgress(value=0.0, max=5.0), HTML(value='')))


{'testlen': 91, 'reflen': 107, 'guess': [91, 79, 67, 55], 'correct': [35, 2, 0, 0]}
ratio: 0.8504672897116778
bleu [([0.3226020944606647, 0.08276663693869638, 4.4098665638249194e-07, 1.0693959787638436e-09], [[0.23884377011164867, 0.23884377011164867, 0.5515605639774822, 0.29999999994000015, 0.3715190997867868, 0.24767939985785795, 0.33093633838648934, 0.2388437701381868, 0.33093633838648934, 0.5367694949004131, 0.22222222217283966, 0.23884377011164867], [5.85045365006396e-09, 5.85045365006396e-09, 0.26369638633624565, 5.773502690709485e-09, 0.23168286400471463, 5.982025825945981e-09, 6.459215935396237e-09, 4.62518972042083e-09, 6.459215935396237e-09, 8.199289470316909e-09, 5.270462765739487e-09, 5.85045365006396e-09], [1.830282338978762e-11, 1.830282338978762e-11, 2.170651479684559e-06, 1.60914897400152e-11, 2.103416377214603e-06, 1.837503456341076e-11, 1.8307983629515338e-11, 1.2985701392916413e-11, 1.8307983629515338e-11, 2.1618858818179448e-11, 1.5831904189457608e-11, 1.8302823389

TypeError: unsupported format string passed to tuple.__format__

In [25]:
train_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.RandomCrop(256),  # get 256x256 crop from random location
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

train.transformations = train_transformations

In [36]:
val_eval_loader = DataLoader(val_eval, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, collate_fn=eval_collate_fn)

In [37]:
t = tqdm(iter(val_eval_loader), desc=f'train')
for batch_idx, batch in enumerate(t):
    images, captions, lengths, fname = batch
    if batch_idx == 0:
        print(batch_idx, fname, captions)

HBox(children=(HTML(value='train'), FloatProgress(value=0.0, max=5.0), HTML(value='')))

0 ['VizWiz_val_00000000.jpg', 'VizWiz_val_00000000.jpg', 'VizWiz_val_00000000.jpg', 'VizWiz_val_00000000.jpg', 'VizWiz_val_00000000.jpg', 'VizWiz_val_00000001.jpg', 'VizWiz_val_00000001.jpg', 'VizWiz_val_00000001.jpg', 'VizWiz_val_00000001.jpg', 'VizWiz_val_00000001.jpg'] [['a', 'computer', 'screen', 'shows', 'a', 'repair', 'prompt', 'on', 'the', 'screen', '.'], ['a', 'computer', 'screen', 'with', 'a', 'repair', 'automatically', 'pop', 'up'], ['partial', 'computer', 'screen', 'showing', 'the', 'need', 'of', 'repairs'], ['part', 'of', 'a', 'computer', 'monitor', 'showing', 'a', 'computer', 'repair', 'message', '.'], ['the', 'top', 'of', 'a', 'laptop', 'with', 'a', 'blue', 'background', 'and', 'dark', 'blue', 'text', '.'], ['a', 'person', 'is', 'holding', 'a', 'bottle', 'that', 'has', 'medicine', 'for', 'the', 'night', 'time', '.'], ['a', 'bottle', 'of', 'medication', 'has', 'a', 'white', 'twist', 'top', '.'], ['night', 'time', 'medication', 'bottle', 'being', 'held', 'by', 'someone'], [