
# Art Trainer Captioning Model

This notebook details the process of fine-tuning an image captioning model using a dataset of art descriptions. We use the Git-large-Coco model from Microsoft and employ the HuggingFace trainer for fine-tuning. The notebook includes steps for data preprocessing, model selection, evaluation, and hyperparameter optimization.

## Table of Contents
1. [Installation](#Installation)
2. [Loading the Dataset](#Loading-the-Dataset)
3. [Model Selection](#Model-Selection)
4. [Data Preprocessing](#Data-Preprocessing)
5. [Evaluation](#Evaluation)
6. [Fine Tuning](#Fine-Tuning)
7. [Hyperparameters Optimization](#Hyperparameters-Optimization)
8. [Fine Tuning with Best Parameters](#Fine-Tuning-with-Best-Parameters)

---



# Install the needed packages

%pip install datasets
%pip install transformers
%pip install sentencepiece
%pip install diffusers --upgrade
%pip install invisible_watermark accelerate safetensors
%pip install accelerate
%pip install jiwer
%pip install evaluate

In [None]:
import pandas as pd
import datasets
import torch
from transformers import AutoProcessor,TrainingArguments, Trainer, AutoTokenizer
from transformers import AutoModelForCausalLM
from PIL import Image, ImageFile
from datasets import load_metric
import numpy as np
Image.LOAD_TRUNCATED_IMAGES = True
ImageFile.LOAD_TRUNCATED_IMAGES = True
PYTORCH_CUDA_ALLOC_CONF=expandable_segments = True
from transformers import EarlyStoppingCallback

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Load the dataset

We check if the dataset is already present in the system, by checking the local variables. If it is not present we load it, else we just skip it. This is done to avoid loading the dataset again and again, as it takes time to load the dataset.

## Data to Keep
Since in this notebook we only Fine-Tune the Captioning model we only keep the descriptions of the paintings, which we scraped from the urls provided in the original dataset.

In [None]:
data = pd.read_csv('../described_dataset_label.csv',sep='\t',encoding='latin-1')
data = data.sample(frac=1).reset_index(drop=True)
data = data.iloc[:20000]
data = data.rename(columns={'FILE':'image','AUTHOR':'author', 'TECHNIQUE':'style','URL':'description'})
data = data[['image','description']]
data['image'] = [f'.{x}' for x in data['image']]
data.head()

In [None]:
print(data.columns)

In [None]:
dataset = datasets.Dataset.from_pandas(data).cast_column('image', datasets.Image())
print(dataset)

In [None]:
sample = dataset[53]

image = sample['image']
height,width = image.size
display(image.resize((int(0.3*height),int(0.3*width))))
caption = sample['description']
print(caption)

# The Model

For our task, we opted for the Git-large-Coco model provided by Microsoft. This model is a large-scale language model trained on the Coco dataset, which shares similarities with our dataset as it consists of images paired with textual descriptions. Among the models available on Hugging Face, we found that the Git-large-Coco model consistently yielded superior results for our specific task.

In addition to the model itself, we utilized the associated processor. This processor aligns with the one used during the original training of the model. Ensuring consistency between the model and its processor is crucial, as the processor plays a key role in tokenizing input data. Mismatched processors could result in the model being unable to interpret input data correctly. By leveraging the pre-existing processor, we save significant time and effort that would otherwise be required to develop and fine-tune our own processor.


In [None]:
checkpoint_capt= "microsoft/git-large-coco"
processor_capt = AutoProcessor.from_pretrained(checkpoint_capt)

# Data Preprocessing

To facilitate data feeding into the processor and obtain tokenized inputs for the model, we define a function named `capt_transforms`. This function transforms the raw data into a format suitable for the model input. Here’s a breakdown of the process:

## Function Description
The `capt_transforms` function takes an example from the dataset and performs the following steps:

1. **Extract Images and Captions**:
   - Extracts images and captions from the example batch.

2. **Tokenization**:
   - Utilizes the pre-trained processor (`processor_capt`) to tokenize the images and captions.
   - Sets the maximum sequence length for padding and truncates sequences if needed.
   - Returns a DatasetDict containing tokenized inputs with keys "input_ids" and "labels".

## Input and Output
- **Input**: Example batch containing images and captions.
- **Output**: Tokenized inputs suitable for model ingestion, comprising input IDs and corresponding labels.

## Data Processing Improvements
- **Efficient Transformation**: The function efficiently transforms raw data into tokenized inputs using the pre-trained processor, ensuring compatibility with the model's input requirements.
- **Padding and Truncation**: Utilizes padding and truncation to handle sequences of varying lengths, enabling uniform input sizes for the model.
- **Label Generation**: Generates labels from input IDs, facilitating model training and evaluation.


In [None]:
def transforms(example_batch):
    images = [x for x in example_batch["image"]]
    captions = [x for x in example_batch["description"]]
    inputs = processor_capt(images=images, text=captions, padding='max_length', truncation=True, return_tensors="pt")
    inputs.update({"labels": inputs["input_ids"]})
    return inputs

In [None]:
#Caption
capt_dataset = dataset.train_test_split(test_size=0.3)
capt_dataset = capt_dataset.with_transform(transforms)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(checkpoint_capt)

We need to preprocess the data before evaluation because the trainer function has a problem of memory leaks.

In [None]:
def preprocess_logits_for_metrics(logits,labels):
    predictions = torch.argmax(logits, dim=-1)
    return predictions,labels

# Evaluation

For evaluating our model’s performance, we employed the ROUGE metric, which is widely used for evaluating image captioning in generative models. The ROUGE score measures the similarity between generated captions and ground truth captions by comparing overlapping n-grams, thereby providing insights into the quality and accuracy of the generated captions.

To compute the ROUGE score, we utilized the `load_metric` function from the  `datasets` library, specifically loading the ROUGE metric for evaluation purposes.
We then defined a custom function named `capt_compute_metrics`  library, specifically loading the ROUGE metric for evaluation purposes. We then defined a custom function named capt_compute_metrics to compute the ROUGE score for evaluation. This function takes eval_pred as input, which contains logits (model predictions) and labels (ground truth captions).

The function performs the following steps:

1. Extracts logits (model predictions) and labels (ground truth captions) from the `eval_pred` input.
2. Decodes the labels and predictions using the associated processor (`processor_capt`), skipping special tokens to obtain human-readable text.
3. Computes the ROUGE score using the decoded predictions and references (ground truth captions).
4. Returns a dictionary containing the computed ROUGE score under the key "rouge_score".

By utilizing the ROUGE score and implementing a custom evaluation function, we gain valuable insights into the quality of our model’s generated captions compared to ground truth captions. This facilitates quantitative assessment and refinement of the model’s performance, ultimately contributing to its effectiveness in generating accurate and relevant captions for images.

We need to preprocess the data before evaluation because the trainer function has a problem of memory leaks.


In [None]:
import evaluate

In [None]:
rouge = evaluate.load('rouge')
def capt_compute_metrics(eval_pred):
    logits, labels = eval_pred
    pred_ids = logits[0]
    pred_ids = processor_capt.batch_decode(pred_ids, skip_special_tokens=True)
    labels = processor_capt.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=pred_ids, references=labels, use_aggregator=True)
    return result

In [None]:
# from transformers import CLIPProcessor, CLIPModel

# clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor_clip = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# def clip_compute_metrics(eval_pred):
#     references = eval_pred.label_ids
#     generated_text = eval_pred.predictions[1]

#     references = processor_clip.batch_decode(references, skip_special_tokens=True)
#     generated_text = processor_clip.batch_decode(generated_text, skip_special_tokens=True)

#     # Calculate the BERTScore
#     result = clip.compute(predictions=generated_text, references=references)
#     print(result)
#     return result["f1"].mean()

# Fine tuning

In this notebook, we decide to use the HuggingFace trainer since it is a very easy-to-use and powerful tool to fine-tune models. This saved us time in writing our own custom training loop as the Hugging Face model serves as a comparison for our custom models, which we train in a different notebook.

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

## Hyperparameters Optimization



In [None]:
def model_init(trial):
    return AutoModelForCausalLM.from_pretrained(checkpoint_capt).to(device)

In [None]:
torch.cuda.empty_cache()

capt_training_args = TrainingArguments(
    output_dir="model_checkpoints/captioning",
    learning_rate=1e-5,
    num_train_epochs=5,
    warmup_ratio=0.2,
    fp16=False,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    save_total_limit=2,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    load_best_model_at_end=True,
)

capt_trainer = Trainer(
    # model=model_capt,
    model_init=model_init,
    args=capt_training_args,
    data_collator=data_collator,
    train_dataset=capt_dataset["train"],
    eval_dataset=capt_dataset["test"],
    compute_metrics=capt_compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

torch.cuda.empty_cache()

In [None]:
#HyperParameter Search

def optuna_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-4, log=True, step=1e-5),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 4,step=1),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [6,8]),
        "gradient_accumulation_steps": trial.suggest_int("gradient_accumulation_steps", 1, 4, step = 1),
        "per_device_eval_batch_size": trial.suggest_categorical("per_device_eval_batch_size", [6,8]),
        "warmup_ratio": trial.suggest_float("warmup_ratio", 0.1, 0.3, step=0.1),
    }

In [None]:
best_trials = capt_trainer.hyperparameter_search(n_trials=100,
                                                 backend="optuna",
                                                 hp_space=optuna_hp_space, 
                                                 direction="maximize",)

## Fine Tune with best parameters

In [None]:
model_capt = AutoModelForCausalLM.from_pretrained(checkpoint_capt).to(device)

In [None]:
torch.cuda.empty_cache()
best_hyperparameters = best_trials.hyperparameters

capt_training_args = TrainingArguments(
    output_dir="model_checkpoints/captioning",
    learning_rate=best_hyperparameters["learning_rate"],
    num_train_epochs=best_hyperparameters["num_train_epochs"],
    fp16=False,
    per_device_train_batch_size=best_hyperparameters["per_device_train_batch_size"],
    per_device_eval_batch_size=best_hyperparameters["per_device_eval_batch_size"],
    gradient_accumulation_steps=best_hyperparameters["gradient_accumulation_steps"],
    save_total_limit=2,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    load_best_model_at_end=True,
)

capt_trainer = Trainer(
    model=model_capt,
    args=capt_training_args,
    data_collator=data_collator,
    train_dataset=capt_dataset["train"],
    eval_dataset=capt_dataset["test"],
    compute_metrics=capt_compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

torch.cuda.empty_cache()

In [None]:
# torch.cuda.empty_cache()
# from transformers import EarlyStoppingCallback
# capt_training_args = TrainingArguments(
#     output_dir="model_checkpoints/captioning",
#     learning_rate=1e-5,
#     num_train_epochs=5,
#     warmup_ratio=0.2,
#     fp16=False,
#     per_device_train_batch_size=8,
#     per_device_eval_batch_size=8,
#     gradient_accumulation_steps=2,
#     eval_accumulation_steps=1,
#     save_total_limit=2,
#     evaluation_strategy="no",
#     save_strategy="no",
#     remove_unused_columns=False,
#     push_to_hub=False,
#     label_names=["description"],
#     load_best_model_at_end=True,
# )

# capt_trainer = Trainer(
#     model=model_capt,
#     # model_init=model_init,
#     args=capt_training_args,
#     data_collator=data_collator,
#     train_dataset=capt_dataset["train"],
#     eval_dataset=capt_dataset["test"],
#     compute_metrics=capt_compute_metrics,
#     preprocess_logits_for_metrics=preprocess_logits_for_metrics,
# )

# torch.cuda.empty_cache()

In [None]:
capt_trainer.train()

In [None]:
sample = dataset[89]
image = sample['image']
height,width = image.size
display(image.resize((int(0.3*height),int(0.3*width))))
desc = sample['description']
print(f'Description: {desc}')

In [None]:
inputs = processor_capt(images = image, return_tensors='pt').to(device)
pixel_values = inputs.pixel_values

generated_ids = model_capt.generate(pixel_values=pixel_values, max_length=5000)
generated_caption = processor_capt.batch_decode(generated_ids,skip_special_tokens=True)[0]
print(generated_caption)

In [None]:
model_capt.push_to_hub("Art_huggingface_caption")