In [None]:
#Importing required libraries
import torch
from torch import nn, optim
import torch.onnx
import torch.nn.functional as F
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from nltk.tokenize import word_tokenize
from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer

#Select GPU to use otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
import nltk
nltk.download('punkt')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#Loading CSV file
captions = pd.read_csv("/content/drive/MyDrive/NN_BASEMODEL/flicke 1k/flickr1k/captions.csv")
captions

In [None]:
#Displays a random image
def display_random_data(count=5, seed=1):
    np.random.seed(seed)
    # random choose images == count
    images = np.random.choice(captions['image'].unique(), count)
    # display and their captions
    for image in images:
        # display image
        display(Image.open(f'/content/drive/MyDrive/NN_BASEMODEL/flicke 1k/flickr1k/images/{image}'))
        # display caption
        img_captions = captions.loc[captions['image']==image, 'caption'].tolist()
        for cap in img_captions:
            print(cap)
display_random_data(2)

In [None]:
#Load and preprocess dataset
class My_Flickr1k(Dataset):
    def __init__(self, root_file, captions, transform=None):

        self.transform = transform
        self.root = root_file
        self.ids = captions

    def __getitem__(self, idx):

        image_path, caption = self.ids[idx]
        image = Image.open(self.root+image_path)
        if self.transform:
            image = self.transform(image)

        return image, caption

    def __len__(self):
        return self.ids.shape[0]

In [None]:
#Split captions and images into training and test data
def build_datasets_vocab(root_file, captions_file, transform, split=0.15):
    df = pd.read_csv(captions_file)

    vocab = {}
    def create_vocab(caption):
        tokens = [token.lower() for token in word_tokenize(caption)]
        for token in tokens:
            if token not in vocab:
                vocab[token] = len(vocab)

    df["caption"].apply(create_vocab)

    train, valid = train_test_split(df, test_size=split, random_state=42)
    return My_Flickr1k(root_file, train.values, transform), \
           My_Flickr1k(root_file, valid.values, transform), \
           vocab

In [None]:
nltk.download('punkt_tab')
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
train_dataset, valid_dataset, vocab = build_datasets_vocab("/content/drive/MyDrive/NN_BASEMODEL/flicke 1k/flickr1k/images/",
                                              "/content/drive/MyDrive/NN_BASEMODEL/flicke 1k/flickr1k/captions.csv",
                                              transform)

id_to_word = {id_: word for word, id_ in vocab.items()}

In [None]:
df = pd.read_csv("/content/drive/MyDrive/NN_BASEMODEL/flicke 1k/flickr1k/captions.csv")
# MAX_CAPTION_LEN = df["caption"].apply(lambda x: len(word_tokenize(x))).max()
MAX_CAPTION_LEN = 38

In [None]:
#Transforms text captions into padded token ID sequences and decodes token IDs back into readable captions.
def transform_captions(captions):

    transformed = [[vocab[word.lower()] for word in word_tokenize(caption)] for caption in captions]
    padded = [transform + [vocab["."]]*(MAX_CAPTION_LEN - len(transform)) for transform in transformed]

    return padded

def get_caption(caption_sequence):

    return " ".join([id_to_word[id_] for id_ in caption_sequence if id_ != vocab["."]])

In [None]:
# Constants
POOLING_FACTOR = 32

In [None]:
#Defines convolutional and transpose convolutional layers with LeakyReLU activation
class ConvLeak(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=5):

        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=kernel_size, padding=(kernel_size-1)//2),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        return self.layer(x)


class ConvTransposeLeak(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=5):
        super().__init__()
        self.layer = nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=kernel_size, padding=(kernel_size-1)//2),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        return self.layer(x)

In [None]:
#Defines a VAE encoder that extracts image features through convolution and pooling, then projects them into a latent space
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, image_dim, latent_dim):

        super().__init__()

        # constants used

        iW, iH = image_dim
        hW, hH = iW//POOLING_FACTOR, iH//POOLING_FACTOR
        vec_dim = out_channels * hW * hH

        self.layer1 = nn.Sequential(
            ConvLeak(in_channels=in_channels, out_channels=48),
            ConvLeak(in_channels=48, out_channels=48)
        )
        self.layer2 = nn.Sequential(
            ConvLeak(in_channels=48, out_channels=84),
            ConvLeak(in_channels=84, out_channels=84)
        )
        self.layer3 = nn.Sequential(
            ConvLeak(in_channels=84, out_channels=128),
            ConvLeak(in_channels=128, out_channels=128)
        )

        self.layer4 = nn.Sequential(
            ConvLeak(in_channels=128, out_channels=out_channels),
            nn.Flatten()
        )




        self.pooling = nn.MaxPool2d(4, return_indices=True)
        self.pooling_2 = nn.MaxPool2d(2, return_indices=True)


        self.hidden = nn.Sequential(
            nn.Linear(in_features = vec_dim, out_features=latent_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=latent_dim, out_features=latent_dim),
            nn.Tanh()
        )

        self.encoder_mean = nn.Linear(in_features = latent_dim, out_features = vec_dim)
        self.encoder_logstd = nn.Linear(in_features = latent_dim, out_features = vec_dim)


    def generate_code(self, mean, log_std):

        sigma = torch.exp(log_std)
        epsilon = torch.randn_like(mean)
        return (sigma * epsilon) + mean


    def forward(self, x):

        x = self.layer1(x)
        x, indices_1 = self.pooling(x)
        x = self.layer2(x)
        x, indices_2 = self.pooling(x)
        x = self.layer3(x)
        x, indices_3 = self.pooling_2(x)
        x = self.layer4(x)




        hidden = self.hidden(x)
        mean, log_std = self.encoder_mean(hidden), self.encoder_logstd(hidden)
        c = self.generate_code(mean, log_std)

        return c, indices_1, indices_2, indices_3, mean, log_std

In [None]:
#Defines a VAE decoder that reconstructs images from latent vectors
class Decoder(nn.Module):

    def __init__(self, in_channels, out_channels, image_dim):

        super().__init__()

        iW, iH = image_dim
        hW, hH = iW//POOLING_FACTOR, iH//POOLING_FACTOR

        self.layer4 = nn.Sequential(
            nn.Unflatten(1, unflattened_size=(in_channels, hW, hH)),
            ConvTransposeLeak(in_channels=in_channels, out_channels=128)
        )

        self.layer3 = nn.Sequential(
            ConvTransposeLeak(128, 128),
            ConvTransposeLeak(128, 84)
        )
        self.layer2 = nn.Sequential(
            ConvTransposeLeak(84, 84),
            ConvTransposeLeak(84, 48)
        )
        self.layer1 = nn.Sequential(
            ConvTransposeLeak(48, 48),
            ConvTransposeLeak(48, 3)
        )

        self.unpooling = nn.MaxUnpool2d(4)
        self.unpooling_2 = nn.MaxUnpool2d(2)

        self.precision = nn.Parameter(torch.rand(1))


    def generate_data(self, mean, precision):


        sigma = torch.exp(-precision)
        epsilon = torch.randn_like(mean)
        return (sigma * epsilon) + mean

    def forward(self, x, indices_1, indices_2, indices_3):

        x = self.layer4(x)
        x = self.unpooling_2(x, indices_3)
        x = self.layer3(x)
        x = self.unpooling(x, indices_2)
        x = self.layer2(x)
        x = self.unpooling(x, indices_1)
        x = self.layer1(x)

        return x

In [None]:
#Implementing Cross Attention
class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim, hidden_dim):
        super().__init__()
        self.query_proj = nn.Linear(query_dim, hidden_dim)
        self.key_proj = nn.Linear(context_dim, hidden_dim)
        self.value_proj = nn.Linear(context_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, context_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, query, context):

        query = self.query_proj(query).unsqueeze(1)  # (batch, 1, hidden_dim)
        key = self.key_proj(context)                 # (batch, seq_len, hidden_dim)
        value = self.value_proj(context)             # (batch, seq_len, hidden_dim)

        scores = torch.bmm(query, key.transpose(1, 2))  # (batch, 1, seq_len)
        attn_weights = self.softmax(scores)             # (batch, 1, seq_len)

        attended = torch.bmm(attn_weights, value)        # (batch, 1, hidden_dim)
        attended = attended.squeeze(1)                   # (batch, hidden_dim)

        output = self.out_proj(attended)                 # (batch, context_dim)

        return output



In [None]:
# Captioning model using GRU and cross attention over VAE-encoded image features
class CaptionRNN(nn.Module):
    CAPTION_LIMIT = MAX_CAPTION_LEN

    def __init__(self, input_size, vocab_size, embedding_size, hidden_size, stop_index):
        super().__init__()

        self.code_seq_len = 64  # because 128x128 images downsampled
        self.code_dim = input_size // self.code_seq_len

        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.gru = nn.GRU(embedding_size + self.code_dim, hidden_size, batch_first=True)

        # Cross Attention
        self.cross_attention = CrossAttention(hidden_size, context_dim=self.code_dim, hidden_dim=hidden_size)
        self.context_proj = nn.Linear(self.code_dim, embedding_size)

        # Output MLP
        self.fc_out = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, vocab_size)
        )

        self.init_hidden_proj = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh()
        )

        self.stop_index = stop_index

    def generate_caption(self, code):
        batch_size = code.size(0)
        code_seq = code.view(batch_size, self.code_seq_len, self.code_dim) # Reshape latent vector to sequence for attention

        h_t = self.init_hidden_proj(code)

        context = self.cross_attention(h_t, code_seq)


        first_logits = self.fc_out(h_t)
        y_t = torch.multinomial(F.softmax(first_logits, dim=-1), 1)
        w_t = self.embedding(y_t)

        words = [y_t.item()]

        for _ in range(CaptionRNN.CAPTION_LIMIT - 1):
            if words[-1] == self.stop_index and len(words) >= 5:  # Stop if EOS token appears after a few words
                break

            gru_input = torch.cat((w_t, context.unsqueeze(1)), dim=-1)  # (batch, 1, embedding+code_dim)
            out, h_t_new = self.gru(gru_input, h_t.unsqueeze(0))
            h_t = out.squeeze(1)

            context = self.cross_attention(h_t, code_seq)


            logits = self.fc_out(h_t)
            y_t = torch.multinomial(F.softmax(logits, dim=-1), 1)

            words.append(y_t.item())
            w_t = self.embedding(y_t)

        return words



    def caption_prob(self, code, caption):
        batch_size = code.size(0)
        code_seq = code.view(batch_size, self.code_seq_len, self.code_dim)

        h_t = self.init_hidden_proj(code)
        caption_embed = self.embedding(caption) # Embed input caption

        outputs = []
        for t in range(caption.size(1)):
            context = self.cross_attention(h_t, code_seq)


            gru_input = torch.cat((caption_embed[:, t:t+1], context.unsqueeze(1)), dim=-1)
            out, h_t = self.gru(gru_input, h_t.unsqueeze(0))
            h_t = out.squeeze(1)

            logits = self.fc_out(h_t)  # Predict token distribution
            outputs.append(logits.unsqueeze(1)) # Store predictions

        outputs = torch.cat(outputs, dim=1)
        return outputs



In [None]:
# Combines VAE and captioning modules to reconstruct images and generate captions
class VAECaptioner(nn.Module):

    def __init__(self, in_channel, code_channels, image_dim, vocab):
        super().__init__()

        LATENT_DIM = 300
        EMBEDDING_SIZE = 600
        HIDDEN_SIZE = 512
        CODE_FLAT = code_channels*((image_dim[0]*image_dim[1])//(POOLING_FACTOR**2))

        self.vocab = vocab

        self.encoder = Encoder(in_channel, code_channels, image_dim, LATENT_DIM)
        self.decoder = Decoder(code_channels, in_channel, image_dim)
        self.captionr = CaptionRNN(CODE_FLAT, len(vocab), EMBEDDING_SIZE, HIDDEN_SIZE, vocab["."])

    def forward(self, x, y):

        c, indices_1, indices_2, indices_3, mean, log_std = self.encoder(x) # Encode image to latent representation
        reconstructed = self.decoder(c, indices_1, indices_2, indices_3) # Reconstruct image from code
        caption_prob = self.captionr.caption_prob(c, y) # Compute caption token probabilities

        return reconstructed, caption_prob, mean, log_std

    def generate_caption(self, x):

        c, indices_1, indices_2, indices_3, mean, log_std = self.encoder(x)
        return self.captionr.generate_caption(c[0])

In [None]:
#Initializes training parameters, data loaders, model, optimizer, and loss functions for image reconstruction and caption generation
EPOCHS = 3
BATCH_SIZE = 32
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VAECaptioner(3, 128, (128, 128), vocab).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.0002)
criterion = nn.MSELoss(reduction="sum")
criterion2 = nn.CrossEntropyLoss(reduction="sum")

In [None]:
#Calculates the total VAE loss
def calculate_loss(reconstructed, caption_prob, images, captions_transformed, mean, log_std, kl_weight):
    reconstruction_error = criterion(reconstructed, images)

    caption_prob = caption_prob.permute(0, 2, 1)
    caption_loss = criterion2(caption_prob, captions_transformed)

    KL_divergence = - (1 - mean.pow(2) - torch.exp(2 * log_std) + (2 * log_std)).sum()

    total_loss = reconstruction_error + caption_loss + kl_weight * KL_divergence
    return total_loss, caption_loss


In [None]:
#Training the model while tracking losses
losses = []
caption_losses = []
val_losses = []
val_caption_losses = []

kl_annealing_steps = 45  # based on batch size and 3 epochs
current_step = 0         # counter for total updates

for epoch in range(EPOCHS):
    t = tqdm(train_dataloader, desc=f"Train: Epoch {epoch}")

    for images, captions in t:
        images = images.to(device)
        captions_transformed = torch.LongTensor(transform_captions(captions)).to(device)

        reconstructed, caption_prob, mean, log_std = model(images, captions_transformed)

        kl_weight = min(1.0, current_step / kl_annealing_steps)  # KL annealing
        loss, caption_loss = calculate_loss(reconstructed, caption_prob, images, captions_transformed, mean, log_std, kl_weight)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.item())
        caption_losses.append(caption_loss.item())
        current_step += 1

    v = tqdm(valid_dataloader, desc=f"Valid: Epoch {epoch}")
    with torch.no_grad():
        for images, captions in v:
            images = images.to(device)
            captions_transformed = torch.LongTensor(transform_captions(captions)).to(device)
            reconstructed, caption_prob, mean, log_std = model(images, captions_transformed)


            loss, caption_loss = calculate_loss(reconstructed, caption_prob, images, captions_transformed, mean, log_std, kl_weight=1.0)

            val_losses.append(loss.item())
            val_caption_losses.append(caption_loss.item())


In [None]:
torch.save(model.state_dict(), 'VAECaptioner.onnx')
from torch.autograd import Variable
trained_model = VAECaptioner(3, 128, (128, 128), vocab)
trained_model.load_state_dict(torch.load('VAECaptioner.onnx'))
dummy_input = Variable(torch.randn(1, 1, 28, 28))
#torch.onnx.export(trained_model, dummy_input, "VAECaptioner.onnx")

In [None]:
torch_model = VAECaptioner(3, 128, (128, 128), vocab).to(device)
model_path = "VAECaptioner.onnx"

# Initialize model with the weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(torch.load(model_path, map_location=map_location))

In [None]:
plt.plot(losses)
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.grid(True)
plt.show()

In [None]:
plt.plot(caption_losses)
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Loss Over Time - For Captions')
plt.grid(True)
plt.show()


In [None]:
# Step 1: Select and prepare the image
img = images[4].unsqueeze(0).to(device)  # (1, 3, 128, 128)

# Step 2: Pass through encoder to get latent code
code, _, _, _, _, _ = model.encoder(img)

# Step 3: Pass code into generate_caption
caption_ids = model.captionr.generate_caption(code)

# Step 4: Show the image
plt.imshow(images[4].cpu().permute(1, 2, 0))
plt.axis("off")
_ = plt.title(get_caption(caption_ids))

In [None]:
# Step 1: Select and prepare the image
img = images[4].unsqueeze(0).to(device)  # (1, 3, 128, 128)

# Step 2: Pass through encoder to get latent code
code, _, _, _, _, _ = model.encoder(img)

# Step 3: Pass code into generate_caption
caption_ids = model.captionr.generate_caption(code)

# Step 4: Show the image
plt.imshow(images[4].cpu().permute(1, 2, 0))
plt.axis("off")
_ = plt.title(get_caption(caption_ids))


In [None]:
# Step 1: Select and prepare the image
img = images[4].unsqueeze(0).to(device)  # (1, 3, 128, 128)

# Step 2: Pass through encoder to get latent code
code, _, _, _, _, _ = model.encoder(img)  # encoder returns 6 things, we only care about first one

# Step 3: Pass code into generate_caption (no second unsqueeze needed)
caption_ids = model.captionr.generate_caption(code)  # code already has batch dimension

# Step 4: Show the image
plt.imshow(images[4].cpu().permute(1, 2, 0))
plt.axis("off")
_ = plt.title(get_caption(caption_ids))