# Libraries

In [42]:
from google.colab import drive
drive.mount('/content/gdrive')

%cd /content/gdrive/MyDrive/ImageCaptioning

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
/content/gdrive/.shortcut-targets-by-id/13dGpwyY-c5FPJTEacGkw8XNTkbGVWT2D/ImageCaptioning


In [43]:
import re
import cv2
import glob
import spacy
import numpy as np
import pandas as pd
from time import time
from PIL import Image
from numpy import array
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import Counter
import torch
from torchvision.models import resnet50, ResNet50_Weights
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [44]:
!python -m spacy download en_core_web_sm
!pip install tensorboard

2024-01-03 07:05:50.123641: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-03 07:05:50.123756: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-03 07:05:50.126339: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Collecting en-core-web-sm==3.6.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.6.0/en_core_web_sm-3.6.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m68.0 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load

# Dataset

In [45]:
spacy_eng = spacy.load("en_core_web_sm")

In [46]:
class Vocabulary:
    def __init__(self,freq_threshold):
        # Setting the pre-reserved tokens int to string tokens
        self.index2word = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}

        # String to int tokens
        # Tts reverse dict self.index2word
        self.word2index = {v: k for k, v in self.index2word.items()}

        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.index2word)

    @staticmethod
    def tokenize(text):
        return [token.text.lower() for token in spacy_eng.tokenizer(text)]

    def build_vocab(self, sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                frequencies[word] += 1

                #add the word to the vocab if it reaches minum frequecy threshold
                if frequencies[word] == self.freq_threshold:
                    self.word2index[word] = idx
                    self.index2word[idx] = word
                    idx += 1

    def numericalize(self,text):
        """ For each word in the text corresponding index token for that word form the vocab built as list """
        tokenized_text = self.tokenize(text)
        return [self.word2index[token] if token in self.word2index else self.word2index["<UNK>"] for token in tokenized_text ]

In [47]:
class ImageCaptioningDataset(Dataset):
    """Image Captioning dataset"""

    def __init__(self, csv_file, transform, freq_threshold=5):
        self.dataframe = pd.read_csv(csv_file)
        self.transform = transform

        self.images = self.dataframe['image']
        self.captions = self.dataframe['caption']

        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.captions.tolist())


    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        caption = self.captions[idx]
        image_path = self.images[idx]

        image = cv2.imread(f'dataset/Images/{image_path}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            image = self.transform(image)

        caption_vec = []
        caption_vec += [self.vocab.word2index["<SOS>"]]
        caption_vec += self.vocab.numericalize(caption)
        caption_vec += [self.vocab.word2index["<EOS>"]]

        return image, torch.tensor(caption_vec)

In [48]:
class CapsCollate:
    def __init__(self, pad_idx, batch_first=False):
        self.pad_idx = pad_idx
        self.batch_first = batch_first

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]

        imgs = torch.cat(imgs, dim=0)

        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=self.batch_first, padding_value=self.pad_idx)
        return imgs, targets

# Model

## Image

In [49]:
class ImageFeatureExtractor(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Load pretrained model and remove last fc layer
        pretrained_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.model = torch.nn.Sequential(*list(pretrained_model.children())[:-1]).to(device)

        # Freeze layer
        for param in self.model.parameters():
            param.requires_grad = False

        # Add a linear layer add the end of model
        self.linear = torch.nn.Linear(2048, 512).to(device)
        self.drop = torch.nn.Dropout(0.3)

    def forward(self, images):
        # Preprocess images
        images = images.to(device)

        # Forward pass
        feature = self.model(images)                   # (batch_size, 2048, 1, 1)
        feature = feature.view(images.shape[0], 1, -1) # (batch_size, 1, 2048)
        feature = self.drop(feature)
        output = self.linear(feature).squeeze(1)       # (batch_size, 512)

        # Return output
        return output

## Text

In [50]:
class TextFeatureExtractor(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.vocab_size = vocab_size

        # Embedding layer
        self.embedding = torch.nn.Embedding(vocab_size, embed_dim).to(device)

        # LSTM layer
        self.lstm = torch.nn.LSTMCell(input_size=embed_dim, hidden_size=512).to(device)

        # Linear layer
        self.decoder1 = torch.nn.Linear(1024, 512).to(device)
        self.decoder2 = torch.nn.Linear(512, self.vocab_size).to(device)
        self.drop = torch.nn.Dropout(0.3)

    def forward(self, features, sequences):

        sequence_length = len(sequences[0]) - 1
        preds = torch.zeros(sequences.shape[0], sequence_length, self.vocab_size)

        sequences = sequences.to(device)
        preds = preds.to(device)

        # Embedding sequence
        embeds = self.embedding(sequences)
        embeds = embeds.to(torch.float32)

        # Forward pass
        for idx in range(sequence_length):
            # Compute feature vector of input text
            embed_word = embeds[:, idx]
            h, c = self.lstm(embed_word)
            # Concat fe of image and text
            concat = torch.cat((features, h), 1)

            # Pass concat fe to decoder
            decoded = self.decoder1(concat)
            decoded = self.drop(decoded)
            output = self.decoder2(decoded)

            # Predicted vector
            preds[:, idx] = output

        return preds

    def predict(self, feature, max_length=20, vocab=None):
        # Starting input
        word = torch.tensor(vocab.word2index['<SOS>']).view(1, -1)
        word = word.to(device)
        feature = feature.to(device)

        # Embedding sequence
        embeds = self.embedding(word)

        captions = []

        for i in range(max_length):

            # Compute feature vector of input text
            embed_word = embeds[:, idx]
            hidden_state, cell_state = self.lstm(embed_word)

            # Concat fe of image and text
            concat = torch.cat((feature, hidden_state), 1)

            # Pass concat fe to decoder
            decoded = self.decoder1(concat)
            decoded = self.drop(decoded)
            output = self.decoder2(decoded)

            # Predict word index
            predicted_word_idx = output.argmax(dim=1)
            captions.append(predicted_word_idx.item())

            # End if <EOS> appears
            if vocab.index2word[predicted_word_idx.item()] == "<EOS>":
                break

            # Send generated word as the next caption
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))

        # Convert the vocab idx to words and return sentence
        return ' '.join([vocab.index2word[idx] for idx in captions])

## Captioner

In [51]:
class Captioner(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, vocab):
        super().__init__()
        self.image_fe =  ImageFeatureExtractor()
        self.text_fe = TextFeatureExtractor(vocab_size, embed_dim)
        self.vocab = vocab

    def forward(self, images, captions):

        image_fv = self.image_fe(images)
        output = self.text_fe(image_fv, captions)

        return output

    def generate_caption(self, image, max_length=20):
        feature = self.image_fe(image)
        predicted_caption = self.text_fe.predict(feature, max_length, self.vocab)

        return predicted_caption


# Train

In [52]:
# Warmp up GPU and CPU
def cpu():
  with tf.device('/cpu:0'):
    random_image_cpu = tf.random.normal((100, 100, 100, 3))
    net_cpu = tf.keras.layers.Conv2D(32, 7)(random_image_cpu)
    return tf.math.reduce_sum(net_cpu)

def gpu():
  with tf.device('/device:GPU:0'):
    random_image_gpu = tf.random.normal((100, 100, 100, 3))
    net_gpu = tf.keras.layers.Conv2D(32, 7)(random_image_gpu)
    return tf.math.reduce_sum(net_gpu)

cpu()
gpu()

<tf.Tensor: shape=(), dtype=float32, numpy=-319.3532>

In [60]:
class Trainer():
    def __init__(self):
        # Dataset
        self.dataset =  ImageCaptioningDataset(
            csv_file="dataset/captions.txt",
            transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize((224, 224), antialias=True),
                    transforms.CenterCrop(224),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
                    ])
            )


        # Hyperparameters
        self.vocab_size = len(self.dataset.vocab)
        self.vocab = self.dataset.vocab
        self.embed_dim = 300
        self.num_epochs = 50
        self.batch_size = 4

        # Parameters
        self.writer = SummaryWriter('runs')

        # Dataloader
        self.data_loader = DataLoader(
            dataset=self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=CapsCollate(pad_idx=self.dataset.vocab.word2index["<PAD>"], batch_first=True)
        )
    def train(self,resume=False):
        # Init model, optimizer, criterion
        model = Captioner(
            vocab_size=self.vocab_size,
            embed_dim=self.embed_dim,
            vocab=self.dataset.vocab
        )
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        criterion = torch.nn.CrossEntropyLoss(ignore_index=self.dataset.vocab.word2index["<PAD>"])

        # Starting epoch
        start_epoch = 0
        min_loss = 9999

        if resume:
            # Load model and optimizer state
            model_state, optimizer_state, prev_epoch, prev_loss = self.load_model()
            model.load_state_dict(model_state)
            optimizer.load_state_dict(optimizer_state)

            # Starting epoch
            start_epoch = prev_epoch
            min_loss = prev_loss

        for epoch in range(start_epoch + 1, self.num_epochs + 1):
            epoch_loss = []
            train_pbar = tqdm(enumerate(iter(self.data_loader)), position=0, leave=True)
            for idx, (image, captions) in train_pbar:
                image, captions = image.to(device), captions.to(device)

                # Zero the gradients
                optimizer.zero_grad()

                # Feed forward
                outputs = model(image, captions)

                # Calculate the batch loss
                targets = captions[:, 1:]
                loss = criterion(outputs.view(-1, self.vocab_size), targets.reshape(-1))
                epoch_loss.append(loss.item())

                # Backward pass
                loss.backward()

                # Update the parameters in the optimizer
                optimizer.step()

                # Show progess bar with loss per batch
                train_pbar.set_postfix_str(f"Loss: {sum(epoch_loss) / len(epoch_loss):0.4f}")
                break

            # Compute average loss per epoch
            avg_epoch_loss = sum(epoch_loss) / len(epoch_loss)
            self.writer.add_scalar("Loss/Train", avg_epoch_loss, epoch + 1)

            # Save model
            if avg_epoch_loss < min_loss:
                self.save_model(model, optimizer, epoch, avg_epoch_loss)
            break
        self.writer.close()

    def save_model(self, model, optimizer, epoch, loss):
        model_state = {
            'epoch': epoch,
            'loss': loss,
            'embed_dim': self.embed_dim,
            'vocab_size': self.vocab_size,
            'vocab': self.vocab,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }

        torch.save(model_state, f'models/merge/model_{epoch}_{loss:.4f}.pth')
        torch.save(model_state, f'models/merge/model_best.pth')

    def load_model(self):
        path = glob.glob("models/merge/*.pth")[-2]
        checkpoint = torch.load(path)

        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        embed_dim = checkpoint['embed_dim']
        vocab_size = checkpoint['vocab_size']
        vocab = checkpoint['vocab']

        model_state = checkpoint['model_state_dict']
        optimizer_state = checkpoint['optimizer_state_dict']

        return model_state, optimizer_state, epoch, loss

In [62]:
trainer = Trainer()
trainer.train(resume=True)

0it [00:03, ?it/s, Loss: 7.8404]


In [None]:
%reload_ext tensorboard
%tensorboard --logdir runs