In [None]:
from training import *

In [None]:
# --- Device Setup ---
device = torch.device("xpu" if torch.xpu.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# --- Paths, Tokenizer, Dataset, and DataLoader ---
IMAGE_DIR = "train2017_50k"
FEATURES_DIR = "train2017_50k_features"
CAPTIONS_FILE = "merged_captions.json"
tokenizer = AutoTokenizer.from_pretrained("nemotron_tokenizer")
max_length = 50
dataset = ImageCaptionDataset(IMAGE_DIR, CAPTIONS_FILE, tokenizer, max_length=max_length, use_features=True, features_dir=FEATURES_DIR)
len(dataset)

In [None]:
# --- Precompute Features ---
# encoder = MobileNetV3Encoder()
# precompute_features(dataset, encoder, device, FEATURES_DIR, batch_size=800)

In [None]:
# --- Create Train, Validation, and Test Splits ---
batch_size = 8
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# --- Hyperparameters ---
embed_dim = 512
num_heads = 8
hidden_dim = 512
num_layers = 6
dropout = 0.2
feature_dim = 960

# --- Instantiate Encoder, Decoder, and Model ---
encoder = MobileNetV3Encoder()
decoder = TransformerDecoder(
    embed_dim=embed_dim,        
    num_heads=num_heads,      
    hidden_dim=hidden_dim,
    vocab_size=tokenizer.vocab_size,
    num_layers=num_layers,    
    max_length=max_length,
    feature_dim=feature_dim,
    dropout=dropout
)
model = ImageCaptionModel(encoder, decoder, use_features=True)

In [None]:
# --- Loss, Optimizer, Scheduler and Training ---
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)  

num_epochs = 50
scheduler = ReduceLROnPlateau(optimizer)

model.train()
model = model.to(device)
criterion = criterion.to(device)
model, optimizer = ipex.optimize(model, optimizer=optimizer)
# model.load_state_dict(torch.load("best_model.pth", weights_only=True))
trainer = ImageCaptionTrainer(model, tokenizer, criterion, optimizer, scheduler, device)

In [None]:
trainer.train(train_loader, val_loader, num_epochs, patience=10, min_delta=0.001, max_length=max_length)

In [None]:
torch.save(model.state_dict(), "last_model.pth")

In [None]:
# --- After Training, Evaluate on the Test Set ---
metrics = trainer.evaluate_test_set(test_loader, max_length)
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")

In [None]:
# --- Load best model then check bleu score ---
model.load_state_dict(torch.load("best_model.pth", weights_only=True))
trainer.model = model
metrics = trainer.evaluate_test_set(test_loader, max_length)
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")