## Image-captioning model

### Overview

The code below trains an image captioning algorithm using a pre-trained vision and language transformer as well as the COCO dataset. The model uses the Microsoft Swin-Tiny-Patch4-Window7-224 and DistilGPT-2  pre-trained models as the encoder and decoder respectively.
The code implementation is done with the help of the following libraries/packages:
 - Transformers library from Hugging Face to implement the model.
 - PyTorch for tensor operations
 - PIL and torchvision libraries for image processing.


### Instructions to Run this notebook

To run this code on a local machine, follow these steps:

- Install the requirements listed in `requirements.txt`. Make sure to include the Transformers, datasets, scikit-learn, PIL, and PyTorch libraries/packages

- Download the COCO dataset and place it in the directory specified by the COCO_DIR variable.

Run the code in a Python environment that has access to a GPU. The code can be run in a Jupyter notebook or in a Python script. This code was run on MacBook M1.

In [1]:
import nltk
from transformers import VisionEncoderDecoderModel, GPT2TokenizerFast, AutoFeatureExtractor, \
                          TrainingArguments, Trainer

import datasets
from PIL import Image
import torch
import torchvision.transforms as transforms
import ipywidgets as widgets

## Model

In [2]:
# ???Many weights are initialized randomly, namely the cross attention weights??? change this comment. not mine
#One of the main objectives is optimizing the cross-attention weights. How GPT connects to the encoders output,
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    'microsoft/swin-tiny-patch4-window7-224',
    'distilgpt2'
)

print(f'This model uses a rre-trained encoder of type {type(model.encoder)} and pre-trained decoder of type {type(model.decoder)}')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['transformer.h.3.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.3.crossattention.bias', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.

This model uses a rre-trained encoder of type <class 'transformers.models.swin.modeling_swin.SwinModel'> and pre-trained decoder of type <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>


## Device


In [3]:
device_type = widgets.RadioButtons(
    options=['M1', 'Other'],
    value='M1',
    description='Select device',
    disabled=False
)
device_type

RadioButtons(description='Select device', options=('M1', 'Other'), value='M1')

In [4]:
# train on gpu
if device_type.value == 'M1':

    device = torch.device("mps")

else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model.to(device)

VisionEncoderDecoderModel(
  (encoder): SwinModel(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0-1): 2 x SwinLayer(
              (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=96, out_features=96, bias=True)
                  (key): Linear(in_features=96, out_features=96, bias=True)
                  (value): Linear(in_features=96, out_features=96, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): SwinSelfOutput(


In [5]:
# Vit and GPT2 have 182,485,248 combined parameters. Swin version I am tuning here has fewer parameters
from torch import numel

combined_params = 0
for param in model.parameters():
    combined_params += numel(param)
    
print(f"Total number of parameters: {combined_params:,}")

Total number of parameters: 123,615,354


## Database
The image captioning algorithm is trained on the COCO dataset which consists of images and captions.

In [6]:
# COCO_DIR = input('Path to COCO dataset')

In [7]:
#the datasets.load_dataset manages everything related to caching. So I have to use it.
COCO_DIR = '/Users/yesidcano/repos/image-captioning/data/coco'


#ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017",data_dir=COCO_DIR, cache_dir='/Users/yesidcano/repos/db_coco_cache')

# Load a slice of the database this https://huggingface.co/docs/datasets/loading to split dataset.

ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017",data_dir=COCO_DIR, split="train[10:1000]")
ds

Found cached dataset coco_dataset_script (/Users/yesidcano/.cache/huggingface/datasets/ydshieh___coco_dataset_script/2017-6b5176efb5303df4/0.0.0/e033205c0266a54c10be132f9264f2a39dcf893e798f6756d224b1ff5078998f)


Dataset({
    features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
    num_rows: 990
})

## Data processing

In [8]:
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
#gpt-2 tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2')
tokenizer.pad_token = tokenizer.eos_token



In [9]:
# Model config

model.config.pad_token = tokenizer.pad_token
model.config.pad_token_id = tokenizer.pad_token_id

model.config.decoder_start_token = tokenizer.bos_token
model.config.decoder_start_token_id = tokenizer.bos_token_id

In [10]:
feature_extractor

ViTFeatureExtractor {
  "do_normalize": true,
  "do_resize": true,
  "feature_extractor_type": "ViTFeatureExtractor",
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 3,
  "size": 224
}

In [11]:
tokenizer

PreTrainedTokenizerFast(name_or_path='distilgpt2', vocab_size=50257, model_max_len=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'})

In [12]:



# Define the transforms to be applied to the images
transform = transforms.Compose([

    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

])

def preprocess_fn(examples):

    # Swin expects pixel_values instead of input_ids
    examples['pixel_values'] = [transform(Image.open(path).convert('RGB')) for path in examples['image_path']]
    # We are padding tokens here instead of using a datacollator
    tokenized = tokenizer(
        examples['caption'], padding='max_length', max_length=10, truncation=True
    )['input_ids']
    # the output captions
    examples['labels'] = [[l if l != tokenizer.pad_token_id else -100 for l in t] for t in tokenized]

    # delete unused keys
    del examples['image_path']
    del examples['caption']
    return examples

processed_dataset = ds.map(
    function=preprocess_fn,
    batched=True,
    batch_size = 50,
    #remove_columns=['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path']



)

Map:   0%|          | 0/990 [00:00<?, ? examples/s]

In [13]:
# By default data are shuffled.
processed_dataset = processed_dataset.train_test_split(test_size=0.1)
processed_dataset

DatasetDict({
    train: Dataset({
        features: ['image_id', 'caption_id', 'height', 'width', 'file_name', 'coco_url', 'pixel_values', 'labels'],
        num_rows: 891
    })
    test: Dataset({
        features: ['image_id', 'caption_id', 'height', 'width', 'file_name', 'coco_url', 'pixel_values', 'labels'],
        num_rows: 99
    })
})

## Model's layers
Below a list of the encoder and decoder's layers respectively.

### Encoder Layers

In [14]:

for name, param in model.encoder.named_parameters():
    print(name)

embeddings.patch_embeddings.projection.weight
embeddings.patch_embeddings.projection.bias
embeddings.norm.weight
embeddings.norm.bias
encoder.layers.0.blocks.0.layernorm_before.weight
encoder.layers.0.blocks.0.layernorm_before.bias
encoder.layers.0.blocks.0.attention.self.relative_position_bias_table
encoder.layers.0.blocks.0.attention.self.query.weight
encoder.layers.0.blocks.0.attention.self.query.bias
encoder.layers.0.blocks.0.attention.self.key.weight
encoder.layers.0.blocks.0.attention.self.key.bias
encoder.layers.0.blocks.0.attention.self.value.weight
encoder.layers.0.blocks.0.attention.self.value.bias
encoder.layers.0.blocks.0.attention.output.dense.weight
encoder.layers.0.blocks.0.attention.output.dense.bias
encoder.layers.0.blocks.0.layernorm_after.weight
encoder.layers.0.blocks.0.layernorm_after.bias
encoder.layers.0.blocks.0.intermediate.dense.weight
encoder.layers.0.blocks.0.intermediate.dense.bias
encoder.layers.0.blocks.0.output.dense.weight
encoder.layers.0.blocks.0.outp

### Decoder's layers

In [15]:
for name, param in model.decoder.named_parameters():
    print(name)

transformer.wte.weight
transformer.wpe.weight
transformer.h.0.ln_1.weight
transformer.h.0.ln_1.bias
transformer.h.0.attn.c_attn.weight
transformer.h.0.attn.c_attn.bias
transformer.h.0.attn.c_proj.weight
transformer.h.0.attn.c_proj.bias
transformer.h.0.ln_2.weight
transformer.h.0.ln_2.bias
transformer.h.0.crossattention.c_attn.weight
transformer.h.0.crossattention.c_attn.bias
transformer.h.0.crossattention.q_attn.weight
transformer.h.0.crossattention.q_attn.bias
transformer.h.0.crossattention.c_proj.weight
transformer.h.0.crossattention.c_proj.bias
transformer.h.0.ln_cross_attn.weight
transformer.h.0.ln_cross_attn.bias
transformer.h.0.mlp.c_fc.weight
transformer.h.0.mlp.c_fc.bias
transformer.h.0.mlp.c_proj.weight
transformer.h.0.mlp.c_proj.bias
transformer.h.1.ln_1.weight
transformer.h.1.ln_1.bias
transformer.h.1.attn.c_attn.weight
transformer.h.1.attn.c_attn.bias
transformer.h.1.attn.c_proj.weight
transformer.h.1.attn.c_proj.bias
transformer.h.1.ln_2.weight
transformer.h.1.ln_2.bias
tr

### Freeze layers

In [16]:
# freeze layer of the encoder (Assuming that the model is already good at understanding images). Since all cross-attention weights need to be optimized
# I do not freeze any of the decoder layers.
for name, param in model.encoder.named_parameters():
    # freeze stage 1 and 2 of the Swin encoder.
    if 'encoder.layer.3' in name:
        break
    param.requires_grad = False

## Evaluation Metrics

In [17]:
import evaluate
metric = evaluate.load("rouge")

In [18]:
import numpy as np

ignore_pad_token_for_loss = True


def postprocess_text(preds, labels):


    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    """
     Process the predicted captions and the labels (image captions) to use the feed them to the Rough Metric to compute evaluation metrics.
    :param eval_preds: set of predicted and target tokens
    :return:
    """
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    #predicted captions are decoded into strings using GPT-2 tokenizer
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # the token ID -100 indicates the end of the sequence.
        # Replaces all -100 values with the id of the padding token in the tokeniezer
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # decoded strings are furthered pre-preprocessed e.g., split them into sentences using `sentence_tokenize` from NLTK
    decoded_preds, decoded_labels = postprocess_text(decoded_preds,
                                                     decoded_labels)
    # Compute the Rough metric. The metric uses stemming to match words with the same root.
    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels,
                            use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    return result

## Training Procedure:

- The swin-gpt-2-image-captioning model is trained with the `Trainer` class from the transformers library.
- The arguments used for training are set via `TrainingArguments` class.
- The training is started by the `train ` method of the trainer class.

In [19]:
training_arg = TrainingArguments(
    output_dir='../models/swin_image_captioning', # The output directory
    overwrite_output_dir=True, # overwrite the content of the output directory
    num_train_epochs=2, # number of training epochs
    per_device_train_batch_size=64, # batch size for training
    per_device_eval_batch_size=64,  # batch size for evaluation
    load_best_model_at_end=True,
    log_level='info',
    logging_steps=50,
    evaluation_strategy='epoch',
    save_strategy='epoch',
)

trainer = Trainer(
    model=model,
    args=training_arg,
    compute_metrics=compute_metrics,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['test'],
)

In [20]:
trainer.evaluate()

The following columns in the evaluation set don't have a corresponding argument in `VisionEncoderDecoderModel.forward` and have been ignored: caption_id, coco_url, file_name, height, image_id, width. If caption_id, coco_url, file_name, height, image_id, width are not expected by `VisionEncoderDecoderModel.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 99
  Batch size = 64


TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

In [None]:
# 4 epochs took 8 hours
trainer.train()

In [None]:
trainer.save_model()

In [None]:
#need to save the tokenizer
tokenizer.save_pretrained('../models/swin_image_captioning')

In [None]:
# loading model and config from pretrained folder
finetuned_model = VisionEncoderDecoderModel.from_pretrained('../models/swin_image_captioning')

In [None]:

from IPython.core.display_functions import display





inference_transforms = transforms.Compose(
    [

    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

# a helper function to caption images from the web or a file path
def generate_caption(m, path):

    img = Image.open(path).convert('RGB')
    img_transformed = inference_transforms(img).unsqueeze(0)

    model_output = m.generate(
        img_transformed,
        num_beams=3,
        max_length=15,
        early_stopping=True,
        do_sample=True,
        top_k=10,
        num_return_sequences=5,
    )

    captions = [tokenizer.decode(g, skip_special_tokens=True).strip() for g in model_output]
    #Show image
    display(img)
    return captions, model_output, img_transformed


captions, model_output, img_transformed = generate_caption(  # Out of sample photo
    finetuned_model, '../data/test_data/000000421195_test.jpg'
)

captions

In [None]:
#from transformers import pipeline

In [None]:
# image_captioner = pipeline("image-to-text", model="../models/swin_image_captioning")