In [7]:
import torch
import nltk
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from load_data import get_loader
from model import EncoderDecoder
import matplotlib.pyplot as plt

In [8]:
transform = transforms.Compose([
    transforms.Resize(226),                     
    transforms.RandomCrop(224),                 
    transforms.ToTensor(),                               
    transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])

train_loader, test_loader, train_dataset, test_dataset = get_loader(
    root_folder="flickr8k/images",
    train_annotation_file="flickr8k/train_captions.txt",
    test_annotation_file="flickr8k/test_captions.txt",
    transform=transform,
    num_workers=2
)

In [9]:
len(train_loader), len(test_loader), len(train_dataset), len(test_dataset)

(1012, 253, 32364, 8091)

In [10]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
embed_size=300
vocab_size = len(train_dataset.vocab)
attention_dim=256
encoder_dim=2048
decoder_dim=512
learning_rate = 3e-4


# initialize model, loss etc
model = EncoderDecoder(embed_size, vocab_size, attention_dim, encoder_dim, decoder_dim, drop_prob=0.3).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
def test_function(model, test_loader):
    model.eval()
    bleu_score = 0
    # f = open("results.txt", "w")
    for idx, (images, correct_captions) in enumerate(iter(test_loader)):
        features = model.encoder(images.to(device))
        for i in range(features.shape[0]):
            caps, alphas = model.decoder.generate_caption(features[i:i+1], vocab=train_dataset.vocab)
            caps = caps[:-1]
            caption = ' '.join(caps)
            correct_caption = []
            for j in correct_captions[i:i+1][0]:
                if j.item() not in [0, 1, 2]:
                    correct_caption.append(test_dataset.vocab.itos[j.item()])
            bleu = nltk.translate.bleu_score.sentence_bleu([correct_caption], caps, weights=(0.5, 0.5))
            correct_caption = ' '.join(correct_caption)
            write_this = correct_caption + ', ' + caption + '\n'
            bleu_score += bleu
            # f.write(write_this)
    # f.close()
    return bleu_score/8091


def save_model(model,num_epochs):
    model_state = {
        'num_epochs':num_epochs,
        'embed_size':embed_size,
        'vocab_size':len(train_dataset.vocab),
        'attention_dim':attention_dim,
        'encoder_dim':encoder_dim,
        'decoder_dim':decoder_dim,
        'state_dict':model.state_dict()
    }

    torch.save(model_state,'attention_model_state.pth')

checkpoint = torch.load('attention_model_state.pth')
model.load_state_dict(checkpoint['state_dict'])
epoch = checkpoint['num_epochs']

In [13]:
num_epochs = 5
print_every = 100
for epoch in range(1, num_epochs + 1):
    model.train()

    for idx, (image, captions) in enumerate(iter(train_loader)):
        image, captions = image.to(device), captions.to(device)
        # Zero the gradients.
        optimizer.zero_grad()
        # Feed forward
        outputs, attentions = model(image, captions)
        # Calculate the batch loss.
        targets = captions[:,1:]
        loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        # Backward pass.
        loss.backward()
        # Update the parameters in the optimizer.
        optimizer.step()
        if (idx+1)%print_every == 0:
            print("Epoch: {} loss: {:.5f}".format(epoch,loss.item()))
    
    print(test_function(model, test_loader))

    #save the latest model
    save_model(model, epoch)

Epoch: 1 loss: 1.34103
Epoch: 1 loss: 1.47104
Epoch: 1 loss: 1.51540
Epoch: 1 loss: 1.43858
Epoch: 1 loss: 1.52330
Epoch: 1 loss: 1.63536
Epoch: 1 loss: 1.54243
Epoch: 1 loss: 1.63800
Epoch: 1 loss: 1.61689
Epoch: 1 loss: 1.70182
0.1738885610679138
Epoch: 2 loss: 1.54467
Epoch: 2 loss: 1.45809
Epoch: 2 loss: 1.34410
Epoch: 2 loss: 1.50166
Epoch: 2 loss: 1.55057
Epoch: 2 loss: 1.53857
Epoch: 2 loss: 1.62845
Epoch: 2 loss: 1.65814
Epoch: 2 loss: 1.63853
Epoch: 2 loss: 1.50285
0.17804776382807722
Epoch: 3 loss: 1.54221
Epoch: 3 loss: 1.34699
Epoch: 3 loss: 1.59455
Epoch: 3 loss: 1.64221
Epoch: 3 loss: 1.47994
Epoch: 3 loss: 1.60703
Epoch: 3 loss: 1.50478
Epoch: 3 loss: 1.68979
Epoch: 3 loss: 1.49244
Epoch: 3 loss: 1.45444
0.17801017608868377
Epoch: 4 loss: 1.70248
Epoch: 4 loss: 1.42919
Epoch: 4 loss: 1.53869
Epoch: 4 loss: 1.47361
Epoch: 4 loss: 1.62460
Epoch: 4 loss: 1.33672
Epoch: 4 loss: 1.35910
Epoch: 4 loss: 1.41541
Epoch: 4 loss: 1.61562
Epoch: 4 loss: 1.26492
0.177392239158814
Epo

In [12]:
model.eval()
bleu_score = 0
f = open("results.txt", "w")
for idx, (images, correct_captions) in enumerate(iter(test_loader)):
    features = model.encoder(images.to(device))
    for i in range(features.shape[0]):
        caps, alphas = model.decoder.generate_caption(features[i:i+1], vocab=train_dataset.vocab)
        caps = caps[:-1]
        caption = ' '.join(caps)
        correct_caption = []
        for j in correct_captions[i:i+1][0]:
            if j.item() not in [0, 1, 2]:
                correct_caption.append(test_dataset.vocab.itos[j.item()])
        bleu = nltk.translate.bleu_score.sentence_bleu([correct_caption], caps, weights=(0.5, 0.5))
        correct_caption = ' '.join(correct_caption)
        write_this = correct_caption + ', ' + caption + '\n'
        bleu_score += bleu
        f.write(write_this)
f.close()
print(bleu_score/8091)

0.1763335136512202
