In [None]:
import torch.optim as optim
import time
import torch
from torch.cuda.amp import GradScaler, autocast
import matplotlib.pyplot as plt

In [None]:
from utils import *
from data_loader import *
from encoder_decoder import *

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
data_location =  "dataa"

In [None]:
path_to_dataset = data_location

# Path to the directory containing image files
img_dir = f"{path_to_dataset}/Flickr8k_Dataset/Flicker8k_Dataset/"

# Path to the file containing captions
captions_file = f"{path_to_dataset}/Flickr8k_text/Flickr8k.token.txt"

# Path to the split files
train_file = f"{path_to_dataset}/Flickr8k_text/Flickr_8k.trainImages.txt"
val_file = f"{path_to_dataset}/Flickr8k_text/Flickr_8k.devImages.txt"
test_file = f"{path_to_dataset}/Flickr8k_text/Flickr_8k.testImages.txt"

transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Initialize datasets for training, validation, and testing
train_dataset = Flickr8kDataset(img_dir, captions_file, train_file, transform=transform)
val_dataset = Flickr8kDataset(img_dir, captions_file, val_file, transform=transform)
test_dataset = Flickr8kDataset(img_dir, captions_file, test_file, transform=transform)

# Create DataLoader instances
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=Flickr8kDataset.collate_fn,
    num_workers=4,
)
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=Flickr8kDataset.collate_fn,
    num_workers=4,
)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=Flickr8kDataset.collate_fn,
    num_workers=4,
)

In [None]:
vocab_size = len(train_dataset.stoi)

In [None]:
#Hyperparams
embed_size=300
attention_dim=256
encoder_dim=2048
decoder_dim=512
learning_rate = 1e-4

In [None]:
#init model
model = EncoderDecoder(
    embed_size=embed_size,
    vocab_size = vocab_size,
    attention_dim=attention_dim,
    encoder_dim=encoder_dim,
    decoder_dim=decoder_dim
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
#helper function to save the model
def save_model(model,num_epochs):
    model_state = {
        'num_epochs':num_epochs,
        'embed_size':embed_size,
        'vocab_size':vocab_size,
        'attention_dim':attention_dim,
        'encoder_dim':encoder_dim,
        'decoder_dim':decoder_dim,
        'state_dict':model.state_dict()
    }

    torch.save(model_state,'attention_model_state.pth')

In [None]:
import time
import torch
from torch.cuda.amp import GradScaler, autocast
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
import matplotlib.pyplot as plt

num_epochs = 15
print_every = 100
validate_every = 1
scaler = GradScaler() 

best_val_loss = float("inf")

train_losses = []
val_losses = []
bleu_scores = []

for epoch in range(1, num_epochs + 1):
    start_time = time.time()  

    model.train()
    total_train_loss = 0
    for idx, (images, captions) in enumerate(train_loader):
        images, captions = images.to(device, non_blocking=True), captions.to(
            device, non_blocking=True
        )

        optimizer.zero_grad()

        with autocast():
            outputs, attentions = model(images, captions)
            targets = captions[:, 1:]
            loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
            total_train_loss += loss.item()

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if (idx + 1) % print_every == 0:
            print(f"Epoch: {epoch}, Step: {idx + 1}, Loss: {loss.item():.5f}")

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation loop
    if epoch % validate_every == 0:
        model.eval()
        total_val_loss = 0
        references = []
        hypotheses = []
        with torch.no_grad():
            for images, captions in val_loader:
                images, captions = images.to(device, non_blocking=True), captions.to(
                    device, non_blocking=True
                )

                with autocast():
                    outputs, _ = model(images, captions)
                    targets = captions[:, 1:]
                    val_loss = criterion(
                        outputs.view(-1, vocab_size), targets.reshape(-1)
                    )
                    total_val_loss += val_loss.item()

                for output_batch in outputs.argmax(dim=-1):
                    words_preds = [
                        train_dataset.itos[idx.item()] for idx in output_batch
                    ]
                    hypotheses.append(words_preds)

                for target_batch in targets:
                    words_targets = [
                        train_dataset.itos[idx.item()] for idx in target_batch
                    ]
                    references.append([words_targets])

        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f"Validation Loss after Epoch {epoch}: {avg_val_loss:.5f}")

        # Compute BLEU scores
        bleu_1 = corpus_bleu(
            references,
            hypotheses,
            weights=(1, 0, 0, 0),
            smoothing_function=SmoothingFunction().method1,
        )
        bleu_2 = corpus_bleu(
            references,
            hypotheses,
            weights=(0.5, 0.5, 0, 0),
            smoothing_function=SmoothingFunction().method1,
        )
        bleu_3 = corpus_bleu(
            references,
            hypotheses,
            weights=(0.33, 0.33, 0.33, 0),
            smoothing_function=SmoothingFunction().method1,
        )
        bleu_4 = corpus_bleu(
            references, hypotheses, smoothing_function=SmoothingFunction().method1
        )
        bleu_scores.append((bleu_1, bleu_2, bleu_3, bleu_4))
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss

        model.train() 

    end_time = time.time() 
    epoch_duration = end_time - start_time
    print(f"Time for epoch {epoch}: {epoch_duration:.2f} seconds")

In [None]:
import matplotlib.pyplot as plt

bleu_1_scores = [score[0] for score in bleu_scores]
bleu_2_scores = [score[1] for score in bleu_scores]
bleu_3_scores = [score[2] for score in bleu_scores]
bleu_4_scores = [score[3] for score in bleu_scores]

epochs = range(1, len(bleu_scores) + 1)

plt.figure(figsize=(10, 8))
plt.plot(epochs, bleu_1_scores, label="BLEU-1", marker="o")
plt.plot(epochs, bleu_2_scores, label="BLEU-2", marker="o")
plt.plot(epochs, bleu_3_scores, label="BLEU-3", marker="o")
plt.plot(epochs, bleu_4_scores, label="BLEU-4", marker="o")
plt.xlabel("Epochs")
plt.ylabel("BLEU Score")
plt.title("BLEU Scores Over Training Epochs")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(range(0, len(train_losses), validate_every), val_losses, label='Validation Loss', linestyle='--')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()