In [1]:
# Use kaggle api to download glove embeddings
# Generate kaggle api key
# https://www.kaggle.com/docs/api#getting-started-installation-&-authentication
# Run `kaggle datasets download -d anmolkumar/glove-embeddings`
# glove-embeddings.zip will be created
# mkdir -p glove.6B
# unzip -qq glove-embeddings.zip -d glove.6B

In [2]:
# !conda install -c anaconda bcolz
# !pip install spacy
# python -m spacy download en

In [3]:
# !mkdir -p raw_data
# !mkdir -p data
# !wget -O raw_data/Flickr8k_Dataset.zip "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
# !wget -O raw_data/Flickr8k_text.zip "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"
# !unzip -qq raw_data/Flickr8k_Dataset.zip -d data/
# !unzip -qq raw_data/Flickr8k_text.zip -d data/

In [1]:
import os
import random
import numpy as np
import pickle
import torch.nn as nn
import torch
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm
import torchvision.models as models
import torchvision.transforms as transforms

%load_ext autoreload
%autoreload 2

In [2]:
from generate_vocab import Vocabulary
from dataset import get_data_loader

%reload_ext autoreload
%autoreload 2

In [3]:
def seed_everything(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
with open('data/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

glove_vectors = pickle.load(open('glove.6B/glove_words.pkl', 'rb'))
glove_vectors = torch.tensor(glove_vectors)

vocab_size = len(vocab)
embed_size = glove_vectors.size()[-1]

In [5]:
# class EncoderCNN(nn.Module):
#     def __init__(self, embed_size):
#         super(EncoderCNN, self).__init__()
#         resnet = models.resnet101(pretrained=True)
#         self.resnet = nn.Sequential(*list(resnet.children())[:-1])
#         self.fc = nn.Linear(2048, embed_size)
#         self.relu = nn.ReLU()

#     def forward(self, images):
#         out = self.resnet(images)
#         out = out.view(1, -1)
#         out = self.relu(self.fc(out))
#         return out

class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN
        self.inception = models.inception_v3(pretrained=True, aux_logits=False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, images):
        features = self.inception(images)
        return self.dropout(self.relu(features))
    
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.embed.weight = nn.Parameter(glove_vectors)
        for p in self.embed.parameters():
            p.requires_grad = True
        
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0).float()
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features, captions)
        return outputs

    def caption_image(self, image, vocab, max_length=50):
        result_caption = []

        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoderRNN.lstm(x.float(), states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)

                if vocab.idx2word[predicted.item()] == "<eos>":
                    break

        return [vocab.idx2word[idx] for idx in result_caption]

In [6]:
class AverageMeter(object):
    def __init__(self):
        self.avg = 0.
        self.sum = 0.
        self.count = 0.

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
# class EarlyStopping:
#     def __init__(self, patience=7, mode="max", delta=0.001):
#         self.patience = patience
#         self.counter = 0
#         self.mode = mode
#         self.best_score = None
#         self.best_epoch = 0
#         self.early_stop = False
#         self.delta = delta
#         if self.mode == "min":
#             self.val_score = np.Inf
#         else:
#             self.val_score = -np.Inf

#     def __call__(self, epoch, epoch_score, model, model_path):

#         if self.mode == "min":
#             score = -1.0 * epoch_score
#         else:
#             score = np.copy(epoch_score)

#         if self.best_score is None:
#             self.best_score = score
#             self.best_epoch = epoch
#             self.save_checkpoint(epoch_score, model, model_path)
#         elif score < self.best_score + self.delta:
#             self.counter += 1
#             print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
#             if self.counter >= self.patience:
#                 self.early_stop = True
#         else:
#             self.best_score = score
#             self.best_epoch = epoch
#             self.save_checkpoint(epoch_score, model, model_path)
#             self.counter = 0

#     def save_checkpoint(self, epoch_score, model, model_path):
#         if epoch_score not in [-np.inf, np.inf, -np.nan, np.nan]:
#             print(f"Validation score improved ({self.val_score} --> {epoch_score}). Saving model!")
#             torch.save(model.state_dict(), model_path)
#         self.val_score = epoch_score

In [7]:
def train(model, optimizer, data_loader):
    model.train()
    for idx, (image_filenames, image_nums, text_captions, images, captions) in tqdm(
        enumerate(train_dataloader), total=len(train_dataloader), leave=False):
        
        optimizer.zero_grad()
        
        images = images.to(device)
        captions = captions.to(device)

        outputs = model(images, captions[:-1])
        loss = criterion(
            outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
        )
        loss.backward()
        optimizer.step()

def evaluate(model, optimizer, data_loader):
    model.eval()
    avg_meter = AverageMeter()
    predicted_captions = []
    actual_captions = []
    for idx, (image_filenames, image_nums, text_captions, images, captions) in tqdm(
        enumerate(train_dataloader), total=len(train_dataloader), leave=False):
        
        optimizer.zero_grad()
        
        images = images.to(device)
        captions = captions.to(device)

        outputs = model(images, captions[:-1])
        loss = criterion(
            outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
        )
        
    #     for image, caption in zip(images, captions):
    #         preds = model.caption_image(image.unsqueeze(0), vocab)
    #         predicted_captions.append(preds)
    #         actual_captions.append([vocab.idx2word[i.item()] for i in caption][1:-1])
    # bleu_score = bleu_score(actual_captions, predicted_captions)
    # return bleu_score

In [8]:
transform = transforms.Compose([
    # transforms.RandomCrop(224),
    # transforms.RandomHorizontalFlip(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                        (0.229, 0.224, 0.225))])
train_dataloader = get_data_loader(data_type="train", vocab=vocab, transforms=transform, num_workers=4)
dev_dataloader = get_data_loader(data_type="dev", vocab=vocab, transforms=transform, num_workers=4)
test_dataloader = get_data_loader(data_type="test", vocab=vocab, transforms=transform, num_workers=4)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 2
seed_everything(2022)
model = CNNtoRNN(embed_size, 256, vocab_size, 1).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for name, param in model.encoderCNN.inception.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True

for epoch in range(NUM_EPOCHS):
    print(f"Epoch: {epoch + 1}")
    train(model, optimizer, train_dataloader)
    evaluate(model, optimizer, dev_dataloader)

Epoch: 1


                                                

KeyboardInterrupt: 