In [None]:
! pip install git+https://github.com/huggingface/transformers.git

In [None]:
!wget https://raw.githubusercontent.com/rathiankit03/ImageCaptionHindi/master/Flickr8kHindiDataset/Flickr8k-Hindi.txt

In [1]:
import transformers
print(transformers.__version__)

In [17]:
import pandas as pd
base_path = '../input/flickr8k/Images/'
with open('./Flickr8k-Hindi.txt') as f:
    data = []
    
    for i in f.readlines():
        sp = i.split(' ')
        data.append([sp[0] + '.jpg', ' '.join(sp[1:])])
        
hindi = pd.DataFrame(data, columns = ['images', 'text'])
#hindi['images'] = hindi['images']!='2258277193_58694969e2'
hindi.head()

In [18]:
hindi = hindi[hindi['images']!='2258277193_58694969e2']
hindi

In [19]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(hindi, test_size=0.2)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image

class Image_Caption_Dataset(Dataset):
    def __init__(self,root_dir,df, feature_extractor,tokenizer,max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.max_length=max_target_length
        
    def __len__(self,df):
        return self.df.shape[0]
    
    def __getitem__(self,idx):
        #return image
        image_path = self.df['images'][idx]
        text = self.df['text'][idx]
        #prepare image
        image = Image.open(self.root_dir+'/'+image_path).convert("RGB")
        pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values
        #add captions by encoding the input
        captions = self.tokenizer(text,
                                 padding='max_length',
                                 max_length=self.max_length).input_ids
        captions = [caption if caption != self.tokenizer.pad_token_id else -100 for caption in captions]
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(captions)}
        return encoding

In [None]:
from transformers import ViTFeatureExtractor,AutoTokenizer

encoder_checkpoint = 'google/vit-base-patch16-224'
decoder_checkpoint = 'surajp/gpt2-hindi'

feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)

In [None]:
root_dir = "../input/flickr8k/Images"


train_dataset = Image_Caption_Dataset(root_dir=root_dir,
                           df=train_df,
                           feature_extractor=feature_extractor,
                                     tokenizer=tokenizer)
val_dataset = Image_Caption_Dataset(root_dir=root_dir,
                           df=test_df,
                           feature_extractor=feature_extractor,
                                     tokenizer=tokenizer)

In [None]:
encoding = train_dataset[0]
for k,v in encoding.items():
  print(k, v.shape)

In [None]:
labels = encoding['labels']
labels[labels == -100] = tokenizer.pad_token_id
label_str = tokenizer.decode(labels, skip_special_tokens=True)
print(label_str)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
eval_dataloader = DataLoader(val_dataset, batch_size=4)

In [38]:
from transformers import VisionEncoderDecoderModel
# initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_checkpoint, decoder_checkpoint)
model.to(device)

In [32]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [40]:
from datasets import load_metric

bleu_metric = load_metric("bleu")

In [39]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    bleu = bleu_metric.compute(predictions=pred_str, references=label_str)

    return {"bleu": bleu}

In [None]:
from transformers import AdamW
from tqdm.notebook import tqdm


In [35]:
from transformers import AdamW
from tqdm.notebook import tqdm

optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(2):  # loop over the dataset multiple times
   # train
   model.train()
   train_loss = 0.0
   for batch in tqdm(train_dataloader):
      # get the inputs
      for k,v in batch.items():
        batch[k] = v.to(device)

      # forward + backward + optimize
      outputs = model(**batch)
      loss = outputs.loss
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      train_loss += loss.item()

   print(f"Loss after epoch {epoch}:", train_loss/len(train_dataloader))
    
   # evaluate
   model.eval()
   valid_bleu = 0.0
   with torch.no_grad():
     for batch in tqdm(eval_dataloader):
       # run batch generation
       outputs = model.generate(batch["pixel_values"].to(device))
       # compute metrics
       cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
       valid_bleu += cer 

   print("Validation BLEU:", valid_bleu / len(eval_dataloader))

model.save_pretrained(".")