In [None]:
! pip install transformers

In [None]:
import os
from PIL import Image
from collections import Counter

import numpy as np
import pandas as pd

import torch
from torch import nn, optim
from torchvision import transforms

In [None]:
base_path = '../input/flickr8k/Images/'
with open('/kaggle/input/hindi-caption/Flickr8k-Hindi.txt') as f:
    data = []
    
    for i in f.readlines():
        sp = i.split(' ')
        data.append([sp[0] + '.jpg', ' '.join(sp[1:])])
        
hindi = pd.DataFrame(data, columns = ['images', 'text'])
hindi.head()

# LM Scratch for Swin and EfficientNet

In [None]:
paragraphs = list(hindi['text'])
text = " ".join(paragraphs)
words = text.split(" ")

class Tokenizer:
    def __init__(self, maxlen = 50):
        self.vocab = Counter()
        self.maxlen = maxlen

    def build_vocab(self, texts):
        for sent in texts:
            self.vocab.update(sent.split(' '))
        
        v,k = {'_pad_': 0, '_unk_': 1}, 2

        for i in sorted(self.vocab, key = self.vocab.get, reverse = True): 
            if i == '_pad_':
                continue
            v[i] = k
            k += 1
        self.vocab = v
        self.idx2word = {v[k]:k for k in v}

    def __call__(self, text):
        ans = []
        l = [(self.vocab[j] if j in self.vocab else 1) for j in text.split()]
        if len(l) >= self.maxlen:
            return l[:self.maxlen]
        else:
            l.extend([0]*(self.maxlen - len(l)))
            return l

In [None]:
class CaptionDataset:
    def __init__(self, df, tranform):
        self.df = df
        self.transform = tranform
        
        self.tokenizer = Tokenizer(32)
        self.tokenizer.build_vocab(df['text'])
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'images']
        
        img = Image.open('../input/flickr8k/Images/' + img_path)
        img = self.transform(img)
        
        caption = self.tokenizer(self.df.loc[idx, 'text'])
        
        return {
            'image': torch.tensor(img),
            'text': torch.tensor(caption, dtype = torch.long)
        }

In [None]:
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = CaptionDataset(hindi, transform)

train_data, test_data = random_split(dataset, [32455,8000])
valid_data, test_data = random_split(test_data, [4000,4000])

trainloader = DataLoader(train_data, batch_size = 8)
validloader = DataLoader(valid_data, batch_size = 8)

In [None]:
from transformers import ViTModel
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.img_model = ViTModel.from_pretrained('google/vit-base-patch16-224', add_pooling_layer=False)
        self.features = nn.Linear(768,768)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, image):
        x = torch.mean(self.img_model(image).last_hidden_state, axis = 1)
        x = self.dropout(F.relu(self.features(x)))
        return x
    
class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, 768)
        self.lstm = nn.LSTM(768, 1024, num_layers=1)
        self.out = nn.Linear(1024, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, text):
        embeds = self.dropout(self.embed(text))
        print(embeds.size(), features.size())
        embeds = torch.cat((features.unsqueeze(0), embeds.permute(1,0,2)), dim=0)
        hc, _ = self.lstm(embeds)
        outputs = self.out(hc)
        return outputs

class CaptionModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(vocab_size)

    def forward(self, images, text):
        features = self.encoder(images)
        outputs = self.decoder(features, text)
        return outputs
    
    def generate(self, image, vocab, max_length=50):
        result_caption = []

        with torch.no_grad():
            x = self.encoder(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoder.lstm(x, states)
                output = self.decoder.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                result_caption.append(predicted.item())
                x = self.decoder.embed(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "_pad_":
                    break

        return [vocabulary.idx2word[idx] for idx in result_caption]
    
model = CaptionModel(len(dataset.tokenizer.vocab)).cuda()

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index = 0)
optimizer = optim.Adam(model.parameters())

In [None]:
from tqdm.notebook import tqdm

epochs = 5
min_valid_loss = np.inf

for e in range(epochs):
    model.train()
    train_loss = 0.0
    
    for batch in tqdm(trainloader):
        optimizer.zero_grad(set_to_none = True)
        img, text = batch['image'].cuda(), batch['text'].cuda()
        
        pred = model(img, text)[:-1]
        loss = criterion(pred.reshape(-1, pred.shape[2]), text.reshape(-1))
        train_loss = 0.0
        
        loss.backward()
        optimizer.step()
        
    valid_loss = 0.0
    with torch.no_grad():
        model.eval()
        for batch in tqdm(validloader):
            img, text = batch['image'].cuda(), batch['text'].cuda()
            
            pred = model(img, text)
            loss = criterion(pred.reshape(-1, pred.shape[2]), text.reshape(-1))
            valid_loss += 0
            
    print(f'Training Loss:{train_loss:.4f}\t\t\t Validation Loss:{train_loss:.4f}')
    if min_valid_loss > valid_loss:
        print(f'Validation Loss Decreased From {min_valid_loss}----->{valid_loss}     ....Saving Model')
        torch.save(model.state_dict(), PATH)

# GPT2 and VIT

In [None]:
from transformers import BertConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel

In [None]:
class CaptionDataset:
    def __init__(self, df, tranform):
        self.df = df
        self.transform = tranform
        
        self.tokenizer = Tokenizer(32)
        self.tokenizer.build_vocab(df['text'])
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'images']
        
        img = Image.open('../input/flickr8k/Images/' + img_path)
        img = self.transform(img)
        
        caption = self.tokenizer(self.df.loc[idx, 'text'])
        
        return {
            'image': torch.tensor(img),
            'text': torch.tensor(caption, dtype = torch.long)
        }

In [None]:
>>> from transformers import VisionEncoderDecoderModel
>>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
>>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('google/vit-base-patch16-224-in21k', 'bert-base-uncased')
>>> # saving model after fine-tuning
>>> model.save_pretrained("./vit-bert")
>>> # load fine-tuned model
>>> model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")