In [None]:
!pip install transformers datasets evaluate torch torchvision nltk rouge_score
import nltk
nltk.download('punkt')
# get flickr
!mkdir -p flickr_data
%cd flickr_data
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
!unzip -q Flickr8k_Dataset.zip
!unzip -q Flickr8k_text.zip
%cd ..

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ev

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


/content/flickr_data
/content


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,
    AutoTokenizer,
    TrainingArguments,
    Trainer
)
import numpy as np
from PIL import Image
import os
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer


In [None]:
class Flickr8kDataset(Dataset):
    def __init__(self, image_dir, captions_file, processor, tokenizer, max_length=64):
        self.image_dir = image_dir
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length

        #read captions
        self.image_captions = {}
        with open(captions_file, 'r', encoding='utf-8') as f:
            for line in f:
                if len(line.strip()) == 0:
                    continue
                try:
                    parts = line.strip().split('#')
                    if len(parts) >= 2:
                        image_name = parts[0].strip()
                        caption = parts[1].split('\t')[1].strip()

                        if image_name not in self.image_captions:
                            # verify image exists before adding
                            image_path = os.path.join(self.image_dir, image_name)
                            if os.path.exists(image_path):
                                self.image_captions[image_name] = []
                            else:
                                continue
                        self.image_captions[image_name].append(caption)
                except Exception as e:
                    print(f"Error processing line: {line}")
                    print(f"Error: {e}")
                    continue

        self.images = list(self.image_captions.keys())
        print(f"Loaded {len(self.images)} images with captions")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_name = self.images[idx]
        image_path = os.path.join(self.image_dir, image_name)

        try:
            image = Image.open(image_path).convert('RGB')
            pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()

            caption = np.random.choice(self.image_captions[image_name])

            tokenized_caption = self.tokenizer(
                caption,
                padding="max_length",
                max_length=self.max_length,
                truncation=True,
                return_tensors="pt"
            )

            labels = tokenized_caption.input_ids.squeeze()
            labels[labels == self.tokenizer.pad_token_id] = -100

            return {
                "pixel_values": pixel_values,
                "labels": labels
            }
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))



In [None]:
def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {
        "pixel_values": pixel_values,
        "labels": labels
    }

In [None]:
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k",
    "gpt2"
)
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# special tokens
tokenizer.pad_token = tokenizer.eos_token
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

train_dataset = Flickr8kDataset(
    "flickr_data/Flicker8k_Dataset",
    "flickr_data/Flickr8k.token.txt",
    image_processor,
    tokenizer
)

training_args = TrainingArguments(
    output_dir="./image_captioning_output",
    num_train_epochs=10,
    per_device_train_batch_size=32,
    warmup_steps=500,
    save_steps=1000,
    logging_steps=100,
    learning_rate=3e-5,
    weight_decay=0.01,
    save_total_limit=2,
    report_to=[],  # disabling wandb logging cuz we're not using those vis
    run_name=None
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
)

print("Starting training...")
trainer.train()

print("Saving model...")
model.save_pretrained("./image_captioning_finetuned")
tokenizer.save_pretrained("./image_captioning_finetuned")
image_processor.save_pretrained("./image_captioning_finetuned")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded 8091 images with captions
Starting training...


Step,Training Loss
100,4.431
200,3.399
300,3.2879
400,3.1347
500,3.0695
600,2.9714
700,2.9363
800,2.8475
900,2.9024
1000,2.8359


Saving model...


['./image_captioning_finetuned/preprocessor_config.json']

In [None]:
!pip install pycocoevalcap nltk
import nltk
nltk.download('wordnet')
nltk.download('punkt')
from nltk.translate import meteor_score
from pycocoevalcap.cider.cider import Cider

def prepare_evaluation_data(caption_file, image_dir):
    evaluation_data = []
    image_captions = {}

    with open(caption_file, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('#')
            if len(parts) >= 2:
                image_name = parts[0].strip()
                caption = parts[1].split('\t')[1].strip()

                if image_name not in image_captions:
                    image_captions[image_name] = []
                image_captions[image_name].append(caption)

    for image_name, captions in image_captions.items():
        evaluation_data.append({
            'image_path': os.path.join(image_dir, image_name),
            'captions': captions
        })

    return evaluation_data

class CaptionEvaluator:
    def __init__(self, model, tokenizer, image_processor, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.device = device
        self.smoothing = SmoothingFunction().method1
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)

    def generate_caption(self, image):
        self.model.eval()
        with torch.no_grad():
            pixel_values = self.image_processor(image, return_tensors="pt").pixel_values.to(self.device)
            output_ids = self.model.generate(
                pixel_values,
                max_length=64,
                num_beams=4,
                early_stopping=True
            )
            caption = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return caption

    def evaluate_dataset(self, dataset, num_samples=None):
        generated_captions = []
        reference_captions = []

        gts = {}  # gt
        res = {}  # res

        for idx, item in enumerate(tqdm(dataset[:num_samples] if num_samples else dataset)):
            image_path = item['image_path']
            image = Image.open(image_path).convert('RGB')

            generated_caption = self.generate_caption(image)
            generated_captions.append(generated_caption.split())

            refs = [ref.split() for ref in item['captions']]
            reference_captions.append(refs)

            gts[idx] = [' '.join(ref) for ref in refs]
            res[idx] = [' '.join(generated_caption.split())]

        bleu4 = corpus_bleu(reference_captions, generated_captions,
                           weights=(0.25, 0.25, 0.25, 0.25),
                           smoothing_function=self.smoothing)

        cider_scorer = Cider()
        cider_score, _ = cider_scorer.compute_score(gts, res)

        meteor_scores = []
        for gen, refs in zip(generated_captions, reference_captions):
            meteor_scores.append(max([meteor_score.meteor_score([ref], gen) for ref in refs]))
        meteor_final = np.mean(meteor_scores)

        rouge_scores = []
        for gen, refs in zip(generated_captions, reference_captions):
            gen_text = ' '.join(gen)
            ref_texts = [' '.join(ref) for ref in refs]
            scores = [self.rouge_scorer.score(ref, gen_text)['rouge1'].fmeasure for ref in ref_texts]
            rouge_scores.append(max(scores))
        rouge1_final = np.mean(rouge_scores)

        metrics = {
            'BLEU-4': bleu4,
            'CIDEr': cider_score,
            'METEOR': meteor_final,
            'ROUGE-1': rouge1_final
        }

        return metrics, list(zip(generated_captions, reference_captions))

print("starting eval...")
evaluation_data = prepare_evaluation_data(
    "flickr_data/Flickr8k.token.txt",
    "flickr_data/Flicker8k_Dataset"
)

evaluator = CaptionEvaluator(model, tokenizer, image_processor)
metrics, caption_pairs = evaluator.evaluate_dataset(evaluation_data, num_samples=100)

#show metrics
print("\neval metrics:")
print(f"BLEU-4: {metrics['BLEU-4']:.4f}")
print(f"CIDEr: {metrics['CIDEr']:.4f}")
print(f"METEOR: {metrics['METEOR']:.4f}")
print(f"ROUGE-1: {metrics['ROUGE-1']:.4f}")

print("\nExample Predictions:")
for i, (gen, refs) in enumerate(caption_pairs[:3]):
    print(f"\nImage {i+1}:")
    print(f"Generated: {' '.join(gen)}")
    print("References:")
    for j, ref in enumerate(refs):
        print(f"  {j+1}: {' '.join(ref)}")