In [37]:
data = []
with open("/kaggle/input/image-captioning-for-visually-impaired/caption_eng/token-english.txt",'r') as file:
    for line in file:
        parts = line.strip().split("\t")
        if len(parts) == 2:
            image_name_id , caption =  parts
            image_name , id = image_name_id.split("#")
            data.append({"image_name":image_name,"id":id,"caption":caption})

In [38]:
import pandas as pd
df = pd.DataFrame(data)

In [39]:
data_2 = []
with open("/kaggle/input/image-captioning-for-visually-impaired/caption_eng/token-english2.txt",'r') as file:
    for line in file:
        parts = line.strip().split("\t")
        if len(parts) == 2:
            image_name_id , caption =  parts
            image_name , id = image_name_id.split("#")
            data.append({"image_name":image_name,"id":id,"caption":caption})

In [40]:
import pandas as pd
df_2 = pd.DataFrame(data_2)

In [41]:
data = pd.concat([df,df_2])

In [42]:
grouped_captions = data.groupby("image_name")["caption"].apply(list)

In [43]:
from transformers import ViTFeatureExtractor , AutoTokenizer

In [44]:
class config:
    encoder = "google/vit-base-patch16-224-in21k"
    decoder = "gpt2"
    train_batch_size = 4

In [45]:
def build_inputs_with_special_tokens(self,token_ids_0):
    ouputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return ouputs
AutoTokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens

In [46]:
feature_extractor = ViTFeatureExtractor.from_pretrained(config.encoder)
tokenizer = AutoTokenizer.from_pretrained(config.decoder)



In [47]:
import os
import random
import torch
from torch.utils.data import Dataset
from PIL import Image

class ImageCaptioningDataset(Dataset):
    def __init__(self, df, tokenizer, img_dir, feature_extractor, transform=None):
        self.data = df
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.transform = transform
        self.feature_extractor = feature_extractor
        self.group_data = self.data.groupby("image_name")["caption"].apply(list)
        self.max_length = 20
    
    def __len__(self):
        return len(self.group_data)
    
    def __getitem__(self, idx):
        img_name = self.group_data.index[idx]
        captions = self.group_data.iloc[idx]
        caption = random.choice(captions)
        
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        pixels = self.feature_extractor(image, return_tensors="pt").pixel_values
        
        tokenized_caption = self.tokenizer(
            caption, 
            padding="max_length", 
            truncation=True,
            max_length=self.max_length, 
            return_tensors="pt"
        ).input_ids.squeeze(0)
        
        labels = tokenized_caption.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        return {"pixel_values": pixels.squeeze(), "labels": labels}

In [48]:
from torchvision import transforms
transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])
img_dir = "/kaggle/input/image-captioning-for-visually-impaired/images/images"
DATA = ImageCaptioningDataset(data,tokenizer,img_dir,feature_extractor=feature_extractor,transform=transforms)

In [49]:
import os ,random

In [50]:
from torch.utils.data import random_split

# Calculate the sizes for train and test splits
total_size = len(DATA)
train_size = int(0.8 * total_size)  # 80% for training
test_size = total_size - train_size  # Remaining 20% for testing

# Split the dataset
train_dataset, test_dataset = random_split(DATA, [train_size, test_size])



In [51]:
from transformers import VisionEncoderDecoderModel 
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.encoder,config.decoder)

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

In [62]:
from transformers import Seq2SeqTrainingArguments , Seq2SeqTrainer

train_args = Seq2SeqTrainingArguments(per_device_train_batch_size=config.train_batch_size,eval_strategy="epoch",
                                     do_train=True,do_eval=True,num_train_epochs=5,output_dir="/kaggle/working/",per_device_eval_batch_size=4,report_to=["none"]  
)

In [56]:
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.sep_token_id

model.config.vocab_size = 72
model.config.max_lenght = 20
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.lenght_penalty = 2.0
model.config.num_beams = 4


In [54]:
import torch

In [63]:
trainer = Seq2SeqTrainer(model=model,tokenizer=feature_extractor,train_dataset=train_dataset,args=train_args,
                         eval_dataset=test_dataset)
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.3603,0.41686
2,0.3533,0.410268
3,0.3493,0.417573
4,0.3458,0.408083
5,0.3513,0.402623


TrainOutput(global_step=2790, training_loss=0.3523199730876526, metrics={'train_runtime': 1366.2453, 'train_samples_per_second': 8.161, 'train_steps_per_second': 2.042, 'total_flos': 2.0121723859894272e+18, 'train_loss': 0.3523199730876526, 'epoch': 5.0})