In [None]:
#!g1.1
%load_ext autoreload
%autoreload 2

#### Make sure you downloaded coco, flickr, glove. With `bash` command and these scripts `load_coco.sh`, `load_flickr.sh`, `load_glove.sh` you can download them

In [None]:
! git clone https://github.com/tojiboyevf/image_captioning.git
! mv  -v /content/image_captioning/* /content/  

fatal: destination path 'image_captioning' already exists and is not an empty directory.
mv: cannot stat '/content/image_captioning/*': No such file or directory


In [None]:
#!g1.1
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
#!g1.1
import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [None]:
#!g1.1
import pickle

from datasets.flickr8k import Flickr8kDataset


DATASET_BASE_PATH = 'data/flickr8k/'
VOCAB = 'vocab_set.pkl'

train_set = Flickr8kDataset(
    dataset_base_path=DATASET_BASE_PATH, dist='train',
    device=device, return_type='tensor', load_img_to_memory=False)


vocab_set = train_set.get_vocab()
with open(VOCAB, 'wb') as f:
    pickle.dump(vocab_set, f)
    

vocab, word2idx, idx2word, max_len = vocab_set
vocab_size = len(vocab)

val_set = Flickr8kDataset(
    dataset_base_path=DATASET_BASE_PATH, dist='val', vocab_set=vocab_set,
    device=device, return_type='corpus', load_img_to_memory=False)

test_set = Flickr8kDataset(
    dataset_base_path=DATASET_BASE_PATH, dist='test', vocab_set=vocab_set,
    device=device, return_type='corpus', load_img_to_memory=False)

train_eval_set = Flickr8kDataset(
    dataset_base_path=DATASET_BASE_PATH, dist='train', vocab_set=vocab_set,
    device=device, return_type='corpus', load_img_to_memory=False)


print(
    f"The number of samples in:\ntrain: {len(train_set)};"
    + f" validation: {len(val_set)}; test: {len(test_set)}\n"
    + f"Vocabulary size: {vocab_size}; Max length of a sentence: {max_len};"
)

The number of samples in:
train: 30000; validation: 1000; test: 1000
Vocabulary size: 7708; Max length of a sentence: 40;


In [None]:
#!g1.1
from torchvision import transforms
from torch.utils.data import DataLoader
from datasets.coco import CoCoDataloader

BATCH_SIZE = 50

eval_transformations = transforms.Compose([
    transforms.Resize(256), 
    transforms.CenterCrop(224), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

val_set.transformations = eval_transformations
test_set.transformations = eval_transformations
train_eval_set.transformations = eval_transformations

eval_collate_fn = lambda batch: (torch.stack([x[0] for x in batch]), [x[1] for x in batch], [x[2] for x in batch])
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False,
                        collate_fn=eval_collate_fn, drop_last=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False,
                         collate_fn=eval_collate_fn, drop_last=True)
train_eval_loader = DataLoader(train_eval_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False,
                               collate_fn=eval_collate_fn, drop_last=True)

vocab_from_file = True     # if True, load existing vocab file
split_size = 0.1

coco_val_loader = CoCoDataloader(transform=eval_transformations,
                        batch_size=BATCH_SIZE,
                        vocab_from_file=vocab_from_file,
                        vocab_file='./vocab_set.pkl',
                        size=split_size,
                        img_folder='data/coco/val2014',
                        annotations_file='data/coco/annotations/captions_val2014.json',
                        shuffle=True,
                        random_seed=42)

In [None]:
#!g1.1
start_token = word2idx['<start>']
end_token = word2idx['<end>']
pad_token = word2idx['<pad>']
max_seq_len = max_len

In [None]:
#!g1.1
import random
import math
import torch
from torch import nn as nn
import numpy as np

np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

def embedding_layer(trainable=True, embedding_matrix=None, **kwargs):
    emb_layer = nn.Embedding(**kwargs)
    if embedding_matrix is not None:
        emb_layer.weight = nn.Parameter(torch.from_numpy(embedding_matrix).float())
    trainable = (embedding_matrix is None) or trainable
    if not trainable:
        emb_layer.weight.requires_grad = False
    return emb_layer


class Encoder(nn.Module): # asosan shu bn train qildik
    def __init__(self, embed_size):
        super().__init__()
        self.vit = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
        self.embed = nn.Sequential(
            nn.Linear(in_features=self.vit.head.out_features, out_features=embed_size),
            nn.GELU(),
            nn.BatchNorm1d(embed_size, momentum=0.01),
            nn.Dropout(0.1)
        )
    
    def forward(self, images):
        with torch.no_grad():
            features = self.vit(images)
        features = self.embed(features)
        return features

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=40):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        

    def forward(self, x):
        if self.pe.size(0) < x.size(0):
            self.pe = self.pe.repeat(x.size(0), 1, 1).to(device)
        self.pe = self.pe[:x.size(0), : , : ]
        
        x = x + self.pe
        return self.dropout(x)

class Decoder(nn.Module):
    def __init__(
        self,
        num_heads,
        decoder_layers,
        embed_size, 
        vocab_size, 
        embedding_matrix=None, 
        train_embd=True, 
        max_len=40, 
        dropout=0.1
    ):
        super(Decoder, self).__init__()
        
        self.embedding = embedding_layer(num_embeddings=vocab_size, embedding_dim=embed_size,
                                     embedding_matrix=embedding_matrix, trainable=train_embd)
        self.pos_encoder = PositionalEncoding(embed_size, dropout)

        self.decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer = nn.TransformerDecoder(self.decoder_layer, num_layers=decoder_layers)
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.embedding_size = embed_size
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc_out.bias.data.zero_()
        self.fc_out.weight.data.uniform_(-initrange, initrange)
    
    def generate_Mask(self, size, decoder_inp):
        decoder_input_mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        decoder_input_mask = decoder_input_mask.float().masked_fill(decoder_input_mask == 0, float('-inf')).masked_fill(decoder_input_mask == 1, float(0.0))

        decoder_input_pad_mask = decoder_inp.float().masked_fill(decoder_inp == 0, float(0.0)).masked_fill(decoder_inp > 0, float(1.0))
        decoder_input_pad_mask_bool = decoder_inp == 0

        return decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool
    

    def forward(self, features, captions):
        features = features.permute(1,0,2)
        decoder_inp_embed = self.embedding(captions)* math.sqrt(self.embedding_size)
        
        decoder_inp_embed = self.pos_encoder(decoder_inp_embed)
        decoder_inp_embed = decoder_inp_embed.permute(1,0,2)
        

        decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool = self.generate_Mask(captions.size(1), captions)
        decoder_input_mask = decoder_input_mask.to(device)
        decoder_input_pad_mask = decoder_input_pad_mask.to(device)
        decoder_input_pad_mask_bool = decoder_input_pad_mask_bool.to(device)
        

        decoder_output = self.transformer(tgt = decoder_inp_embed, memory = features, tgt_mask = decoder_input_mask, tgt_key_padding_mask = decoder_input_pad_mask_bool)
        
        final_output = self.fc_out(decoder_output)
        
        return final_output,  decoder_input_pad_mask
    
    def sample(self, features, max_len=40, topk=3, start_token=59, end_token=57, pad_token=58):
        input_seq = torch.ones(features.size(0), max_len).type(torch.LongTensor) * pad_token
        input_seq[:, 0] = start_token
        input_seq = input_seq.to(device)
        for i in range(max_len-1):
            output, _ = self.forward(features, input_seq)
            output = output[i, :, :]
            predicted = output.argmax(1)
            input_seq[:, i+1] = predicted
        return input_seq


class Captioner(nn.Module):
    def __init__(self, num_heads, decoder_layers, embed_size, vocab_size, embedding_matrix=None, train_embd=True):
        super().__init__()
        self.encoder = Encoder(embed_size)
        self.decoder = Decoder(num_heads, decoder_layers, embed_size, vocab_size,
                               embedding_matrix=embedding_matrix, train_embd=train_embd)

    def forward(self, images, captions):
        features = self.encoder(images)
        features = features.unsqueeze(1)
        output, padding_mask = self.decoder(features, captions)
        return output, padding_mask
    
    def sample(self, images, max_len=max_seq_len, topk=3, start_token=start_token, end_token=end_token, pad_token=pad_token):
        features = self.encoder(images)
        features = features.unsqueeze(1)
        captions = self.decoder.sample(features=features, max_len=max_len, topk=topk,
                                       start_token=start_token, end_token=end_token,
                                       pad_token=pad_token)
        return captions

In [None]:
#!g1.1
from glove import embedding_matrix_creator
EMBEDDING_DIM = 200
EMBEDDING = f"GLV{EMBEDDING_DIM}"

embedding_matrix = embedding_matrix_creator(embedding_dim=EMBEDDING_DIM, word2idx=word2idx)
print(f"Embedding matrix shape: {embedding_matrix.shape}")

  0%|          | 0/7708 [00:00<?, ?it/s]

Embedding matrix shape: (7708, 200)


In [None]:
#!g1.1
from metrics import *
from utils_torch import words_from_tensors_fn
import numpy as np
from tqdm import tqdm

sentence_bleu_score_fn = bleu_score_fn(4, 'sentence')
corpus_bleu_score_fn = bleu_score_fn(4, 'corpus')
tensor_to_word_fn = words_from_tensors_fn(idx2word=idx2word)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train_set.pad_value)

def evaluate_model(data_loader, model, data, bleu_score_fn, tensor_to_word_fn, desc=''):
    running_bleu = [0.0] * 5
    model.eval()
    t = tqdm(iter(data_loader), desc=f'{desc}')
    for batch_idx, batch in enumerate(t):
        if data=='coco:':
            images, captions = batch
            images = images.to(device)
        else:
            images, captions, lengths = batch
        outputs = tensor_to_word_fn(model.sample(images).cpu().numpy())

        for i in (1, 2, 3, 4):
            running_bleu[i] += bleu_score_fn(captions, outputs, n=i)
        t.set_postfix({
            'bleu1': running_bleu[1] / (batch_idx + 1),
            'bleu4': running_bleu[4] / (batch_idx + 1),
        }, refresh=True)
    for i in (1, 2, 3, 4):
        running_bleu[i] /= len(data_loader)
    return running_bleu

In [None]:
! pip install timm
! gdown --id 123

In [None]:
NUM_HEADS_DEC = 10
NUM_LAYERS_DEC = 6

path = './vit_transformer_b50_emdGLV200_best_val_bleu.pt'
checkpoint = torch.load(path)
model = Captioner(NUM_HEADS_DEC, NUM_LAYERS_DEC, EMBEDDING_DIM, vocab_size, embedding_matrix, False).to(device)
model.decoder.pos_encoder.pe = model.decoder.pos_encoder.pe.repeat(BATCH_SIZE, 1, 1).to(device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()

### Flickr

In [None]:
#!g1.1

inter_params = {
    'model': model,
    'bleu_score_fn': corpus_bleu_score_fn,
    'tensor_to_word_fn': tensor_to_word_fn,
    'data': 'flickr'
}

with torch.no_grad():
    model.eval()
    train_bleu = evaluate_model(
        desc=f'Train: ',
        data_loader=train_eval_loader,
        **inter_params,
    )
    
    val_bleu = evaluate_model(
        desc=f'Val: ',
        data_loader=val_loader,
        **inter_params,
    )
    
    test_bleu = evaluate_model(
        desc=f'Test: ',
        data_loader=test_loader,
        **inter_params,
    )
    for setname, result in zip(('train', 'val', 'test'), (train_bleu, val_bleu, test_bleu)):
        print(setname, end=' ')
        for ngram in (1, 2, 3, 4):
            print(f'Bleu-{ngram}: {result[ngram]}', end=' ')
        print()

### COCO

In [None]:
with torch.no_grad():
    val_bleu = evaluate_model(
                    desc=f'\tValidation Bleu Score: ',
                    model=model,
                    data='coco',
                    bleu_score_fn=corpus_bleu_score_fn,
                    tensor_to_word_fn=tensor_to_word_fn,
                    data_loader=data_loader,
                    vocab_size=vocab_size,
                )

    print('val', end=' ')
    for ngram in (1, 2, 3, 4):
        print(f'Bleu-{ngram}: {val_bleu[ngram]}', end=' ')
    print()

	Validation Bleu Score: 100%|██████████| 406/406 [08:44<00:00,  1.29s/it, bleu1=0.000106, bleu4=0.0119]
