In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import pickle
from transformers import GPT2Tokenizer

from dataset import Flickr8kDataset, Flickr8kPreprocessor
from models import ClipCapModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


Using device: cuda


In [2]:
DATA_DIR = './data'
IMAGES_DIR = os.path.join(DATA_DIR, 'Images')
CAPTIONS_FILE = os.path.join(DATA_DIR, 'captions.txt')

preprocessor = Flickr8kPreprocessor(data_dir=DATA_DIR)

captions_data = preprocessor.load_captions(captions_file='captions.txt')
print(f'Loaded {len(captions_data)} images with captions.')

# CLIP features
features_path = os.path.join(DATA_DIR, 'clip_features.pkl')
if not os.path.exists(features_path):
    features_data = preprocessor.extract_clip_features(images_dir='Images', save_path='clip_features.pkl')
else:
    with open(features_path, 'rb') as f:
        features_data = pickle.load(f)
    print(f'Loaded features for {len(features_data["image_names"])} images.')

splits = preprocessor.prepare_train_data(features_data, captions_data, test_size=0.2, val_size=0.1)

print(f"Train: {len(splits['train'][0])} samples")
print(f"Val: {len(splits['val'][0])} samples")
print(f"Test: {len(splits['test'][0])} samples")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loaded 8091 images with captions.
Loaded features for 8091 images.
Train: 28318 samples
Val: 4046 samples
Test: 8091 samples


In [3]:
## 2. Create Dataset and DataLoader

# Hyperparameters
PREFIX_LENGTH = 10
BATCH_SIZE = 32
LEARNING_RATE = 5e-5
EPOCHS = 10
CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"
GPT2_MODEL_NAME = "openai-community/gpt2"

# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(GPT2_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Create datasets
train_dataset = Flickr8kDataset(
    image_features=splits['train'][0],
    captions=splits['train'][1],
    tokenizer=tokenizer,
    prefix_length=PREFIX_LENGTH
)

val_dataset = Flickr8kDataset(
    image_features=splits['val'][0],
    captions=splits['val'][1],
    tokenizer=tokenizer,
    prefix_length=PREFIX_LENGTH
)

test_dataset = Flickr8kDataset(
    image_features=splits['test'][0],
    captions=splits['test'][1],
    tokenizer=tokenizer,
    prefix_length=PREFIX_LENGTH
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')
print(f'Test batches: {len(test_loader)}')


Train batches: 885
Val batches: 127
Test batches: 253


In [4]:
## 3. Initialize Model

# Initialize model
model = ClipCapModel(
    prefix_length=PREFIX_LENGTH,
).to(device)

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS * len(train_loader))

print(f'{sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.')


137,031,936 trainable parameters.


In [5]:
## 4. Training Loop

def train_epoch(model, dataloader, optimizer, scheduler, device, tokenizer):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc='Training')
    for batch in progress_bar:
        optimizer.zero_grad()

        image_features = batch['image_features'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        # Labels are the same as input_ids, with padding ignored
        labels = input_ids.clone()
        labels[labels == tokenizer.pad_token_id] = -100

        outputs = model(
            image_features=image_features,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]})

    return total_loss / len(dataloader)

def validate_epoch(model, dataloader, device, tokenizer):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc='Validation')
        for batch in progress_bar:
            image_features = batch['image_features'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            labels = input_ids.clone()
            labels[labels == tokenizer.pad_token_id] = -100

            outputs = model(
                image_features=image_features,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

    return total_loss / len(dataloader)


In [7]:
train_losses = []
val_losses = []
best_val_loss = float('inf')

for epoch in range(EPOCHS):
    print(f'\nEpoch {epoch+1}/{EPOCHS}')
    print('-' * 50)

    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device, tokenizer)
    train_losses.append(train_loss)

    val_loss = validate_epoch(model, val_loader, device, tokenizer)
    val_losses.append(val_loss)

    print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, 'best_model.pth')
        print('Saved best model!')

print('Training completed!')



Epoch 1/10
--------------------------------------------------


Training:   0%|          | 0/885 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Training:  40%|████      | 358/885 [29:10<42:56,  4.89s/it, loss=1.82, lr=4.98e-5]


KeyboardInterrupt: 

In [None]:
## 5. Visualization and Evaluation

# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

In [9]:
# Load the best model
checkpoint = torch.load('best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])


# Generate sample captions
def generate_sample_captions(model, dataloader, tokenizer, device, num_samples=5):
    model.eval()
    samples = []

    with torch.no_grad():
        for batch in dataloader:
            if len(samples) >= num_samples:
                break

            image_features = batch['image_features'].to(device)
            input_ids = batch['input_ids']

            generated_captions = model.generate(
                image_features,
                max_length=50,
                temperature=0.8,
                do_sample=True,
                top_p=0.9
            )

            for i in range(len(generated_captions)):
                if len(samples) >= num_samples:
                    break
                gt_caption = tokenizer.decode(input_ids[i], skip_special_tokens=True)
                samples.append({
                    'generated': generated_captions[i],
                    'ground_truth': gt_caption
                })
    return samples

# Generate and display sample captions from the test set
sample_captions = generate_sample_captions(model, test_loader, tokenizer, device, num_samples=10)

for i, sample in enumerate(sample_captions):
    print(f'\n--- Sample {i+1} ---')
    print(f'Generated:    {sample["generated"]}')
    print(f'Ground Truth: {sample["ground_truth"]}')


  checkpoint = torch.load('best_model.pth', map_location=device)



--- Sample 1 ---
Generated:    A lioness is chasing after a black animal .
Ground Truth: A large wild cat is pursuing a horse across a meadow .

--- Sample 2 ---
Generated:    Two dogs are fighting over a stick .
Ground Truth: Two brown dogs fight on the leafy ground .

--- Sample 3 ---
Generated:    A man in a white shirt is standing on a cliff overlooking the ocean .
Ground Truth: A man in shorts is standing on a rock looking out at the view from the hilltop .

--- Sample 4 ---
Generated:    A white dog with a muzzle runs through the grass with a stick in its mouth .
Ground Truth: a muzzled white dog is running on the grass .

--- Sample 5 ---
Generated:    A skier in a red jacket is skiing down a snowy hill .
Ground Truth: A person skiing downhill .

--- Sample 6 ---
Generated:    A German shepherd dog is running through the grass with a tennis ball in its mouth .
Ground Truth: Shepherd dog catches tennis ball

--- Sample 7 ---
Generated:    A man in a yellow shirt is jumping on a 