In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import pickle
from prepare_data import CaptionDataset
from tqdm import tqdm
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
from src.model import Encoder, DecoderTransformer

In [3]:
with open('../Processed Data/embedding_matrix.pkl', 'rb') as f:
    embedding_matrix = pickle.load(f)
with open('../Processed Data/word_to_index.pkl', 'rb') as f:
    word_to_index = pickle.load(f)


In [3]:
vocab_size = len(word_to_index)
pad_idx = word_to_index["<PAD>"]

# Model Hyperparameters
embed_dim = 300  # Should match the embedding matrix dimension
nhead = 6
num_decoder_layers = 4
dim_feedforward = 1024
dropout = 0.5

# Training Hyperparameters
batch_size = 32
num_epochs = 15
learning_rate = 1e-4

In [None]:
encoder = Encoder(embed_dim=embed_dim, dropout=dropout)
decoder = DecoderTransformer(
    vocab_size=vocab_size,
    embedding_dim=embed_dim,
    nhead=nhead,
    num_decoder_layers=num_decoder_layers,
    dim_feedforward=dim_feedforward,
    embedding_matrix=embedding_matrix
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device)
decoder = decoder.to(device)

print(f"Using device: {device}")

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = CaptionDataset(
    image_dir='../Flickr8k/Images',
    captions_path='../Processed Data/encoded_captions.pkl',
    transform=transform
)
# Note: Set num_workers to 0 if you are on Windows and encounter DataLoader errors
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [6]:
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(filter(lambda p: p.requires_grad, params), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [None]:
print("Starting training...")
for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0
    
    loop = tqdm(dataloader, leave=True)
    for i, (images, captions) in enumerate(loop):
        images = images.to(device)
        captions = captions.to(device)
        
        caption_input = captions[:, :-1]
        caption_target = captions[:, 1:]

        features = encoder(images)
        outputs = decoder(caption_input, features)

        loss = criterion(
            outputs.reshape(-1, vocab_size),
            caption_target.reshape(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(dataloader)
    print(f"\nEnd of Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

print("Training finished.")

In [None]:
print("Saving model...")
os.makedirs("../Models", exist_ok=True)
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'embed_dim': embed_dim,
    'nhead': nhead,
    'num_decoder_layers': num_decoder_layers,
    'dim_feedforward': dim_feedforward
}, '../Models/caption_model.pth')

print("✅ Model saved successfully.")
