# Import Libraries

In [3]:
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split

from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import GPT2Tokenizer

from transformers import ViTModel

In [4]:
BATCH_SIZE = 32

# Dataset

In [5]:
df = pd.read_csv("data/results.csv", delimiter="|")

df.head()

Unnamed: 0,image_name,comment_number,comment
0,1000092795.jpg,0,Two young guys with shaggy hair look at their...
1,1000092795.jpg,1,"Two young , White males are outside near many..."
2,1000092795.jpg,2,Two men in green shirts are standing in a yard .
3,1000092795.jpg,3,A man in a blue shirt standing in a garden .
4,1000092795.jpg,4,Two friends enjoy time spent together .


In [6]:
train_df, test_df = train_test_split(df, test_size=.2, random_state=42)

train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

train_df.shape, test_df.shape

((127132, 3), (31783, 3))

In [7]:
class Flickr30kDataset(Dataset):
    def __init__(self, image_dir, caption_file, transform=None, tokenizer=None, max_length=50):
        
        self.image_dir = image_dir
        self.caption_file = caption_file
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.captions = self._load_captions()


    def _load_captions(self):

        if isinstance(self.caption_file, str):
            df = pd.read_csv(self.caption_file, delimiter="|")
        else:
            df = self.caption_file
        return df


    def __len__(self):
        return len(self.captions)


    def __getitem__(self, i):

        image_name = self.captions.loc[i, "image_name"]
        caption = self.captions.loc[i, " comment"]

        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        encoding = self.tokenizer.encode_plus(
            caption,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        caption_tokenized = encoding['input_ids'].squeeze(0)

        return image, caption_tokenized


In [8]:
image_dir = "data/flickr30k_images"
caption_file = "data/results.csv"


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

train_dataset = Flickr30kDataset(image_dir, train_df, transform=transform, tokenizer=tokenizer, max_length=50)
test_dataset = Flickr30kDataset(image_dir, test_df, transform=transform, tokenizer=tokenizer, max_length=50)

train_laoder = DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_laoder = DataLoader(train_dataset, batch_size=BATCH_SIZE*2)

In [9]:
a, b = next(iter(train_laoder))

a.size(), b.size()

(torch.Size([32, 3, 224, 224]), torch.Size([32, 50]))

# Model

In [14]:
class ViTFeatureExtractor(nn.Module):
    def __init__(self, use_all_tokens=True):
        super(ViTFeatureExtractor, self).__init__()

        self.use_all_tokens = use_all_tokens
        self.vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

    def forward(self, x):
        features = self.vit_model(x).last_hidden_state

        if self.use_all_tokens:
            return features
        else:
            return features[:, 0, :]

In [15]:
feature_extractor = ViTFeatureExtractor()

feature_extractor(a).size()

torch.Size([32, 197, 768])