In [1]:
from training_enhanced import *

  warn(


In [2]:
# --- Device Setup ---
device = torch.device("xpu" if torch.xpu.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
device

device(type='xpu')

In [3]:
# --- Paths, Tokenizer, Dataset, and DataLoader ---
IMAGE_DIR = "train2017_50k"
FEATURES_DIR = "train2017_50k_features_en"
CAPTIONS_FILE = "merged_captions.json"
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer.bos_token = "[CLS]"
tokenizer.eos_token = "[SEP]"
tokenizer.bos_token_id = tokenizer.convert_tokens_to_ids("[CLS]")
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("[SEP]")

max_length = 50
dataset = ImageCaptionDataset(IMAGE_DIR, CAPTIONS_FILE, tokenizer, max_length=max_length, use_features=True, features_dir=FEATURES_DIR)
len(dataset)

50000

In [4]:
# --- Precompute Features ---
# encoder = EfficientNetEncoder()
# precompute_features(dataset, encoder, device, FEATURES_DIR, batch_size=100)

Precomputing features: 100%|██████████| 500/500 [54:52<00:00,  6.58s/it]


In [4]:
# --- Create Train, Validation, and Test Splits ---
batch_size = 32
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, persistent_workers=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, persistent_workers=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, persistent_workers=True, num_workers=4)

In [5]:
# --- Hyperparameters ---
embed_dim = 512
num_heads = 8
hidden_dim = 2048
num_layers = 6
dropout = 0.2
feature_dim = 1792
lr = 1e-4
weight_decay = 1e-4


# --- High Compute Hyperparameters --- 
# embed_dim = 256
# hidden_dim = 1024
# num_heads = 16  # More attention heads
# num_layers = 6  # Deeper model
# dropout = 0.2
# feature_dim = 960

# --- Instantiate Encoder, Decoder, and Model ---
encoder = EfficientNetEncoder()
decoder = TransformerDecoder(
    embed_dim=embed_dim,        
    num_heads=num_heads,      
    hidden_dim=hidden_dim,
    vocab_size=tokenizer.vocab_size,
    num_layers=num_layers,    
    max_length=max_length,
    feature_dim=feature_dim,
    dropout=dropout
)
model = ImageCaptionModel(encoder, decoder, use_features=True)

In [6]:
# --- Loss, Optimizer, Scheduler and Training ---
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=2)
num_epochs = 50

model.train()
model = model.to(device)
criterion = criterion.to(device)
model, optimizer = ipex.optimize(model, optimizer=optimizer)
# model.load_state_dict(torch.load("best_model.pth", weights_only=True))
trainer = ImageCaptionTrainer(model, tokenizer, criterion, optimizer, scheduler, device)



In [7]:
trainer.train(train_loader, val_loader, num_epochs, patience=10, min_delta=0.001, max_length=max_length)

Epoch 1/50: 100%|██████████| 1250/1250 [07:14<00:00,  2.88it/s, loss=3.2695]
                                                                              

Epoch 1 | Train Loss: 4.0235 | Val Loss: 3.1067 | LR: 0.000100
--> Best model saved. | Val loss 3.1067


Epoch 2/50: 100%|██████████| 1250/1250 [06:59<00:00,  2.98it/s, loss=2.5017]
                                                                              

Epoch 2 | Train Loss: 2.9507 | Val Loss: 2.7084 | LR: 0.000100
--> Best model saved. | Val loss 2.7084


Epoch 3/50: 100%|██████████| 1250/1250 [07:00<00:00,  2.98it/s, loss=2.7224]
                                                                              

Epoch 3 | Train Loss: 2.6355 | Val Loss: 2.5426 | LR: 0.000100
--> Best model saved. | Val loss 2.5426


Epoch 4/50: 100%|██████████| 1250/1250 [07:11<00:00,  2.90it/s, loss=2.2854]
                                                                              

Epoch 4 | Train Loss: 2.4458 | Val Loss: 2.4249 | LR: 0.000100
--> Best model saved. | Val loss 2.4249


Epoch 5/50: 100%|██████████| 1250/1250 [07:31<00:00,  2.77it/s, loss=2.3916]
                                                                              

Epoch 5 | Train Loss: 2.3048 | Val Loss: 2.3583 | LR: 0.000100
--> Best model saved. | Val loss 2.3583


Epoch 6/50: 100%|██████████| 1250/1250 [07:21<00:00,  2.83it/s, loss=2.2293]
                                                                              

Epoch 6 | Train Loss: 2.1909 | Val Loss: 2.3218 | LR: 0.000100
--> Best model saved. | Val loss 2.3218


Epoch 7/50: 100%|██████████| 1250/1250 [07:19<00:00,  2.84it/s, loss=2.1898]
                                                                              

Epoch 7 | Train Loss: 2.0922 | Val Loss: 2.2813 | LR: 0.000100
--> Best model saved. | Val loss 2.2813


Epoch 8/50: 100%|██████████| 1250/1250 [07:22<00:00,  2.83it/s, loss=2.0460]
                                                                              

Epoch 8 | Train Loss: 2.0031 | Val Loss: 2.2744 | LR: 0.000100
--> Best model saved. | Val loss 2.2744


Epoch 9/50: 100%|██████████| 1250/1250 [07:21<00:00,  2.83it/s, loss=1.9321]
                                                                              

Epoch 9 | Train Loss: 1.9209 | Val Loss: 2.2646 | LR: 0.000100
--> Best model saved. | Val loss 2.2646


Epoch 10/50: 100%|██████████| 1250/1250 [07:38<00:00,  2.73it/s, loss=1.9549]
                                                                              

Epoch 10 | Train Loss: 1.8435 | Val Loss: 2.2650 | LR: 0.000100


Epoch 11/50: 100%|██████████| 1250/1250 [07:10<00:00,  2.90it/s, loss=1.7347]
                                                                              

Epoch 11 | Train Loss: 1.7681 | Val Loss: 2.2683 | LR: 0.000100


Epoch 12/50: 100%|██████████| 1250/1250 [06:52<00:00,  3.03it/s, loss=1.8081]
                                                                              

Epoch 12 | Train Loss: 1.6953 | Val Loss: 2.2792 | LR: 0.000030


Epoch 13/50: 100%|██████████| 1250/1250 [06:50<00:00,  3.04it/s, loss=1.5428]
                                                                              

Epoch 13 | Train Loss: 1.5330 | Val Loss: 2.2667 | LR: 0.000030


Epoch 14/50: 100%|██████████| 1250/1250 [06:50<00:00,  3.04it/s, loss=1.5429]
                                                                              

Epoch 14 | Train Loss: 1.4871 | Val Loss: 2.2836 | LR: 0.000030


Epoch 15/50: 100%|██████████| 1250/1250 [06:51<00:00,  3.04it/s, loss=1.4619]
                                                                              

Epoch 15 | Train Loss: 1.4534 | Val Loss: 2.2979 | LR: 0.000009


Epoch 16/50: 100%|██████████| 1250/1250 [06:52<00:00,  3.03it/s, loss=1.3804]
                                                                              

Epoch 16 | Train Loss: 1.3948 | Val Loss: 2.2950 | LR: 0.000009


Epoch 17/50: 100%|██████████| 1250/1250 [06:53<00:00,  3.02it/s, loss=1.3325]
                                                                              

Epoch 17 | Train Loss: 1.3805 | Val Loss: 2.2967 | LR: 0.000009


Epoch 18/50: 100%|██████████| 1250/1250 [06:53<00:00,  3.03it/s, loss=1.4249]
                                                                              

Epoch 18 | Train Loss: 1.3695 | Val Loss: 2.3042 | LR: 0.000003


Epoch 19/50: 100%|██████████| 1250/1250 [06:51<00:00,  3.04it/s, loss=1.2540]
                                                                              

Epoch 19 | Train Loss: 1.3493 | Val Loss: 2.3040 | LR: 0.000003
Early stopping triggered!




In [8]:
torch.save(model.state_dict(), "last_model.pth")
model.load_state_dict(torch.load("best_model.pth", weights_only=True))
trainer.model = model

In [9]:
# --- After Training, Evaluate on the Test Set (Greedy Decoding) ---
metrics = trainer.evaluate_test_set(test_loader, max_length)
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")

Testing: 100%|██████████| 157/157 [2:02:12<00:00, 46.71s/it] 


Computing BLEU Score... 0.0740467658803539
Computing CIDEr... 0.30955993101206697
Computing METEOR... 0.2613576305799852
Computing ROUGE-L Score... 0.2964625358581543
Computing BERT Score... 0.9674673080444336
BLEU Score: 0.0740
CIDEr Score: 0.3096
METEOR Score: 0.2614
ROUGE-L Score: 0.2965
BERT Score: 0.9675
