#### Image captioning

<font size = 3><span style = "color:#3A3E59;font-family:'Times New Roman'">Image Captioning is the task of describing the content of an image in words. This task lies at the intersection of computer vision and natural language processing. Most image captioning systems use an encoder-decoder framework, where an input image is encoded into an intermediate representation of the information in the image, and then decoded into a descriptive text sequence.</span></font>
<br>

<img src="https://raw.githubusercontent.com/danieljl/keras-image-captioning/master/results-without-errors.jpg">



# Table of contents


- [1. Imports](#1)
- [2. Hyperparameters](#2)
- [3. Dataset](#3)
  * .[3.1 Feature Extractor and Tokenizer](#3.1)
  * [3.2 Transforms and dataframe](#3.2)
  * [3.3 Dataset Class](#3.3)
  * .[3.4 Train and validation dataset](#3.4)
- [4. Model Building](#4)
    * .[4.1 Model Initialization](#4.1)
- [5. Training](#5)
    * .[5.1 Training Arguments](#5.1)
    * .[5.2 Training using Seq2SeqTrainer](#5.2)
- .[6. Predictions](#6)

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
from IPython.display import clear_output
!pip install deep-phonemizer
!pip install transformers[torch]
!pip install datasets
!pip install google-cloud-storage

# <a id='1'></a>1. Imports

In [None]:
import os

import datasets
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
import multiprocessing as mp
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import io, transforms
from torch.utils.data import Dataset, DataLoader, random_split

from transformers import Seq2SeqTrainer ,Seq2SeqTrainingArguments
from transformers import VisionEncoderDecoderModel , ViTFeatureExtractor, AutoFeatureExtractor
from transformers import AutoTokenizer ,  GPT2Config , default_data_collator


if torch.cuda.is_available():

    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

<a id="2"></a>
# 2. Hyperparameters

In [None]:
class config :
    ENCODER = "google/vit-base-patch16-224-in21k"
    DECODER = "gpt2"
    TRAIN_BATCH_SIZE = 20
    VAL_BATCH_SIZE = 20
    LR = 3e-4
    MAX_LEN = 128
    EPOCHS = 2
    IMG_SIZE = (224,224)

<a id="3"></a>
# 3. Dataset

In [None]:
!wget https://storage.googleapis.com/gen-ai-workshop-1/dataset.zip
!unzip -q dataset.zip -d dataset

<a id="3.1"></a>
## 3.1 Feature Extractor and Tokenizer

1. <font size = 3><span style="color:#3A3E59;font-family:'Times New Roman'"> The Feature extractor is loaded using <b>ViTFeatureExtractor</b>  </span></font>
2. <font size = 3><span style="color:#3A3E59;font-family:'Times New Roman'">The tokenizer for GPT2 is loaded using the <b>AutoTokenizer</b>  </span></font>

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(config.ENCODER)
tokenizer = AutoTokenizer.from_pretrained(config.DECODER)
tokenizer.pad_token = tokenizer.eos_token

<a id="3.2"></a>
## 3.2 Transforms and dataframe

In [None]:
df = pd.read_csv("dataset/captions.txt")
train_image_names, test_image_names = train_test_split(df['image'].unique(), test_size=0.2)
train_df = df[ df['image'].isin(train_image_names) ]
val_df = df[ df['image'].isin(test_image_names) ]
train_df.sample(10)

<a id="3.3"></a>
## 3.3 Dataset Class

The dataset is created using the following steps:
1. We read image using the PIL library
2. The image is transformed using feature extractor
3. The captions are loaded from the dataframe
4. The captions are tokenized
5. The tokenized captions are padded to max length
6. The images and tokenized captions are returned

In [None]:
import time
class ImgDataset(Dataset):
    def __init__(self, df,root_dir,tokenizer,feature_extractor, custom_verbose = False):
        self.df = df
        self.custom_verbose = custom_verbose
        self.root_dir = root_dir
        self.tokenizer= tokenizer
        self.feature_extractor = feature_extractor
        self.max_length = 50

    def __len__(self,):
        return len(self.df)

    def __getitem__(self,idx):
        caption = self.df.caption.iloc[idx]
        image = self.df.image.iloc[idx]
        img_path = os.path.join(self.root_dir , image)
        img = Image.open(img_path).convert("RGB")
        pixel_values = self.feature_extractor(img, return_tensors="pt").pixel_values
        captions = self.tokenizer(caption, padding='max_length', max_length=self.max_length).input_ids
        eos_idx = captions.index(tokenizer.pad_token_id)
        captions = [caption if caption != self.tokenizer.pad_token_id else -100 for caption in captions]
        captions[eos_idx] = self.tokenizer.pad_token_id
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(captions)}
        return encoding

<a id="3.4"></a>
## 3.4 Train and validation dataset

In [None]:
train_dataset = ImgDataset(train_df, root_dir = "dataset/Images",tokenizer=tokenizer,feature_extractor = feature_extractor)
val_dataset = ImgDataset(val_df , root_dir = "dataset/Images",tokenizer=tokenizer,feature_extractor = feature_extractor)

<a id="4"></a>
# 4. MODEL BUILDING

**ENCODER**

<img src = "https://production-media.paperswithcode.com/methods/Screen_Shot_2021-01-26_at_9.43.31_PM_uI4jjMq.png">

The Vision Transformer, or ViT, is a model for image classification that employs a Transformer-like architecture over patches of the image. An image is split into fixed-size patches, each of them are then linearly embedded, position embeddings are added, and the resulting sequence of vectors is fed to a standard Transformer encoder. In order to perform classification, the standard approach of adding an extra learnable “classification token” to the sequence is used.</span></font>

**DECODER**

<img src = "https://i.stack.imgur.com/7J4O7.png" >


GPT-2 is a transformers model pretrained on a very large corpus of English data in a self-supervised fashion. inputs are sequences of continuous text of a certain length and the targets are the same sequence, shifted one token (word or piece of word) to the right. The model uses internally a mask-mechanism to make sure the predictions for the token i only uses the inputs from 1 to i but not the future tokens.
    
This way, the model learns an inner representation of the English language that can then be used to extract features useful for downstream tasks. The model is best at what it was pretrained for however, which is generating texts from a prompt.
    

<a id="4.1"></a>
## 4.1 Model Initialization

In [None]:
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(config.ENCODER, config.DECODER).to(device)

In [None]:
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

<a id="5"></a>
# 5. TRAINING

<a id="5.1"></a>
### 5.1 Training Arguments


In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir='VIT_large_gpt2',
    per_device_train_batch_size=config.TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=config.VAL_BATCH_SIZE,
    predict_with_generate=False,
    evaluation_strategy="epoch",
    save_steps=2048,
    warmup_steps=128,
    learning_rate = config.LR,
    num_train_epochs = config.EPOCHS,
    overwrite_output_dir=True,
    save_total_limit=1,
    weight_decay=0.01
)

In [None]:
len(train_dataset), len(val_dataset)

<a id="5.2"></a>
### 5.2 Training using Seq2SeqTrainer

In [None]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    tokenizer=feature_extractor,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=default_data_collator,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model('VIT_large_gpt2_v2')

<a id="6"></a>
# 6. PREDICTIONS

In [None]:
from transformers import logging
logging.set_verbosity_error()

In [None]:
# Load model
#loaded_model = VisionEncoderDecoderModel.from_pretrained("VIT_large_gpt2_v2").to(device)
loaded_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)

In [None]:
def show_pred(idx, model):
    img_name = train_df.reset_index(drop=True)['image'][idx]
    img = Image.open(f"dataset/Images/{img_name}").convert("RGB")
    generated_caption = tokenizer.decode(model.generate(feature_extractor(img, return_tensors="pt").pixel_values.to(device), max_length=100)[0])

    plt.imshow(img)
    print(f"Prediction: {generated_caption}")
    print(f"Label: {train_df.reset_index(drop=True)['caption'][idx]}")

In [None]:
show_pred(10, loaded_model)