In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pickle
import numpy as np
from model import Encoder, Decoder
from prepare_data import CaptionDataset

In [None]:
with open('Processed Data/embedding_matrix.pkl', 'rb') as f:
    embedding_matrix = pickle.load(f)
with open('Processed Data/word_to_index.pkl', 'rb') as f:
    word_to_index = pickle.load(f)

In [None]:
vocab_size = len(word_to_index)
embedding_dim = embedding_matrix.shape[1]
hidden_dim = 256
pad_idx = word_to_index["<PAD>"]

In [None]:
encoder = Encoder(encoded_image_size=hidden_dim)
decoder = Decoder(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    hidden_dim=hidden_dim,
    embedding_matrix=embedding_matrix
)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device)
decoder = decoder.to(device)

In [None]:
dataset = CaptionDataset(
    features_path='Processed Data/image_features_resnet50.pkl',
    captions_path='Processed Data/encoded_captions.pkl'
)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
params = list(encoder.linear.parameters()) + list(encoder.dropout.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(params, lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0
    for features, captions in dataloader:
        features = features.to(device)  # (batch, 2048)
        captions = captions.to(device)  # (batch, seq_len)
        caption_input = captions[:, :-1]
        caption_target = captions[:, 1:]

        # Chỉ dùng linear+dropout của encoder vì feature đã là 2048
        features_encoded = encoder.linear(encoder.dropout(features))  # (batch, 256)
        outputs = decoder(caption_input, features_encoded)           # (batch, vocab_size)

        loss = criterion(
            outputs.view(-1, vocab_size),
            caption_target.reshape(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}")

# 7. Lưu model
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'Models/caption_model_resnet50.pth')