In [1]:
%load_ext autoreload
%autoreload 2

import os
import string
import pandas as pd
import numpy as np
import torch

from collections import defaultdict

from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence
from tqdm.auto import tqdm

from loader.dataset import VizwizDataset
from loader.images import ImageS3
from loader.model import ModelS3
from commons.utils import embedding_matrix, tensor_to_word_fn
from models.resnext101 import TransformerAttention
from eval.metrics import bleu, cider, rouge, spice, meteor, bleu_score_fn


from IPython.core.display import HTML

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# %pip install ipywidgets nltk
# sudo yum install java-1.8.0-openjdk

In [2]:
train = VizwizDataset(dtype='train', ret_type='tensor', copy_img_to_mem=False, device=device, partial=100)
vocabulary = train.getVocab()

loading annotations into memory...
Done (t=0.45s)
creating index...
index created! imgs = 23431, anns = 100575
tokenizing caption... done!!


In [3]:
val = VizwizDataset(dtype='val', ret_type='tensor', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=50)

loading annotations into memory...
Done (t=0.15s)
creating index...
index created! imgs = 7750, anns = 33145
tokenizing caption... done!!


In [4]:
val_eval = VizwizDataset(dtype='val', ret_type='corpus', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=50)

loading annotations into memory...
Done (t=0.25s)
creating index...
index created! imgs = 7750, anns = 33145
tokenizing caption... done!!


In [5]:
test = VizwizDataset(dtype='test', ret_type='corpus', copy_img_to_mem=False, vocabulary=vocabulary, device=device, partial=10)

loading annotations into memory...
Done (t=0.04s)
creating index...
index created! imgs = 8000, anns = 0
no caption found... done!!


In [6]:
BUCKET = 'assistive-version'
MODEL = 'resnext101_attention'
GLOVE_DIR = 'annotations/glove'
ENCODER_DIM = 2048
EMBEDDING_DIM = 300
ATTENTION_DIM = 256
DECODER_DIM = 256

BATCH_SIZE = 10 # 128
LOG_INTERVAL = 25 * (256 // BATCH_SIZE)
LR = 5e-4
NUM_EPOCHS = 2

LOCAL_PATH = 'bin/'
KEY_PATH = 'bin/'
CAPTIONS_PATH = 'captions/'
VERSION = 1.0

MODEL_NAME = f'{MODEL}_b{BATCH_SIZE}_emb{EMBEDDING_DIM}'

In [7]:
embedding_mtx = embedding_matrix(embedding_dim=EMBEDDING_DIM, word2idx=vocabulary.word2idx, glove_dir=GLOVE_DIR)
embedding_mtx.shape

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=378.0), HTML(value='')))




(378, 300)

In [8]:
def fit(dataloaders, model, loss_fn, optimizer, desc=''):
    
    means = dict()
    
    for phase in ['train', 'val']:
        acc = 0.0
        loss = 0.0
        
        if phase == 'train':
            model.train()
        else:
            model.eval()
            
        t = tqdm(iter(dataloaders[phase]), desc=f'{desc} ::: {phase}')
        for batch_idx, batch in enumerate(t):
            images, captions, lengths, fname, image_id = batch
    
            if phase == 'train':
                optimizer.zero_grad()
            
            scores, captions_sorted, decode_len, alphas, sort_idx = model(images, captions, lengths)
            # exclude <start> and only includes after <start> to <end>
            targets = captions_sorted[:, 1:]
            # remove pads or timesteps that were not decoded
            scores = pack_padded_sequence(scores, decode_len, batch_first=True)[0]
            targets = pack_padded_sequence(targets, decode_len, batch_first=True)[0]

            loss = loss_fn(scores, targets)
    
            if phase == 'train':
                loss.backward()
                optimizer.step()

            acc += (torch.argmax(scores, dim=1) == targets).sum().float().item() / targets.size(0)
            loss += loss.item()

            t.set_postfix({
                'loss': loss / (batch_idx + 1),
                'acc': acc / (batch_idx + 1)
            }, refresh=True)

            if (batch_idx + 1) % LOG_INTERVAL == 0 :
                print(f'{desc}_{phase} {batch_idx + 1}/{len(dataloaders[phase])} '
                      f'{phase}_loss: {loss / (batch_idx + 1):.4f} '
                      f'{phase}_acc: {acc / (batch_idx + 1):.4f}')
        
        means[phase] = loss / len(dataloaders[phase])
    
    return means['train'], means['val']
    

def detokenize(tokens):
    return ''.join([' ' + i if not i.startswith("'") and i not in string.punctuation else i for i in tokens]).strip()

def evaluate(data_loader, model, bleu_score_fn, tensor_to_word_fn, desc=''):
    
    model.eval()
    
    pred_byfname = dict()
    caps_byfname = defaultdict(list)
    scores = dict()
    
    running_bleu = [0.0] * 5
    
    t = tqdm(iter(data_loader), desc=f'{desc}')
    for batch_idx, batch in enumerate(t):
        images, captions, lengths, fname, image_id = batch
        outputs = tensor_to_word_fn(model.sample(images, startseq_idx=vocabulary.word2idx['<start>']).cpu().numpy())
        
        for i in range(1, 5):
            running_bleu[i] += bleu_score_fn(reference_corpus=captions, candidate_corpus=outputs, n=i)
        t.set_postfix({
            'bleu1': running_bleu[1] / (batch_idx + 1),
            'bleu4': running_bleu[4] / (batch_idx + 1)
        }, refresh=True)
        
        for f, o, c in zip(fname, outputs, captions):
            if not f in pred_byfname:
                pred_byfname[f] = [detokenize(o)]
            caps_byfname[f].append(detokenize(c))

    # mean running_bleu score
    for i in range(1, 5):
        running_bleu[i] /= len(data_loader)
    scores['avg_bleu'] = running_bleu

    # calculate overall score
    scores['bleu'] = bleu(caps_byfname, pred_byfname, verbose=0)
    scores['cider'] = cider(caps_byfname, pred_byfname)
    scores['rouge'] = rouge(caps_byfname, pred_byfname)
    #scores['spice'] = spice(caps_byfname, pred_byfname)
    #scores['meteor'] = meteor(caps_byfname, pred_byfname)
    
    return scores

def generate_captions(dataloader, model, desc=''):
    rlist = []
    
    t = tqdm(iter(dataloader), desc=f'{desc}')
    for batch_idx, batch in enumerate(t):
        images, fname, image_id = batch
        outputs = tensor2word_fn(model.sample(images, startseq_idx=vocabulary.word2idx['<start>']).cpu().numpy())

        for out, img in zip(outputs, image_id):
            result = dict(
                image_id = int(img),
                caption = detokenize(out)
            )
            rlist.append(result)

    results = dict(
        results = rlist
    )
    
    return results

In [9]:
vocab_size = len(vocabulary.vocab)
transformer = TransformerAttention(encoded_image_size=14, attention_dim=ATTENTION_DIM, embedding_dim=EMBEDDING_DIM, 
                                   decoder_dim=DECODER_DIM, vocab_size=vocab_size, encoder_dim=ENCODER_DIM, 
                                   embedding_matrix=embedding_mtx, train_embedding=True).to(device) 

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train.pad_value).to(device)
corpus_bleu_score_fn = bleu_score_fn(4, 'corpus')
tensor2word_fn = tensor_to_word_fn(idx2word=vocabulary.idx2word)

params = transformer.parameters()
optimizer = torch.optim.Adam(params=params, lr=LR)


In [10]:
train_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.RandomCrop(256),  # get 256x256 crop from random location
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

eval_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.CenterCrop(256),  # get 256x256 crop from random location
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

train.transformations = train_transformations
val.transformations = train_transformations
val_eval.transformations = eval_transformations
test.transformations = eval_transformations

In [11]:
dataloaders = dict(
    train = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False),
    val = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False)
)
#train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False)
#val_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False)

eval_collate_fn = lambda batch: (torch.stack([x[0] for x in batch]), [x[1] for x in batch], [x[2] for x in batch], 
                                 [x[3] for x in batch], [x[4] for x in batch])
test_collate_fn = lambda batch: (torch.stack([x[0] for x in batch]), [x[1] for x in batch], [x[2] for x in batch])

val_eval_loader = DataLoader(val_eval, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, collate_fn=eval_collate_fn)
test_loader = DataLoader(test, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, collate_fn=test_collate_fn)

In [12]:
train_loss_min = 100
val_loss_min = 100
val_avg_bleu4_max = 0.0

model_bin = ModelS3()
transformer_best = None

for epoch in range(NUM_EPOCHS):
    train_loss, val_loss = fit(dataloaders, model=transformer, loss_fn=loss_fn, optimizer=optimizer,
                              desc=f'Epoch {epoch+1} of {NUM_EPOCHS}')
    
    with torch.no_grad():
        scores =  evaluate(val_eval_loader, model=transformer, bleu_score_fn=corpus_bleu_score_fn, 
                           tensor_to_word_fn=tensor2word_fn, desc='Eval Score')
        
        print(f'Epoch {epoch + 1}/{NUM_EPOCHS}')
        print('=' * 95)
        print(''.join([f'val_avg_bleu{i}: {scores["avg_bleu"][i]:.4f} ' for i in range(1, 5)]))
        print(''.join([f'val_bleu{i + 1}{":":>5} {scores["bleu"][0][i]:.4f} ' for i in range(0, 4)]))
        print(f'val_cider{":":>5} {scores["cider"][0]:.4f}')
        print(f'val_rouge{":":>5} {scores["rouge"][0]:.4f}')
        #print(f'val_spcie{":":>5} {scores["spice"][0]:.4f}')
        #print(f'val_meteor{":":>5} {scores["meteor"][0]:.4f}')
        print('-' * 95)
        
        state = dict(
            epoch = epoch + 1,
            state_dict = transformer.state_dict(),
            train_loss_latest = train_loss,
            val_loss_latest = val_loss,
            train_loss_min = min(train_loss, train_loss_min),
            val_loss_min = min(val_loss, val_loss_min),
            val_avg_bleu1 = scores['avg_bleu'][1],
            val_avg_bleu4 = scores['avg_bleu'][4],
            val_avg_bleu4_max = max(scores['avg_bleu'][4], val_avg_bleu4_max),
            val_bleu1 = scores['bleu'][0][0],
            val_bleu4 = scores['bleu'][0][3],
            val_cider = scores['cider'][0],
            val_rouge = scores['rouge'][0]
        )
        
        if scores['avg_bleu'][4] > val_avg_bleu4_max:
            val_avg_bleu4_max = scores['avg_bleu'][4]
            fname = f'{MODEL_NAME}_best_v{VERSION}.pt'
            # keep the best transformer
            transformer_best = transformer
            model_bin.save(state, os.path.join(LOCAL_PATH, fname), os.path.join(KEY_PATH, fname))

fname = f'{MODEL_NAME}_ep{NUM_EPOCHS}_latest_v{VERSION}.pt'
model_bin.save(state, os.path.join(LOCAL_PATH, fname), os.path.join(KEY_PATH, fname))


[2021-06-25 17:53:35.181 pytorch-1-6-cpu-py36--ml-t3-medium-370ee60fbc7a856e8f67ac271515:1024 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2021-06-25 17:53:35.346 pytorch-1-6-cpu-py36--ml-t3-medium-370ee60fbc7a856e8f67ac271515:1024 INFO profiler_config_parser.py:102] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.


HBox(children=(HTML(value='Epoch 1 of 2 ::: train'), FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(HTML(value='Epoch 1 of 2 ::: val'), FloatProgress(value=0.0, max=5.0), HTML(value='')))




HBox(children=(HTML(value='Eval Score'), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Epoch 1/2
val_avg_bleu1: 0.1861 val_avg_bleu2: 0.0367 val_avg_bleu3: 0.0181 val_avg_bleu4: 0.0114 
val_bleu1    : 0.2948 val_bleu2    : 0.0000 val_bleu3    : 0.0000 val_bleu4    : 0.0000 
val_cider    : 0.0000
val_rouge    : 0.2798
-----------------------------------------------------------------------------------------------


HBox(children=(HTML(value='Epoch 2 of 2 ::: train'), FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(HTML(value='Epoch 2 of 2 ::: val'), FloatProgress(value=0.0, max=5.0), HTML(value='')))




HBox(children=(HTML(value='Eval Score'), FloatProgress(value=0.0, max=5.0), HTML(value='')))


Epoch 2/2
val_avg_bleu1: 0.1913 val_avg_bleu2: 0.0378 val_avg_bleu3: 0.0187 val_avg_bleu4: 0.0118 
val_bleu1    : 0.3018 val_bleu2    : 0.0570 val_bleu3    : 0.0000 val_bleu4    : 0.0000 
val_cider    : 0.0044
val_rouge    : 0.2892
-----------------------------------------------------------------------------------------------


In [13]:
results = generate_captions(test_loader, model=transformer, desc='captioning ::: test')
results

HBox(children=(HTML(value='captioning ::: test'), FloatProgress(value=0.0, max=1.0), HTML(value='')))




{'results': [{'image_id': 31181, 'caption': 'a a a a a a'},
  {'image_id': 31182, 'caption': 'a a a a a'},
  {'image_id': 31183, 'caption': 'a a of a a'},
  {'image_id': 31184, 'caption': 'a a a a a'},
  {'image_id': 31185, 'caption': 'a a of of a a a a'},
  {'image_id': 31186, 'caption': 'a a a a a a'},
  {'image_id': 31187, 'caption': 'a a a a a a a'},
  {'image_id': 31188, 'caption': 'a a of a a'},
  {'image_id': 31189, 'caption': 'a a a a a a a'},
  {'image_id': 31190, 'caption': 'a a a a a a'}]}

In [14]:

# save captions to s3
fname = f'{MODEL_NAME}_ep{NUM_EPOCHS}_latest_v{VERSION}.json'
model_bin.save_captions(results, os.path.join(CAPTIONS_PATH, fname))


In [15]:
# load captions from s3
results = model_bin.load_captions(os.path.join(CAPTIONS_PATH, fname))
results

{'results': [{'image_id': 31181, 'caption': 'a a a a a a'},
  {'image_id': 31182, 'caption': 'a a a a a'},
  {'image_id': 31183, 'caption': 'a a of a a'},
  {'image_id': 31184, 'caption': 'a a a a a'},
  {'image_id': 31185, 'caption': 'a a of of a a a a'},
  {'image_id': 31186, 'caption': 'a a a a a a'},
  {'image_id': 31187, 'caption': 'a a a a a a a'},
  {'image_id': 31188, 'caption': 'a a of a a'},
  {'image_id': 31189, 'caption': 'a a a a a a a'},
  {'image_id': 31190, 'caption': 'a a a a a a'}]}

In [16]:

# prepare for inference
imageS3 = ImageS3()
fpath = os.path.join('vizwiz', 'test', 'VizWiz_test_00000002.jpg')
img = eval_transformations(imageS3.getImage(fpath)).to(device)
#caption = [vocabulary.idx2word[idx.item()] for idx in transformer.sample(img.unsqueeze(0), vocabulary.word2idx['<start>'])[0] 
#                 if idx.item() != vocabulary.word2idx['<end>']]
caption = tensor2word_fn(transformer.sample(img.unsqueeze(0), startseq_idx=vocabulary.word2idx['<start>']).cpu().numpy())[0]
detokenize(caption)


'a a of a a'

In [25]:
# for experiment only
train_transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.RandomCrop(256),  # get 256x256 crop from random location
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])

train.transformations = train_transformations

In [36]:
val_eval_loader = DataLoader(val_eval, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, collate_fn=eval_collate_fn)

In [37]:
t = tqdm(iter(val_eval_loader), desc=f'train')
for batch_idx, batch in enumerate(t):
    images, captions, lengths, fname = batch
    if batch_idx == 0:
        print(batch_idx, fname, captions)

HBox(children=(HTML(value='train'), FloatProgress(value=0.0, max=5.0), HTML(value='')))

0 ['VizWiz_val_00000000.jpg', 'VizWiz_val_00000000.jpg', 'VizWiz_val_00000000.jpg', 'VizWiz_val_00000000.jpg', 'VizWiz_val_00000000.jpg', 'VizWiz_val_00000001.jpg', 'VizWiz_val_00000001.jpg', 'VizWiz_val_00000001.jpg', 'VizWiz_val_00000001.jpg', 'VizWiz_val_00000001.jpg'] [['a', 'computer', 'screen', 'shows', 'a', 'repair', 'prompt', 'on', 'the', 'screen', '.'], ['a', 'computer', 'screen', 'with', 'a', 'repair', 'automatically', 'pop', 'up'], ['partial', 'computer', 'screen', 'showing', 'the', 'need', 'of', 'repairs'], ['part', 'of', 'a', 'computer', 'monitor', 'showing', 'a', 'computer', 'repair', 'message', '.'], ['the', 'top', 'of', 'a', 'laptop', 'with', 'a', 'blue', 'background', 'and', 'dark', 'blue', 'text', '.'], ['a', 'person', 'is', 'holding', 'a', 'bottle', 'that', 'has', 'medicine', 'for', 'the', 'night', 'time', '.'], ['a', 'bottle', 'of', 'medication', 'has', 'a', 'white', 'twist', 'top', '.'], ['night', 'time', 'medication', 'bottle', 'being', 'held', 'by', 'someone'], [