**IMPORTING LIBRARIES**

In [1]:
import os
import json
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    get_linear_schedule_with_warmup
)
from PIL import Image
from nltk.translate.bleu_score import corpus_bleu
from rouge_score import rouge_scorer

**DATASET**

In [2]:
class CustomBLIPDataset(Dataset):
    def __init__(self, json_file, image_dir, processor, max_length=32):
        with open(json_file, 'r') as f:
            data = json.load(f)

        self.samples = []
        for ann in data['annotations']:
            self.samples.append({
                'image_id': ann['image_id'],
                'caption': ann['caption']
            })

        self.image_dir = image_dir
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_id = sample['image_id']
        caption = sample['caption']
        image_path = os.path.join(self.image_dir, f"{image_id}.jpg")
        image = Image.open(image_path).convert('RGB')
        encoding = self.processor(
            images=image,
            text=caption,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt"
        )
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        return encoding

**TRAINING**

In [3]:
class BLIPTrainer:
    def __init__(
        self,
        model_name="Salesforce/blip-image-captioning-base",
        train_batch_size=8,
        eval_batch_size=16,
        learning_rate=5e-5,
        num_epochs=10,
        warmup_steps=1000,
        max_length=32,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        output_dir='blip_checkpoints'
    ):
        self.model_name = model_name
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.warmup_steps = warmup_steps
        self.max_length = max_length
        self.device = device
        self.output_dir = output_dir

        print(f"Using device: {self.device}")

        self.processor = BlipProcessor.from_pretrained(model_name)
        self.model = BlipForConditionalGeneration.from_pretrained(model_name)
        self.model.to(device)

        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

        os.makedirs(output_dir, exist_ok=True)

    def prepare_data(self, json_path, image_dir, train_ratio=0.8, val_ratio=0.1):
        full_dataset = CustomBLIPDataset(
            json_file=json_path,
            image_dir=image_dir,
            processor=self.processor,
            max_length=self.max_length
        )

        total_size = len(full_dataset)
        train_size = int(train_ratio * total_size)
        val_size = int(val_ratio * total_size)
        test_size = total_size - train_size - val_size

        print(f"Dataset sizes - Train: {train_size}, Val: {val_size}, Test: {test_size}")

        train_dataset, val_dataset, test_dataset = random_split(
            full_dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(42)
        )

        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=4
        )

        self.val_dataloader = DataLoader(
            val_dataset,
            batch_size=self.eval_batch_size,
            shuffle=False,
            num_workers=4
        )

        self.test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.eval_batch_size,
            shuffle=False,
            num_workers=4
        )

        return self.train_dataloader, self.val_dataloader, self.test_dataloader

    def compute_metrics(self, pred_captions, true_captions):
        # Calculate BLEU score
        references = [[caption.split()] for caption in true_captions]
        predictions = [caption.split() for caption in pred_captions]
        bleu_score = corpus_bleu(references, predictions)

        # Calculate ROUGE scores
        rouge_scores = {
            'rouge1': 0.0,
            'rouge2': 0.0,
            'rougeL': 0.0
        }
        for pred, ref in zip(pred_captions, true_captions):
            scores = self.rouge_scorer.score(ref, pred)
            rouge_scores['rouge1'] += scores['rouge1'].fmeasure
            rouge_scores['rouge2'] += scores['rouge2'].fmeasure
            rouge_scores['rougeL'] += scores['rougeL'].fmeasure
        num_samples = len(pred_captions)
        for key in rouge_scores:
            rouge_scores[key] /= num_samples
        return {
            'bleu': bleu_score,
            **rouge_scores
        }

    def train(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)
        total_steps = len(self.train_dataloader) * self.num_epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=total_steps
        )

        # Training loop
        best_val_loss = float('inf')
        for epoch in range(self.num_epochs):
            print(f"\nStarting epoch {epoch + 1}/{self.num_epochs}")

            # Training
            self.model.train()
            train_loss = 0
            progress_bar = tqdm(self.train_dataloader, desc=f"Training epoch {epoch + 1}")

            for batch in progress_bar:
                input_ids = batch['input_ids'].to(self.device)
                pixel_values = batch['pixel_values'].to(self.device)
                attention_mask = batch.get('attention_mask', None)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(self.device)

                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    pixel_values=pixel_values,
                    attention_mask=attention_mask,
                    labels=input_ids
                )

                loss = outputs.loss
                train_loss += loss.item()

                # Backward pass
                loss.backward()
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

                progress_bar.set_postfix({'loss': loss.item()})

            avg_train_loss = train_loss / len(self.train_dataloader)

            # Validation
            val_metrics = self.evaluate(self.val_dataloader, "Validation")

            metrics = {
                'train_loss': avg_train_loss,
                'val_loss': val_metrics['loss'],
                'val_bleu': val_metrics['bleu'],
                'val_rouge1': val_metrics['rouge1'],
                'val_rouge2': val_metrics['rouge2'],
                'val_rougeL': val_metrics['rougeL']
            }

            print(f"Epoch {epoch + 1} metrics:", metrics)

            if val_metrics['loss'] < best_val_loss:
                best_val_loss = val_metrics['loss']
                self.save_model(f"{self.output_dir}/best_model")
                print(f"New best model saved with validation loss: {val_metrics['loss']:.4f}")
            self.save_model(f"{self.output_dir}/checkpoint_epoch_{epoch + 1}")

    @torch.no_grad()
    def evaluate(self, dataloader, split="Validation"):
        self.model.eval()
        total_loss = 0
        all_predictions = []
        all_references = []

        progress_bar = tqdm(dataloader, desc=f"{split} evaluation")

        for batch in progress_bar:
            input_ids = batch['input_ids'].to(self.device)
            pixel_values = batch['pixel_values'].to(self.device)
            attention_mask = batch.get('attention_mask', None)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.device)

            outputs = self.model(
                input_ids=input_ids,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                labels=input_ids
            )

            total_loss += outputs.loss.item()

            generated_ids = self.model.generate(
                pixel_values=pixel_values,
                max_length=self.max_length,
                num_beams=4,
                early_stopping=True
            )

            generated_captions = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
            reference_captions = self.processor.batch_decode(input_ids, skip_special_tokens=True)

            all_predictions.extend(generated_captions)
            all_references.extend(reference_captions)

        metrics = self.compute_metrics(all_predictions, all_references)
        metrics['loss'] = total_loss / len(dataloader)

        return metrics

    def test(self):
        test_metrics = self.evaluate(self.test_dataloader, "Test")
        print(f"Test metrics:", test_metrics)
        return test_metrics

    def save_model(self, filename):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'processor_state': self.processor.save_pretrained,
            'model_name': self.model_name
        }
        torch.save(checkpoint, filename)
        print(f"Model saved to {filename}")

    def load_model(self, filename):
        checkpoint = torch.load(filename, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model_name = checkpoint['model_name']
        self.model.to(self.device)
        print(f"Model loaded from {filename}")

In [4]:
trainer = BLIPTrainer(
    model_name="Salesforce/blip-image-captioning-base",
    train_batch_size=8,
    eval_batch_size=16,
    learning_rate=5e-5,
    num_epochs=10,
    warmup_steps=1000,
    max_length=32,
    output_dir='blip_checkpoints'
)

train_dataloader, val_dataloader, test_dataloader = trainer.prepare_data(
    json_path='images.json',
    image_dir='images',
    train_ratio=0.8,
    val_ratio=0.1
)

Using device: cpu
Dataset sizes - Train: 2000, Val: 250, Test: 250


In [14]:
trainer.train()
trainer.save_model('fine-tuned_BLIP.pt')


Starting epoch 1/10


Training epoch 1: 100%|██████████| 250/250 [05:00<00:00,  1.20s/it, loss=3.79]
Validation evaluation: 100%|██████████| 16/16 [02:11<00:00,  8.23s/it]


Epoch 1 metrics: {'train_loss': 6.141959931373596, 'val_loss': 3.811283975839615, 'val_bleu': 0.08775112768060989, 'val_rouge1': 0.15882476738688986, 'val_rouge2': 0.12781450853520607, 'val_rougeL': 0.15387159506901166}
Model saved to blip_checkpoints/best_model
New best model saved with validation loss: 3.8113
Model saved to blip_checkpoints/checkpoint_epoch_1

Starting epoch 2/10


Training epoch 2: 100%|██████████| 250/250 [05:03<00:00,  1.21s/it, loss=1.01]
Validation evaluation: 100%|██████████| 16/16 [02:17<00:00,  8.57s/it]


Epoch 2 metrics: {'train_loss': 1.742139730811119, 'val_loss': 1.2196671664714813, 'val_bleu': 0.12549425110812118, 'val_rouge1': 0.2955645744029393, 'val_rouge2': 0.16630010683776775, 'val_rougeL': 0.2686275117736633}
Model saved to blip_checkpoints/best_model
New best model saved with validation loss: 1.2197
Model saved to blip_checkpoints/checkpoint_epoch_2

Starting epoch 3/10


Training epoch 3: 100%|██████████| 250/250 [05:03<00:00,  1.21s/it, loss=0.936]
Validation evaluation: 100%|██████████| 16/16 [01:54<00:00,  7.13s/it]


Epoch 3 metrics: {'train_loss': 0.9558675435781478, 'val_loss': 1.1771351844072342, 'val_bleu': 0.11596268235628784, 'val_rouge1': 0.27993081614563514, 'val_rouge2': 0.1546276744252678, 'val_rougeL': 0.25455645865289744}
Model saved to blip_checkpoints/best_model
New best model saved with validation loss: 1.1771
Model saved to blip_checkpoints/checkpoint_epoch_3

Starting epoch 4/10


Training epoch 4: 100%|██████████| 250/250 [05:03<00:00,  1.21s/it, loss=0.806]
Validation evaluation: 100%|██████████| 16/16 [02:21<00:00,  8.81s/it]


Epoch 4 metrics: {'train_loss': 0.7763670245409012, 'val_loss': 1.1783964596688747, 'val_bleu': 0.11987048823038354, 'val_rouge1': 0.26725660665003764, 'val_rouge2': 0.15605687311310165, 'val_rougeL': 0.24812848601448573}
Model saved to blip_checkpoints/checkpoint_epoch_4

Starting epoch 5/10


Training epoch 5: 100%|██████████| 250/250 [05:02<00:00,  1.21s/it, loss=0.54]
Validation evaluation: 100%|██████████| 16/16 [02:16<00:00,  8.56s/it]


Epoch 5 metrics: {'train_loss': 0.6025062891244888, 'val_loss': 1.228714156895876, 'val_bleu': 0.12027737961385931, 'val_rouge1': 0.31319443262298247, 'val_rouge2': 0.1637502339826814, 'val_rougeL': 0.27944674810300973}
Model saved to blip_checkpoints/checkpoint_epoch_5

Starting epoch 6/10


Training epoch 6: 100%|██████████| 250/250 [05:03<00:00,  1.21s/it, loss=0.31]
Validation evaluation: 100%|██████████| 16/16 [02:15<00:00,  8.48s/it]


Epoch 6 metrics: {'train_loss': 0.4066934297084808, 'val_loss': 1.2845053859055042, 'val_bleu': 0.1159657336296715, 'val_rouge1': 0.33710858786786063, 'val_rouge2': 0.15615731703470428, 'val_rougeL': 0.2979902985199454}
Model saved to blip_checkpoints/checkpoint_epoch_6

Starting epoch 7/10


Training epoch 7: 100%|██████████| 250/250 [05:02<00:00,  1.21s/it, loss=0.388]
Validation evaluation: 100%|██████████| 16/16 [02:17<00:00,  8.62s/it]


Epoch 7 metrics: {'train_loss': 0.262628405213356, 'val_loss': 1.3757603019475937, 'val_bleu': 0.09932028079526377, 'val_rouge1': 0.3261094840564905, 'val_rouge2': 0.13346278278586304, 'val_rougeL': 0.2819953558372457}
Model saved to blip_checkpoints/checkpoint_epoch_7

Starting epoch 8/10


Training epoch 8: 100%|██████████| 250/250 [05:02<00:00,  1.21s/it, loss=0.187]
Validation evaluation: 100%|██████████| 16/16 [02:24<00:00,  9.04s/it]


Epoch 8 metrics: {'train_loss': 0.17672417384386063, 'val_loss': 1.435774713754654, 'val_bleu': 0.08205557206418007, 'val_rouge1': 0.3054646386785037, 'val_rouge2': 0.11439032932115378, 'val_rougeL': 0.2647234504782206}
Model saved to blip_checkpoints/checkpoint_epoch_8

Starting epoch 9/10


Training epoch 9: 100%|██████████| 250/250 [05:02<00:00,  1.21s/it, loss=0.133]
Validation evaluation: 100%|██████████| 16/16 [02:19<00:00,  8.70s/it]


Epoch 9 metrics: {'train_loss': 0.12455823394656182, 'val_loss': 1.5176988244056702, 'val_bleu': 0.07525727109097474, 'val_rouge1': 0.28785992290393553, 'val_rouge2': 0.10286787225701541, 'val_rougeL': 0.24671647870046157}
Model saved to blip_checkpoints/checkpoint_epoch_9

Starting epoch 10/10


Training epoch 10: 100%|██████████| 250/250 [05:02<00:00,  1.21s/it, loss=0.0944]
Validation evaluation: 100%|██████████| 16/16 [02:31<00:00,  9.50s/it]


Epoch 10 metrics: {'train_loss': 0.09534445084631443, 'val_loss': 1.5522729605436325, 'val_bleu': 0.08777703652939359, 'val_rouge1': 0.29933885036681745, 'val_rouge2': 0.1153205375408667, 'val_rougeL': 0.2555412900249298}
Model saved to blip_checkpoints/checkpoint_epoch_10
Model saved to fine-tuned_BLIP.pt


**TESTING**

In [15]:
test_metrics = trainer.test()
print("\nEvaluation on test set:\n", test_metrics)

Test evaluation: 100%|██████████| 16/16 [02:34<00:00,  9.68s/it]


Test metrics: {'bleu': 0.09442623168048496, 'rouge1': 0.2932291978870383, 'rouge2': 0.1275179068814095, 'rougeL': 0.2548687881025692, 'loss': 1.4956378638744354}

Evaluation on test set:
 {'bleu': 0.09442623168048496, 'rouge1': 0.2932291978870383, 'rouge2': 0.1275179068814095, 'rougeL': 0.2548687881025692, 'loss': 1.4956378638744354}
