# Show, Attend and Tell - Implementation

 ## Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Import libraries

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import warnings
warnings.filterwarnings('ignore')

import torch
from torch import nn
import json
from torchvision import transforms
from fastai.callbacks import lr_finder, SaveModelCallback, EarlyStoppingCallback,ReduceLROnPlateauCallback
from torch.nn.utils.rnn import pack_padded_sequence 
from torch.utils.data import Dataset, DataLoader
from fastai.vision import learner, data
from fastai.metrics import top_k_accuracy
from fastai.text import *
from fastai.vision import *
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.bleu_score import sentence_bleu
from sklearn.utils import shuffle
from PIL import Image

defaults.device = torch.device('cuda')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
import copy

def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "0"

In [None]:
device

## Config

In [None]:
config_baseline = {
    ## BEGIN FastText Config - DO NOT CHANGE ##
    'embedding.name':'fasttext-word-freq-2-plus',
    'embedding.folder': 'fasttext-word-freq-2-plus',
    'append_sos_eos_tokens': True,
    'embedding.dim': 300,
    'start_token': '<start>',
    'end_token': '<end>',
    'unknown_token': '<unk>',
    'pad_token': '<pad>',
    ## END FastText Config ##

    'cnn': 'vgg16', # {resnet101,vgg16,resnext101}
    'encoder_dim': 512,
    'encoder': 'lstm', #{lstm|transformer}
    'search' : 'greedy',#{greedy|beam}
    'beam_size':3
}

config_bert_resnet101 = {
    ## BEGIN Bert Config - DO NOT CHANGE ##
    'embedding.name':'bert-768',
    'embedding.folder': 'bert-768',
    'append_sos_eos_tokens': False,
    'embedding.dim': 768,
    'start_token': '[CLS]',
    'end_token': '[SEP]',
    'unknown_token': '[UNK]',
    'pad_token': '[PAD]',
    ## END Bert Config ##

    'cnn': 'resnet101', # {resnet101,vgg16,resnext101}
    'encoder_dim': 2048,
    'encoder': 'lstm', #{lstm|transformer}
    'search' : 'greedy',#{greedy|beam}
    'beam_size':3
}

config_bert_filtered_resnet101_transformer = {
    ## BEGIN Bert Config - DO NOT CHANGE ##
    'embedding.name':'bert-768-word-freq-2-plus',
    'embedding.folder': 'bert-768-word-freq-2-plus',
    'append_sos_eos_tokens': False,
    'embedding.dim': 768,
    'start_token': '[CLS]',
    'end_token': '[SEP]',
    'unknown_token': '[UNK]',
    'pad_token': '[PAD]',
    ## END Bert Config ##

    'cnn': 'resnet101', # {resnet101,vgg16,resnext101}
    'encoder_dim': 2048,
    'encoder': 'transformer', #{lstm|transformer}
    'search' : 'greedy',#{greedy|beam}
    'beam_size':3
}

config_bert_resnext101_transformer = {
    ## BEGIN Bert Config - DO NOT CHANGE ##
    'embedding.name':'bert-768',
    'embedding.folder': 'bert-768',
    'append_sos_eos_tokens': False,
    'embedding.dim': 768,
    'start_token': '[CLS]',
    'end_token': '[SEP]',
    'unknown_token': '[UNK]',
    'pad_token': '[PAD]',
    ## END Bert Config ##

    'cnn': 'resnext101', # {resnet101,vgg16,resnext101}
    'encoder_dim': 2048,
    'encoder': 'transformer', #{lstm|transformer}
    'search' : 'greedy',#{greedy|beam}
    'beam_size':3
}

config_bert_resnext101 = {
    ## BEGIN Bert Config - DO NOT CHANGE ##
    'embedding.name':'bert-768',
    'embedding.folder': 'bert-768',
    'append_sos_eos_tokens': False,
    'embedding.dim': 768,
    'start_token': '[CLS]',
    'end_token': '[SEP]',
    'unknown_token': '[UNK]',
    'pad_token': '[PAD]',
    ## END Bert Config ##

    'cnn': 'resnext101', # {resnet101,vgg16,resnext101}
    'encoder_dim': 2048,
    'encoder': 'lstm', #{lstm|transformer}
    'search' : 'greedy',#{greedy|beam}
    'beam_size':3
}

config_bert1024_resnet101 = {
    ## BEGIN Bert Config - DO NOT CHANGE ##
    'embedding.name':'bert-1024',
    'embedding.folder': 'bert-1024',
    'append_sos_eos_tokens': False,
    'embedding.dim': 1024,
    'start_token': '[CLS]',
    'end_token': '[SEP]',
    'unknown_token': '[UNK]',
    'pad_token': '[PAD]',
    ## END Bert Config ##

    'cnn': 'resnet101', # {resnet101,vgg16,resnext101}
    'encoder_dim': 2048,
    'encoder': 'lstm', #{lstm|transformer}
    'search' : 'greedy',#{greedy|beam}
    'beam_size':3
}

config_bert1024_filtered_resnet101 = {
    ## BEGIN Bert Config - DO NOT CHANGE ##
    'embedding.name':'bert-1024-word-freq-2-plus',
    'embedding.folder': 'bert-1024-word-freq-2-plus',
    'append_sos_eos_tokens': False,
    'embedding.dim': 1024,
    'start_token': '[CLS]',
    'end_token': '[SEP]',
    'unknown_token': '[UNK]',
    'pad_token': '[PAD]',
    ## END Bert Config ##

    'cnn': 'resnet101', # {resnet101,vgg16,resnext101}
    'encoder_dim': 2048,
    'encoder': 'lstm', #{lstm|transformer}
    'search' : 'greedy',#{greedy|beam}
    'beam_size':3
}


# conf = config_baseline
#conf = config_bert_resnet101
#conf = config_bert_filtered_resnet101_transformer
#conf = config_bert_resnext101_transformer
#conf = config_bert_resnext101
# conf = config_bert1024_resnet101
conf = config_bert1024_filtered_resnet101

conf['dataset_size'] = 'small' # {small|medium|full}

START_TOKEN = conf['start_token']
END_TOKEN = conf['end_token']
UNKNOWN_TOKEN = conf['unknown_token']
PAD_TOKEN = conf['pad_token']

print(conf)
print(START_TOKEN)
print(END_TOKEN)
print(UNKNOWN_TOKEN)
print(PAD_TOKEN)

## Create Experiment Folders

In [None]:
% cd "/content/drive/MyDrive/cs7643/project/"
import os, datetime
start_time = datetime.datetime.now()
exp_dir = os.path.join(os.getcwd(),  "Experiment_" + datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + "_" + conf['embedding.name'] + "_" + conf['cnn'] + "_" + conf['encoder'] + "_" + conf['search'])
os.makedirs(exp_dir)
% cd $exp_dir
% cp -r ../ExperimentTemplate/* .
% ls

## Image Databunch  Preparation

### Load Captions

In [None]:
# load caption data
captions_file = 'captions/{}/captions.json'.format(conf['embedding.folder'])
print(captions_file)
with open(captions_file,'r') as f:
    captions_data = json.load(f)

In [None]:
# reformat Json to pd Dataframe
image_meta = pd.DataFrame(captions_data['images'])
image_meta['references'] = image_meta['sentences'].apply(lambda x: [i['tokens'] for i in x])

In [None]:
captions_data['images'][0]['sentences'][0]

In [None]:
image_meta.head()

### Load Embedding and Preprocess

In [None]:
vocab_file = 'captions/{}/vocab.json'.format(conf['embedding.folder'])
print(vocab_file)
with open(vocab_file,'rb') as f:
    vocab = json.load(f)

def unnesting(df, explode_cols):
    idx = df.index.repeat(df[explode_cols[0]].str.len())
    df1 = pd.concat([
        pd.DataFrame({x: np.concatenate(df[x].values)}) for x in explode_cols], axis=1)
    df1.index = idx

    return df1.join(df.drop(explode_cols, 1), how='left')

# expand multi captions into rows (explode)
metadata = unnesting(image_meta,['sentids','sentences'])
metadata['labels'] =  metadata.sentences.apply(lambda x: x['raw'])
metadata['tokens'] =  metadata.sentences.apply(lambda x: x['tokens'])
metadata.reset_index(inplace = True)


# attaching Image folder path to filename and add column with filepaths
metadata['filename'] = metadata.filename.apply(lambda x: f'Flicker8k/Flickr8k_Dataset/Flicker8k_Dataset/{x}')
metadata.reset_index(inplace = True)

# numericalize tokens and re format list to numeric array
if conf['append_sos_eos_tokens'] == True:
  metadata['numericalized'] = metadata.tokens.apply(lambda x: [vocab[START_TOKEN]] + [vocab[i] if i in vocab.keys() else vocab[UNKNOWN_TOKEN] for i in x]+[vocab[END_TOKEN]])
else:
  metadata['numericalized'] = metadata.tokens.apply(lambda x: [vocab[i] if i in vocab.keys() else vocab[UNKNOWN_TOKEN] for i in x])

metadata['numericalized'] = metadata.numericalized.apply(lambda x: np.array(x))
metadata['SeqLen'] = metadata.numericalized.apply(len)

# numerricalise references 
def ref_numericalize(lst): return list(map(lambda x: [vocab[i] if i in vocab.keys() else vocab[UNKNOWN_TOKEN] for i in x],lst))

# store corresaponding reference captions in numeric form
metadata['numericalized_ref'] = metadata.references.apply(ref_numericalize)

# shuffle 
metadata = shuffle(metadata)
metadata.reset_index(drop=True, inplace=True)

In [None]:
metadata.numericalized_ref[2]

### Train/Validation Split

In [None]:
# split data in train and valid
train_idx = metadata.index[metadata.split == 'train']
valid_idx = metadata.index[metadata.split == 'val']

In [None]:
if conf['dataset_size'] == 'small':
  train_idx = train_idx[:10].to_list()
  valid_idx = valid_idx[:2].to_list()
  metadata = metadata[metadata.index.isin(train_idx+valid_idx)]
  metadata.reset_index(drop=True, inplace=True)
elif conf['dataset_size'] == 'medium':
  train_idx = train_idx[:1000].to_list()
  valid_idx = valid_idx[:200].to_list()
  metadata = metadata[metadata.index.isin(train_idx+valid_idx)]
  metadata.reset_index(drop=True, inplace=True)

### Custom Datasets

In [None]:
# define Dataset object that outputs img path, caption, indices of reference captions 
class ImageCaptionDataset(Dataset):
    def __init__(self, data, split = 'train',transform=None):
        data = data[data.split == split]
        self.filenames = list(data['filename'])
        self.captions  = list(data['numericalized'])
        self.inds  = data.index.tolist()
        self.transform = transform
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        # Grayscale images in dataset have to be onverted as tensor shapes have to match except in dim=0
        image = Image.open(self.filenames[idx]).convert('RGB')
        caption = self.captions[idx]
        ref_ind = self.inds[idx]

        
        if self.transform is not None:
            image = self.transform(image)
        
        return (image, caption, ref_ind)

In [None]:
# transformation resize img size to 350 px, tensoring, then normalize using image net stats
trans = transforms.Compose([
    transforms.Resize((350,350)),
    transforms.ToTensor()
    ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

inv_normalize = transforms.Normalize(
    mean=[-0.5238/0.3159, -0.5003/0.3091, -0.4718/0.3216],
    std=[1/0.3159, 1/0.3091, 1/0.3216]
)

denorm = transforms.Compose([
    inv_normalize,
    transforms.functional.to_pil_image
])

In [None]:
# create Dataset instense for both train set and valid
train_dataset = ImageCaptionDataset(metadata,'train',trans)
valid_dataset = ImageCaptionDataset(metadata,'val',trans)

In [None]:
# check random sample
#idx = 500
#d = train_dataset
#img, cap,ref = d[idx]

#print(' '.join([dict(zip(vocab.values(),vocab.keys()))[x] for x in cap])+'\n')


#for cap in metadata.numericalized_ref.loc[ref]:
    #print(' '.join([dict(zip(vocab.values(),vocab.keys()))[i] for i in cap]))
#print(cap)
#plt.imshow(denorm(img));

## Model Architecture

### Baseline Encoder/Decoder

In [None]:
from torch import nn
from torch.nn import functional as F, init
from torchvision import transforms, models
import torch
import random
from pdb import set_trace


device =torch.device("cuda" if torch.cuda.is_available() else "cpu")



# create a embedding layer
def create_emb(embedding_array):
    emb = nn.Embedding(len(word_map),embedding_dim)
    emb.weight.data = torch.from_numpy(embedding_array).float()
    return emb

class Encoder(nn.Module):
    def __init__(self,encode_img_size, fine_tune = False):
        super(Encoder, self).__init__()
        self.enc_imgsize = encode_img_size
        if conf['cnn'] == 'resnet101':
          cnn = models.resnet101(pretrained=True)
        elif conf['cnn'] == 'vgg16':
          cnn = models.vgg16(pretrained=True)
        elif conf['cnn'] == 'resnext101':
          cnn =  models.resnext101_32x8d(pretrained=True)

        self.encoder = nn.Sequential(*list(cnn.children())[:-2]) # removing final Linear layer
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encode_img_size,encode_img_size))
        self.fine_tune = fine_tune
        self.fine_tune_h()
        
    def fine_tune_h(self):
        """
        Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
        :param fine_tune: Allow?
        """
        for p in self.encoder.parameters():
            p.requires_grad = False
        
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.encoder.children())[5:]:
            for p in c.parameters():
                p.requires_grad = self.fine_tune
        
    def forward(self,X):
        out = self.encoder(X) # X is tensor of size (batch size, 3 (RGB), input height, width)
        out = self.adaptive_pool(out) # output (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)
        out = out.view(out.size(0), -1, out.size(3))
        return out
    
class Decoder(nn.Module):
    def __init__(self,attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5, pretrained_embedding = None,teacher_forcing_ratio = 0):
        super(Decoder, self).__init__()
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.encoder_dim = encoder_dim
        self.vocab_size = vocab_size
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.dropout = nn.Dropout(dropout)
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim) 
        self.embedding = nn.Embedding(vocab_size,embed_dim)
        self.lstm = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) #use 
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # gate
        self.pretrained_embedding = pretrained_embedding
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()
        
    def init_weights(self):
        """
        Initilizes some parametes with values from the uniform Dist

        """
        self.embedding.weight.data.uniform_(0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1,0.1)

        # Kaiming initialization
        #init.kaiming_normal_(self.init_h.weight, mode='fan_in')
        #init.kaiming_normal_(self.init_c.weight, mode='fan_in')
        #init.kaiming_normal_(self.f_beta.weight, mode='fan_in')
        #init.kaiming_normal_(self.fc.weight, mode='fan_in')

    def pretrained(self):
        if self.pretrained_embedding is not None:
            self.embedding.weight.data = torch.from_numpy(self.pretrained_embedding)
            
    def init_hidden_state(self, encoder_out):
        
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c
            
    def forward(self,encoder_out, encoded_captions,decode_lengths,inds):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        num_pixels = encoder_out.size(1)
        #embeddings = self.embedding(encoded_captions)
        
        ## initililize hidden encoding
        h, c = self.init_hidden_state(encoder_out)
        
        #dec_out = torch.zeros(1,batch_size,self.decoder_dim).to(device) #uncomment for teacher forcing

        decode_lengths = decode_lengths - 1
        
        max_len = max(decode_lengths).item()
        
        
        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max_len, vocab_size)
        alphas = torch.zeros(batch_size, max_len, num_pixels)
        
        for t in range(max_len):
            batch_size_t = sum([l.item() > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            
            # teacher forcing 
            use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
            
            
            inp_emb = self.embedding(encoded_captions[:batch_size_t,t]).float() if  (use_teacher_forcing or t==0) else self.embedding(prev_word[:batch_size_t]).float()
            #self.emb2dec_dim((embeddings[:batch_size_t, t, :]).float()) use syntax for teacher forcing
            #inp_emb = inp_emb if (use_teacher_forcing or t==0) else dec_out.squeeze(0)[:batch_size_t] #uncomment to add teacher forcing
            
            h, c = self.lstm(
                torch.cat([inp_emb, attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t,t , :] = preds
            alphas[:batch_size_t, t, :] = alpha

            _,prev_word = preds.max(dim=-1)
        return predictions,decode_lengths, alphas, inds
        
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        
        self.enc_att = nn.Linear(encoder_dim,attention_dim)
        self.dec_att = nn.Linear(decoder_dim,attention_dim)
        self.att = nn.Linear(attention_dim,1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

        # kaiming
        #init.kaiming_normal_(self.enc_att.weight, mode='fan_in')
        #init.kaiming_normal_(self.dec_att.weight, mode='fan_in')
        #init.kaiming_normal_(self.att.weight, mode='fan_in')

    def forward(self,encoder_out, decoder_hidden):
        encoder_att = self.enc_att(encoder_out)
        decoder_att = self.dec_att(decoder_hidden)
        att = self.att(self.relu(encoder_att + decoder_att.unsqueeze(1))).squeeze(2) #testing added batchnorm 
        alpha = self.softmax(att)
        attention_weighted_encoding = (encoder_out*alpha.unsqueeze(2)).sum(dim=1)
        
        return attention_weighted_encoding, alpha

### Transformer Encoder/Decoder

In [None]:
### Transformer Encoder/Decoder

import torch
import torch.nn as nn
import math
from torch.autograd import Variable
import torchvision.models

class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        print (f'Embedder init {vocab_size, d_model}')
        self.d_model = d_model
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        #self.embed = nn.Embedding(vocab_size, d_model)
        self.embed = nn.Linear(vocab_size,d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        #print (f'Embedder forward {x.shape}')
        return self.relu(self.embed(x)) 

class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_len = 200, dropout = 0.1):
        super().__init__()
        self.d_model = d_model
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.dropout = nn.Dropout(dropout)
        # create constant 'pe' matrix with values dependant on 
        # pos and i
        pe = torch.zeros(max_seq_len, d_model)
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = \
                math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = \
                math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
 
    
    def forward(self, x):
        # make embeddings relatively larger
        x = x * math.sqrt(self.d_model)
        #add constant to embedding
        seq_len = x.size(1)
        pe = Variable(self.pe[:,:seq_len], requires_grad=False)
        if x.is_cuda:
            pe.cuda()
        
        #print (x.shape)
        #print (pe.shape)
        x = x + pe
        d = self.dropout(x)
        #print (f' PositionalEncoder x forward shape {d.shape}')
        return d


import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class Norm(nn.Module):
    def __init__(self, d_model, eps = 1e-6):
        super().__init__()
    
        self.size = d_model
        
        # create two learnable parameters to calibrate normalisation
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        
        self.eps = eps
    
    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
        / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

def attention(q, k, v, d_k, mask=None, dropout=None):
    
    scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)
    
    if mask is not None:
        mask = mask.unsqueeze(1)
        scores = scores.masked_fill(mask == 0, -1e9)
    
    scores = F.softmax(scores, dim=-1)
    
    if dropout is not None:
        scores = dropout(scores)
        
    output = torch.matmul(scores, v)
    return output

class MultiHeadAttentionT(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout).to(self.device)
        self.out = nn.Linear(d_model, d_model).to(self.device)
    
    def forward(self, q, k, v, mask=None):
        
        bs = q.size(0)
        
        # perform linear operation and split into N heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        
        # transpose to get dimensions bs * N * sl * d_model
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        

        # calculate attention using function we will define next
        scores = attention(q, k, v, self.d_k, mask, self.dropout)
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous()\
        .view(bs, -1, self.d_model)
        output = self.out(concat)
        #print (f' MultiHeadAttentionT output forward shape {output.shape}')
        return output

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout = 0.1):
        super().__init__() 
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # We set d_ff as a default to 2048
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        #print (f' FeedForward forward shape {x.shape}')
        return x


import torch
import torch.nn as nn


class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout=0.1):
        super().__init__()
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)

        self.attn = MultiHeadAttentionT(heads, d_model, dropout=dropout)
        self.ff = FeedForward(d_model, dropout=dropout)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.attn(x2,x2,x2,mask))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.ff(x2))
        #print (f' EncoderLayer forward x shape {x.shape}')
        return x
    
# build a decoder layer with two multi-head attention layers and
# one feed-forward layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout=0.1):
        super().__init__()
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)
        self.norm_3 = Norm(d_model)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        self.dropout_3 = nn.Dropout(dropout)
        
        self.attn_1 = MultiHeadAttentionT(heads, d_model, dropout=dropout)
        self.attn_2 = MultiHeadAttentionT(heads, d_model, dropout=dropout)
        self.ff = FeedForward(d_model, dropout=dropout)

    def forward(self, x, e_outputs, src_mask, trg_mask):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, \
        src_mask))
        x2 = self.norm_3(x)
        x = x + self.dropout_3(self.ff(x2))
        print (f' DecoderLayer forward x shape {x.shape}')
        return x


import torch
import torch.nn as nn 
import torchvision.models


class TransformerEncoder(nn.Module):
    def __init__(self, encode_img_size=14, d_model=2048, N=6, heads=8, dropout=0.1):
        super().__init__()
        self.N = N
        self.enc_imgsize = encode_img_size

        if conf['cnn'] == 'resnet101':
          cnn = torchvision.models.resnet101(pretrained=True)
        elif conf['cnn'] == 'vgg16':
          cnn = torchvision.models.vgg16(pretrained=True)
        elif conf['cnn'] == 'resnext101':
          cnn =  torchvision.models.resnext101_32x8d(pretrained=True)
        
        self.encoder = nn.Sequential(*list(cnn.children())[:-2]) # removing final Linear layer
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encode_img_size,encode_img_size))
        #self.embed = Embedder(encode_img_size, d_model)
        self.embed = Embedder(d_model, d_model) # exp
        self.pe = PositionalEncoder(d_model, dropout=dropout)
        #self.pe = PositionalEncoder(encode_img_size, dropout=dropout) # exp
        self.layers = get_clones(EncoderLayer(d_model, heads, dropout=dropout), N)
        self.norm = Norm(d_model)
        
    def forward(self, src, mask=None):
        out = self.encoder(src) # X is tensor of size (batch size, 3 (RGB), input height, width)
        out = self.adaptive_pool(out) # output (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)
        out = out.view(out.size(0), -1, out.size(3))
        #print (f'TEncoder forward out shape {out.shape}')
        x = self.embed(out)
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x, mask)
        x  = self.norm(x)
        #print (f' TEncoder forward x shape {x.shape}')
        return self.norm(x)
    
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, heads, dropout):
        super().__init__()
        self.N = N
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model, dropout=dropout)
        self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)
        self.norm = Norm(d_model)

    def forward(self, trg, e_outputs, src_mask, trg_mask):
        x = self.embed(trg)
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x, e_outputs, src_mask, trg_mask)
        return self.norm(x)

class Transformer(nn.Module):
    def __init__(self, img, tokens, d_model=512, N=6, heads=8, dropout=0.1):
        super().__init__()
        self.encoder = TEncoder(img, d_model, N, heads, dropout)
        self.decoder = Decoder(tokens, d_model, N, heads, dropout)
        self.out = nn.Linear(d_model, tokens)
    def forward(self, src, trg, src_mask, trg_mask):
        e_outputs = self.encoder(src, src_mask)
        #print("DECODER")
        d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
        output = self.out(d_output)
        return output

    


## DataLoaders  preparations

### Custom collate function

create a collate funtion with inputs bacth of imgs, caption pairs and sorts them by caption length in descending order

In [None]:

from functools import partial
def pad_collate_ImgCap(samples, pad_idx = 0, pad_first:bool=True, backwards:bool=False, transpose:bool=False, device = device):
    "Function that collect samples and adds padding. Flips token order if needed"
    images, captions, ref_ind = zip(*samples)
    max_len_cap = max([len(c) for c in captions])
    decode_lengths = torch.tensor([len(c) for c in captions])
    res_cap = torch.zeros(len(samples), max_len_cap).long() + pad_idx
    ref_ind = torch.tensor(ref_ind)
    
    if backwards: pad_first = not pad_first
    for i,c in enumerate(captions):
        if pad_first: 
            res_cap[i,-len(c):] = LongTensor(c)
        else:         
            res_cap[i,:len(c)] = LongTensor(c)
    
    if backwards:
        cap = cap.flip(1)
    if transpose:
        res_cap.transpose_(0,1)
    
    images = torch.stack(images, 0, out=None)
    
    # Sort input data by decreasing lengths; why? apparent below
    decode_lengths, sort_ind = decode_lengths.sort(dim=0, descending=True)
    #set_trace()
    
    images = images[sort_ind].to(device)
    res_cap = res_cap[sort_ind].to(device)
    ref_ind = ref_ind[sort_ind].to(device)
    decode_lengths = decode_lengths.to(device)
    
    
    return (images, res_cap, decode_lengths,ref_ind), res_cap[:, 1:]

In [None]:
imgcap_collate_func = partial(pad_collate_ImgCap, pad_idx=0, pad_first=False, transpose=False)

In [None]:
data_loader_batch_size=25

### Create databunch with transformations

In [None]:
### train and validation dataloaders
train_sam = list(metadata.loc[metadata.split == 'train','numericalized'])
val_sam = list(metadata.loc[metadata.split == 'val','numericalized'])

### define Sampler for smapling bacthes for sorted pairs 
val_sampler = SortSampler(val_sam, key=lambda x: len(val_sam[x]))
trn_sampler = SortishSampler(train_sam, key=lambda x: len(train_sam[x]), bs=data_loader_batch_size)

### define data loaders for loading inputs batches into network
val_dl = DataLoader(dataset=valid_dataset, batch_size=data_loader_batch_size, sampler=val_sampler, collate_fn=imgcap_collate_func,pin_memory=False)
trn_dl = DataLoader(dataset=train_dataset, batch_size=data_loader_batch_size, sampler=trn_sampler, collate_fn=imgcap_collate_func,pin_memory=False)


# transformations 
tfms = get_transforms(flip_vert=False, max_lighting=0.1, max_zoom=1.05, max_warp=0.)

# fastai databunch object 
dataBunch = DataBunch(train_dl=trn_dl, valid_dl=val_dl ,device=device,collate_fn=imgcap_collate_func)

In [None]:
dataBunch.valid_ds[0];

In [None]:
# visualize sample 
##idx = 10000
##d = dataBunch.valid_ds
#img, cap,_ = d[-1]

#print(' '.join([dict(zip(vocab.values(),vocab.keys()))[x] for x in cap]))
#print(cap)
#plt.imshow(denorm(img));

## Model Intialization

### HyperParams

Tune HyperParams

In [None]:
emb_dim = conf['embedding.dim']
attention_dim = 512 # encoder_dim tranformed to attention_dim
decoder_dim = 512  #  word_emb_dim tranformed to decoder_dim
dropout = 0.5
encoder_dim = conf['encoder_dim']
vocab_size = len(vocab)
fine_tune_encoder = True
criterion = nn.CrossEntropyLoss().to(device)

Load Word Embeddings

In [None]:
# load word embeddings
embedding_file = 'captions/{}/embedding.pkl'.format(conf['embedding.folder'])
print(embedding_file)

with open(embedding_file,'rb') as f:
    embedding = pickle.load(f)
    if conf['embedding.name'] == 'fasttext-word-freq-2-plus':
      embedding[4021:] = np.random.normal(embedding.mean(),embedding.std(),(4, 300))

In [None]:
vocab_size

In [None]:
#import seaborn as sns
#from mpl_toolkits.mplot3d import Axes3D
#import numpy as np


#x = range(embedding.shape[0])
#y = range(embedding.shape[1])


#hf = plt.figure(figsize=(15,20))
#ha = hf.add_subplot(111, projection='3d')

#X, Y = np.meshgrid(x, y)

##ha.scatter(X, Y, embedding)

#plt.show()

### Model initialization

In [None]:
###########   Layer Initializations ##########

encode_img_size = 14

if conf['encoder'] == 'transformer':
  enc = TransformerEncoder()
else:
  enc = Encoder(encode_img_size,fine_tune=fine_tune_encoder)

dec = Decoder(attention_dim, 
              emb_dim, 
              decoder_dim, 
              vocab_size, 
              encoder_dim=encoder_dim, 
              dropout=dropout, 
              pretrained_embedding = embedding,teacher_forcing_ratio=1)


# Models Ensemble 
class Ensemble(nn.Module):
    def __init__(self,encoder, decoder):
        super(Ensemble, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    def forward(self,x1,caps,decode_lengths,inds): # you need flatten in between
        imgs = self.encoder(x1) # here input x1 is Images: output (batch_size, encoded_image_size, encoded_image_size, 2048)
        scores,decode_lengths,alphas, inds  = self.decoder(imgs, caps,decode_lengths,inds) #caps_sorted, decode_lengths, alphas, sort_ind
        return scores,decode_lengths, alphas, inds

# Testing
criterion = nn.CrossEntropyLoss().to(device)
enc = enc.to(device)
dec = dec.to(device)
arch = Ensemble(enc, dec).to(device)

In [None]:
enc

In [None]:
device

### Model Evaluation

In [None]:
def eval(img):
    encoder_out = self.model.encoder(img)

    num_pixels = encoder_out.size(1)
    encoder_out = encoder_out.expand(3, num_pixels, encoder_dim)


    # Tensor to store top k previous words at each step; now they're just 0
    k_prev_words = torch.LongTensor([[0]] * k).to(device)  # (k, 1)

    # Tensor to store top k sequences; now they're just <start>
    seqs = k_prev_words  # (k, 1)

    # Tensor to store top k sequences' scores; now they're just 0
    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)


    #dec_inp = torch.zeros(h.size(1), requires_grad=False).long()
    #dec_inp = dec_inp.to(self.device)

    complete_seqs = list()
    complete_seqs_scores = list()

    step = 1

    h,c = self.model.decoder.init_hidden_state(encoder_out)

    a = True
    while True:
        k_prev_words = torch.LongTensor([[0]] * 3).to(device)
        embeddings = self.model.decoder.embedding(k_prev_words).squeeze(1)
        awe, _ = self.model.decoder.attention(encoder_out, h)
        gate = self.model.decoder.sigmoid(self.model.decoder.f_beta(h))
        awe = gate * awe
        h, c = self.model.decoder.lstm(torch.cat([embeddings, awe], dim=1), (h, c))

        scores = self.model.decoder.fc(h)  # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)
        scores = top_k_scores.expand_as(scores) + scores
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)



        # Convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words / vocab_size  # (s)
        next_word_inds = top_k_words % vocab_size  # (s)

        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)


        # If any sequence is not complete
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                                       next_word != pad_idx]
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        # if at step any of seq completes we store complete_seqs and proceeds with
        # remaining k - len(complete_seqs) incomplete seqs 

        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
        k -= len(complete_inds)  # reduce beam length accordingly

        if k == 0:
            break

        seqs = seqs[incomplete_inds]
        h = h[prev_word_inds[incomplete_inds]]
        c = c[prev_word_inds[incomplete_inds]]

        encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

        # Break if things have been going on too long
        if step > 50:
            break
        step += 1

### Callback handler

Defining Callbacks

In [None]:
from statistics import mean
from fastai.callback import Callback
import copy as cp
from torch import nn
from fastai.vision import *
from pathlib import  Path, posixpath
from pdb import set_trace
from nltk.translate.bleu_score import corpus_bleu
from torch.nn.utils.rnn import pack_padded_sequence





device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


from statistics import mean
from fastai.callback import Callback
import copy as cp
from torch import nn
from fastai.vision import *
from pathlib import  Path, posixpath
from pdb import set_trace
from nltk.translate.bleu_score import corpus_bleu
from torch.nn.utils.rnn import pack_padded_sequence
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def beam_search(mod, img,vocab = None, beam_size = 3):
    with torch.no_grad():
        k = beam_size
        
        ## imput tensor preparation
        img = img.unsqueeze(0) #treating as batch of size 1

        ## model prepartion
        #mod = learn.model

        # encoder output
        encoder_out = mod.encoder(img)
        #encoder_dim = encoder_out.size(-1)
        encoder_size = encoder_out.size(-1)
        #num_pixels = encoder_out.size(1)
        encoder_out = encoder_out.view(1, -1, encoder_size)
        num_pixels = encoder_out.size(1)
        # expand or repeat 'k' time
        #encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)
        encoder_out = encoder_out.expand(k, num_pixels, encoder_size)  # (k, num_pixels, encoder_dim)
        # Tensor to store top k previous words at each step; now they're just <start>
        #k_prev_words = torch.LongTensor([[vocab['<start>']]] * k).to(device)  # (k, 1)
        k_prev_words = torch.LongTensor([[1]] * k).to(device) 
        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words       
        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)
        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h, c = mod.decoder.init_hidden_state(encoder_out)

        references = list()
        hypotheses = list()

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:
            embeddings = mod.decoder.embedding(k_prev_words).squeeze(1).float()  # (s, embed_dim)
            awe, _ = mod.decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)
            gate = mod.decoder.sigmoid(mod.decoder.f_beta(h))
            awe = (gate * awe)

            h, c = mod.decoder.lstm(torch.cat([embeddings, awe], dim=1), (h, c))
            scores = mod.decoder.fc(h)
            scores = F.log_softmax(scores, dim=1)
            

            # Add scores to prev scores
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

            # Convert unrolled indices to actual indices of scores
            #prev_word_inds = torch.LongTensor(top_k_words // len(vocab)).to(device)  # (s)
            prev_word_inds = torch.true_divide(top_k_words , len(vocab)).long().cpu()
            next_word_inds = top_k_words % len(vocab)  # (s)

            # Add new words to sequences
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1) stroes indices of words

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                                next_word != vocab[END_TOKEN]]

            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)


            # Break if things have been going on too long
            if step > 50:
                break
            step += 1

    if len(complete_seqs_scores) > 0:
      i = complete_seqs_scores.index(max(complete_seqs_scores))
      seq = complete_seqs[i]
    else:
       seq = [vocab[START_TOKEN],vocab[END_TOKEN]]
    # Hypotheses
    hypotheses.append([w for w in seq if w not in {vocab[START_TOKEN], vocab[END_TOKEN], vocab[PAD_TOKEN]}])

    return hypotheses

# Loss Function
def loss_func(input,targets, lamb=1):
    pred, decode_lengths, alphas,_ = input
    pred = pack_padded_sequence(pred, decode_lengths.cpu(), batch_first=True).to(device)
    targs = pack_padded_sequence(targets, decode_lengths.cpu(), batch_first=True).to(device)
    loss = nn.CrossEntropyLoss().to(device)(pred.data, targs.data.long())
    loss += (lamb*((1. - alphas.sum(dim=1)) ** 2.).mean()).to(device) #stochastic attention
    return  loss #loss(pred.data.long(), targets.data.long())



def topK_accuracy(input, targets, k=5):
    """
    Computes top-k accuracy, from predicted and true labels.
    :param scores: scores from the model
    :param targets: true labels
    :param k: k in top-k accuracy
    :return: top-k accuracy
    """
    pred, decode_lengths, alphas,_ = input
    batch_size = targets.size(0)
    scores = pack_padded_sequence(pred, decode_lengths.cpu(), batch_first=True).to(device)
    targ = pack_padded_sequence(targets, decode_lengths.cpu(), batch_first=True).to(device)
    batch_size = targ.data.size(0)
    _, ind = scores.data.topk(k, 1, True, True)
    correct = ind.eq(targ.data.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total * (100.0 / batch_size)


class TeacherForcingCallback(Callback):
    def __init__(self, learn:Learner):
        super().__init__()
        self.learn = learn
    
    def on_batch_begin(self, epoch,**kwargs):
        self.learn.model.decoder.teacher_forcing_ratio = (10 - epoch) * 0.1 if epoch < 10 else 0
        
    def on_batch_end(self,**kwargs):
        self.learn.model.decoder.teacher_forcing_ratio = 0.

class GradientClipping(LearnerCallback):
    "Gradient clipping during training."
    def __init__(self, learn:Learner, clip:float = 0.3):
        super().__init__(learn)
        self.clip = clip

    def on_backward_end(self, **kwargs):
        "Clip the gradient before the optimizer step."
        if self.clip: nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)

        

class Bleu1Metric(Callback):
    def __init__(self,metadata = None, vocab = None):
        super().__init__()
        self.vocab = vocab
        self.metadata = metadata

    def on_epoch_begin(self, **kwargs):
        self.bleureferences = list()
        self.bleucandidates = list()

        
    def on_batch_end(self, last_output, last_target, **kwargs):
        pred, decode_lengths,_,inds = last_output
        references = self.metadata.numericalized_ref.loc[inds.tolist()]
        _,pred_words = pred.max(dim=-1)
        pred_words, decode_lengths,references = list(pred_words), decode_lengths, list(references)
        hypotheses = list()
        for i,cap in enumerate(pred_words): hypotheses.append([x for x in cap.tolist()[:decode_lengths[i]] if x not in {self.vocab[START_TOKEN], self.vocab[END_TOKEN], self.vocab[PAD_TOKEN]}])
        #for i,cap in enumerate(pred_words): hypotheses.append([x for x in cap.tolist() if x not in {self.vocab['xxunk'], self.vocab['xxpad'], self.vocab['xxbos'], self.vocab['xxeos'],self.vocab['xxfld'],self.vocab['xxmaj'],self.vocab['xxup'],self.vocab['xxrep'],self.vocab['xxwrep']}])
        self.bleureferences.extend(references)
        self.bleucandidates.extend(hypotheses)

        
    def on_epoch_end(self, last_metrics, **kwargs):
        assert len(self.bleureferences) == len(self.bleucandidates)
        # print('\n'+' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[0]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[0][0]]))
        # print(' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[25]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[25][0]]))
        # print(' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[99]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[99][0]])+'\n')
        blue1 = corpus_bleu(self.bleureferences, self.bleucandidates,weights =(1.0/1.0,0,0,0))
        return add_metrics(last_metrics,blue1)


class Bleu2Metric(Callback):
    def __init__(self,metadata = None, vocab = None):
        super().__init__()
        self.vocab = vocab
        self.metadata = metadata

    def on_epoch_begin(self, **kwargs):
        self.bleureferences = list()
        self.bleucandidates = list()

        
    def on_batch_end(self, last_output, last_target, **kwargs):
        pred, decode_lengths,_,inds = last_output
        references = self.metadata.numericalized_ref.loc[inds.tolist()]
        _,pred_words = pred.max(dim=-1)
        pred_words, decode_lengths,references = list(pred_words), decode_lengths, list(references)
        hypotheses = list()
        for i,cap in enumerate(pred_words): hypotheses.append([x for x in cap.tolist()[:decode_lengths[i]] if x not in {self.vocab[START_TOKEN], self.vocab[END_TOKEN], self.vocab[PAD_TOKEN]}])
        #for i,cap in enumerate(pred_words): hypotheses.append([x for x in cap.tolist() if x not in {self.vocab['xxunk'], self.vocab['xxpad'], self.vocab['xxbos'], self.vocab['xxeos'],self.vocab['xxfld'],self.vocab['xxmaj'],self.vocab['xxup'],self.vocab['xxrep'],self.vocab['xxwrep']}])
        self.bleureferences.extend(references)
        self.bleucandidates.extend(hypotheses)

        
    def on_epoch_end(self, last_metrics, **kwargs):
        assert len(self.bleureferences) == len(self.bleucandidates)
        # print('\n'+' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[0]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[0][0]]))
        # print(' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[25]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[25][0]]))
        # print(' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[99]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[99][0]])+'\n')
        blue2 = corpus_bleu(self.bleureferences, self.bleucandidates,weights =(1.0/2.0,1.0/2.0,0,0))
        return add_metrics(last_metrics,blue2)


class Bleu3Metric(Callback):
    def __init__(self,metadata = None, vocab = None):
        super().__init__()
        self.vocab = vocab
        self.metadata = metadata

    def on_epoch_begin(self, **kwargs):
        self.bleureferences = list()
        self.bleucandidates = list()

        
    def on_batch_end(self, last_output, last_target, **kwargs):
        pred, decode_lengths,_,inds = last_output
        references = self.metadata.numericalized_ref.loc[inds.tolist()]
        _,pred_words = pred.max(dim=-1)
        pred_words, decode_lengths,references = list(pred_words), decode_lengths, list(references)
        hypotheses = list()
        for i,cap in enumerate(pred_words): hypotheses.append([x for x in cap.tolist()[:decode_lengths[i]] if x not in {self.vocab[START_TOKEN], self.vocab[END_TOKEN], self.vocab[PAD_TOKEN]}])
        #for i,cap in enumerate(pred_words): hypotheses.append([x for x in cap.tolist() if x not in {self.vocab['xxunk'], self.vocab['xxpad'], self.vocab['xxbos'], self.vocab['xxeos'],self.vocab['xxfld'],self.vocab['xxmaj'],self.vocab['xxup'],self.vocab['xxrep'],self.vocab['xxwrep']}])
        self.bleureferences.extend(references)
        self.bleucandidates.extend(hypotheses)

        
    def on_epoch_end(self, last_metrics, **kwargs):
        assert len(self.bleureferences) == len(self.bleucandidates)
        # print('\n'+' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[0]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[0][0]]))
        # print(' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[25]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[25][0]]))
        # print(' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[99]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[99][0]])+'\n')
        blue3 = corpus_bleu(self.bleureferences, self.bleucandidates,weights =(1.0/3.0,1.0/3.0,1.0/3.0,0))
        return add_metrics(last_metrics,blue3)


class Bleu4Metric(Callback):
    def __init__(self,metadata = None, vocab = None):
        super().__init__()
        self.vocab = vocab
        self.metadata = metadata

    def on_epoch_begin(self, **kwargs):
        self.bleureferences = list()
        self.bleucandidates = list()

        
    def on_batch_end(self, last_output, last_target, **kwargs):
        pred, decode_lengths,_,inds = last_output
        references = self.metadata.numericalized_ref.loc[inds.tolist()]
        _,pred_words = pred.max(dim=-1)
        pred_words, decode_lengths,references = list(pred_words), decode_lengths, list(references)
        hypotheses = list()
        for i,cap in enumerate(pred_words): hypotheses.append([x for x in cap.tolist()[:decode_lengths[i]] if x not in {self.vocab[START_TOKEN], self.vocab[END_TOKEN], self.vocab[PAD_TOKEN]}])
        #for i,cap in enumerate(pred_words): hypotheses.append([x for x in cap.tolist() if x not in {self.vocab['xxunk'], self.vocab['xxpad'], self.vocab['xxbos'], self.vocab['xxeos'],self.vocab['xxfld'],self.vocab['xxmaj'],self.vocab['xxup'],self.vocab['xxrep'],self.vocab['xxwrep']}])
        self.bleureferences.extend(references)
        self.bleucandidates.extend(hypotheses)

        
    def on_epoch_end(self, last_metrics, **kwargs):
        assert len(self.bleureferences) == len(self.bleucandidates)
        # print('\n'+' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[0]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[0][0]]))
        # print(' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[25]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[25][0]]))
        # print(' '.join([list(self.vocab.keys())[i-1] for i in self.bleucandidates[99]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.bleureferences[99][0]])+'\n')
        #blue3 = corpus_bleu(self.bleureferences, self.bleucandidates,weights =(1.0/3.0,1.0/3.0,1.0/3.0,0))
        bleu4 = corpus_bleu(self.bleureferences, self.bleucandidates,weights =(0.25,0.25,0.25,0.25))
        return add_metrics(last_metrics,bleu4)


class BeamSearch(LearnerCallback):
    def __init__(self,learn:Learner,metadata = None, vocab = None, beam_fn = beam_search):
        super().__init__(learn)
        self.beam_fn = beam_fn
        self.vocab = vocab
        self.metadata = metadata

    def on_epoch_begin(self, **kwargs):
        self.beamreferences = list()
        self.beamcandidates = list()

    def on_batch_end(self,last_input, last_target, **kwargs):
        model_copy = cp.deepcopy(self.learn.model)
        imgs,_,_,inds = last_input
        references = self.metadata.numericalized_ref.loc[inds.tolist()]
        references = list(references)
        hypotheses = list()
        for img in imgs: hypotheses.append(self.beam_fn(model_copy,img,self.vocab)[0])
        self.beamreferences.extend(references)
        self.beamcandidates.extend(hypotheses)

    def on_epoch_end(self, last_metrics, **kwargs):
        assert len(self.beamreferences) == len(self.beamcandidates)
        #print(' '.join([list(self.vocab.keys())[i-1] for i in self.beamcandidates[8]])+' | '+' '.join([list(self.vocab.keys())[i-1] for i in self.beamreferences[8][0]]))
        return add_metrics(last_metrics,corpus_bleu(self.beamreferences, self.beamcandidates))

### Learner create Object

In [None]:
loss_func

In [None]:
opt_fn = partial(optim.Adam, betas=(0.8, 0.99))


# metrics functions
loss_func = partial(loss_func, lamb = 1)
#BleuMetric(metadata,vocab)
beam_fn = partial(beam_search,beam_size = conf['beam_size'])

learn = Learner(dataBunch,arch,loss_func= loss_func,opt_func=opt_fn,  metrics=[topK_accuracy,Bleu1Metric(metadata,vocab),Bleu2Metric(metadata,vocab),Bleu3Metric(metadata,vocab),Bleu4Metric(metadata,vocab)],callback_fns=[ShowGraph]) #,TeacherForcingTurnOff,TeacherForcingCallback

# split model into encoder and decoder layer groups
split_list = [learn.model.encoder,learn.model.decoder]
learn.split(split_list)
len(learn.layer_groups)

In [None]:
opt_fn

In [None]:
# def summary_trainable(learner):
#     result = []
#     total_params_element = 0
#     def check_trainable(module):
#         nonlocal total_params_element
#         if len(list(module.children())) == 0:
#             num_param = 0
#             num_trainable_param = 0
#             num_param_numel = 0
#             for parameter in module.parameters():
#                 num_param += 1
#                 if parameter.requires_grad:
#                     num_param_numel += parameter.numel()
#                     total_params_element += parameter.numel()
#                     num_trainable_param += 1

#             result.append({'module': module, 'num_param': num_param , 'num_trainable_param' : num_trainable_param, 'num_param_numel': num_param_numel})
#     learner.model.apply(check_trainable)

#     print("{: <85} {: <17} {: <20} {: <40}".format('Module Name', 'Total Parameters', 'Trainable Parameters', '# Elements in Trainable Parametrs'))
#     for row in result:
#         print("{: <85} {: <17} {: <20} {: <40,}".format(row['module'].__str__(), row['num_param'], row['num_trainable_param'], row['num_param_numel']))
#     print('Total number of parameters elements {:,}'.format(total_params_element))



# # uncomment below to print summary of trainable layers 
# #learn.freeze()
# summary_trainable(learn)

## Training

### Stage 1: with encoder part freezed

In [None]:
BATCH_SIZE = 6

In [None]:
from PIL import Image

# learn.freeze()
# learn.lr_find(end_lr = 1)
# learn.recorder.plot(suggestion=True)

In [None]:
### train for 10 epochs
if conf['search'] == 'beam':
  learn.freeze()
  EPOCH = 10
  TUNED_LEARNING_RATE = 5e-4
  learn.fit(EPOCH,TUNED_LEARNING_RATE,callbacks = [SaveModelCallback(learn, monitor='bleu4_metric',name='Stage_1_Best_Model'),
                                          GradientClipping(learn = learn, clip=5.)
                                          ,BeamSearch(learn = learn,metadata = metadata, vocab = vocab, beam_fn = beam_fn)])
else:
  learn.freeze()
  EPOCH = 10
  TUNED_LEARNING_RATE = 5e-4
  learn.fit(EPOCH,TUNED_LEARNING_RATE,callbacks = [SaveModelCallback(learn, monitor='bleu4_metric',name='Stage_1_Best_Model'),
                                          GradientClipping(learn = learn, clip=5.)])
                                          #,BeamSearch(learn = learn,metadata = metadata, vocab = vocab, beam_fn = beam_fn)])

### Stage-2: Unfreeze encoder part as well

In [None]:
# learn.data.batch_size = BATCH_SIZE
# learn.unfreeze()
# learn.load('Stage_1_Best_Model');
# learn.lr_find(start_lr=1e-11,end_lr = 1e-03)
# learn.recorder.plot(suggestion=True)

In [None]:
if conf['search'] == 'beam':
  learn.data.batch_size = BATCH_SIZE
  learn.unfreeze()
  learn.load('Stage_1_Best_Model');
  EPOCH = 10
  TUNED_LEARNING_RATE = 1e-4
  learn.fit_one_cycle(EPOCH, TUNED_LEARNING_RATE,
                      callbacks = [SaveModelCallback(learn, monitor='bleu4_metric',name='Stage_2_Best_Model'),
                                  GradientClipping(learn = learn, clip=5.)
                                  ,BeamSearch(learn = learn,metadata = metadata, vocab = vocab, beam_fn = beam_fn)])
else:
  learn.data.batch_size = BATCH_SIZE
  learn.unfreeze()
  learn.load('Stage_1_Best_Model');
  EPOCH = 10
  TUNED_LEARNING_RATE = 1e-4
  learn.fit_one_cycle(EPOCH, TUNED_LEARNING_RATE,
                      callbacks = [SaveModelCallback(learn, monitor='bleu4_metric',name='Stage_2_Best_Model'),
                                  GradientClipping(learn = learn, clip=5.)])
                                  #,BeamSearch(learn = learn,metadata = metadata, vocab = vocab, beam_fn = beam_fn)])


In [None]:
end_time = datetime.datetime.now()
print("total time={}".format(end_time-start_time))

In [None]:
# learn.save('last_epoch_stage2');