In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image

!pip install git+https://github.com/openai/CLIP.git

import clip
import numpy as np

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-_nr8oo9j
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-_nr8oo9j
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25ldone
[?25h  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=a534a69350096a7a0194571395d60e2c1b08af85defa4bf6a11fb13488320f71
  Stored in directory: /tmp/pip-ephem-wheel-cache-hsvmphjv/wheels/da/2b/4c/d6691fa9597aac8bb

In [2]:
# Load captions
def load_captions(captions_file):
    captions = []
    with open(captions_file, 'r') as f:
        for line in f:
            img, caption = line.strip().split(',', 1)  # Split only on the first comma
            if (img != 'image'):
                captions.append((img.strip(), caption.strip()))
    return captions


# Preprocess images using CLIP
def preprocess_images(image_folder, clip_preprocess, device):
    images = {}
    for img_name in os.listdir(image_folder):
        try:
            img_path = os.path.join(image_folder, img_name)
            image = Image.open(img_path).convert("RGB")
            image_input = clip_preprocess(image).unsqueeze(0).to(device).to(torch.float32)

            # image_input = clip_preprocess(image).unsqueeze(0).to(device)
            images[img_name] = image_input
        except Exception as e:
            print(f"Error processing {img_name}: {e}")
    return images

# Tokenizer function
# def tokenize_caption(caption, vocab, max_len=20):
#     tokens = caption.lower().split()[:max_len]
#     return [vocab.get(token, vocab['<unk>']) for token in tokens] + [0] * (max_len - len(tokens))

# def tokenize_caption(caption, vocab, max_len=20):
#     tokens = caption.lower().split()
#     token_indices = [vocab.get(word, vocab["<unk>"]) for word in tokens]
#     token_indices = [vocab["start"]] + token_indices[:max_len - 2] + [vocab["end"]]
#     token_indices += [vocab["<pad>"]] * (max_len - len(token_indices))
#     return token_indices

def tokenize_caption(caption, vocab, max_len=20):
    words = caption.lower().split()
    
    # Convert words to indices and include start/end tokens
    tokens = [vocab['<start>']] + [vocab.get(w, vocab['<unk>']) for w in words] + [vocab['<end>']]

    # If the sequence is longer than max_len, truncate and ensure last token is <end>
    if len(tokens) > max_len:
        tokens = tokens[:max_len]
        tokens[-1] = vocab['<end>']

    # If shorter, pad with <pad>
    if len(tokens) < max_len:
        tokens += [vocab['<pad>']] * (max_len - len(tokens))

    return tokens


In [3]:
from torch.utils.data import Dataset

class FlickrDataset(Dataset):
    def __init__(self, captions, images, vocab):
        self.captions = [(img, cap) for img, cap in captions if img in images]
        self.images = images
        self.vocab = vocab

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_name, caption = self.captions[idx]
        if img_name not in self.images:
            raise ValueError(f"Image {img_name} not found in preprocessed images!")
        
        return img_name, torch.tensor(tokenize_caption(caption, self.vocab))


In [4]:
from sklearn.model_selection import train_test_split

def split_data(captions, test_size=0.2, random_state=42):

    img_names = list(set([img for img, _ in captions]))  # Unique image names
    train_imgs, test_imgs = train_test_split(img_names, test_size=test_size, random_state=random_state)

    train_captions = [(img, cap) for img, cap in captions if img in train_imgs]
    test_captions = [(img, cap) for img, cap in captions if img in test_imgs]

    return train_captions, test_captions


In [5]:
class CLIPEncoder(nn.Module):
    def __init__(self, device):
        super(CLIPEncoder, self).__init__()
        self.device = device
        self.model, self.preprocess = clip.load("ViT-B/32", device=self.device)

    def forward(self, image):
        with torch.no_grad():
            features = self.model.encode_image(image)
            features = features / features.norm(dim=-1, keepdim=True)  # Normalize
        return features.unsqueeze(1)  # Shape: (batch_size, 1, 512)


In [6]:
class Attention(nn.Module):
    def __init__(self, encoder_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, 512)
        self.decoder_att = nn.Linear(512, 512)
        self.full_att = nn.Linear(512, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)
        att2 = self.decoder_att(decoder_hidden)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)
        alpha = self.softmax(att)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha


In [7]:
class Decoder(nn.Module):
    def __init__(self, vocabulary_size, encoder_dim, tf=False):
        super(Decoder, self).__init__()
        self.use_tf = tf
        self.vocabulary_size = vocabulary_size
        self.encoder_dim = encoder_dim

        self.init_h = nn.Linear(encoder_dim, 512)
        self.init_c = nn.Linear(encoder_dim, 512)
        self.tanh = nn.Tanh()

        self.f_beta = nn.Linear(512, encoder_dim)
        self.sigmoid = nn.Sigmoid()

        self.deep_output = nn.Linear(512, vocabulary_size)
        self.dropout = nn.Dropout()

        self.attention = Attention(encoder_dim)
        self.embedding = nn.Embedding(vocabulary_size, 512)
        self.lstm = nn.LSTMCell(512 + encoder_dim, 512)

    def forward(self, img_features, captions):
        batch_size = img_features.size(0)
        max_timespan = captions.size(1) - 1

        h, c = self.get_init_lstm_state(img_features)

        embedding = self.embedding(captions)

        preds = torch.zeros(batch_size, max_timespan, self.vocabulary_size).to(img_features.device)
        alphas = torch.zeros(batch_size, max_timespan, img_features.size(1)).to(img_features.device)

        for t in range(max_timespan):
            context, alpha = self.attention(img_features, h)
            gate = self.sigmoid(self.f_beta(h))
            gated_context = gate * context

            lstm_input = torch.cat((embedding[:, t], gated_context), dim=1)
            h, c = self.lstm(lstm_input, (h, c))
            preds[:, t] = self.deep_output(self.dropout(h))
            alphas[:, t] = alpha

        return preds, alphas

    def get_init_lstm_state(self, img_features):
        img_features = img_features.to(torch.float32)
        avg_features = img_features.mean(dim=1)
        c = self.tanh(self.init_c(avg_features))
        h = self.tanh(self.init_h(avg_features))
        return h, c


In [8]:
class EncoderDecoder(nn.Module):
    def __init__(self, vocab_size, device):
        super(EncoderDecoder, self).__init__()
        self.encoder = CLIPEncoder(device)
        self.decoder = Decoder(vocab_size, encoder_dim=512)

    def forward(self, images, captions):
        encoder_out = self.encoder(images)
        preds, alphas = self.decoder(encoder_out, captions)
        return preds, alphas


In [9]:
def train_model(model, data_loader, criterion, optimizer, vocab_size, device, num_epochs=10):
    
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for img_names, captions in data_loader:
            # Filter out invalid keys (e.g., 'image' or missing images)
            valid_images = []
            for name in img_names:
                if name in preprocessed_images:
                    valid_images.append(preprocessed_images[name])
                else:
                    print(f"Skipping invalid key: {name}")

            if not valid_images:
                continue  # Skip batch if no valid images

            # Combine valid images into a tensor
            images = torch.cat(valid_images).to(device)
            # print(type(captions))
            captions = captions.to(device)


            # Forward pass
            preds, alphas = model(images, captions)

            # Align preds and targets
            max_seq_len = captions.size(1) - 1  # Exclude <end> token
            preds = preds[:, :max_seq_len, :].contiguous()
            targets = captions[:, 1:max_seq_len + 1].contiguous()

            # Debugging: Print shapes to ensure alignment
            # print(f"Preds shape: {preds.shape}, Targets shape: {targets.shape}")

            # Compute loss
            loss = criterion(preds.view(-1, vocab_size), targets.view(-1))

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}")



In [10]:
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"
# vocab = {"<pad>": 0, "<unk>": 1, "start": 2, "end": 3}  # Example vocabulary
# vocab_size = len(vocab)


captions_file = "/kaggle/input/captions.txt"
image_folder = "/kaggle/input/Images"

# Load captions
captions = load_captions(captions_file)

# Preprocess images
preprocessed_images = preprocess_images(image_folder, clip.load("ViT-B/32")[1], device)

100%|████████████████████████████████████████| 338M/338M [00:01<00:00, 193MiB/s]


In [11]:
from collections import Counter

def build_vocab(captions, min_freq=1):
    """
    Build vocabulary from captions.
    :param captions: List of (img_name, caption) pairs
    :param min_freq: Minimum frequency for a word to be included
    :return: vocab (word to index mapping)
    """
    word_counts = Counter()

    # Tokenize and count words
    for _, caption in captions:
        words = caption.lower().split()  # Simple split by spaces
        word_counts.update(words)
    
    # Create the vocabulary with special tokens
    vocab = {"<pad>": 0, "<unk>": 1, "<start>": 2, "<end>": 3}
    index = 4

    # Add words to the vocabulary that meet the frequency threshold
    for word, count in word_counts.items():
        if count >= min_freq:
            vocab[word] = index
            index += 1

    return vocab


In [12]:
vocab = build_vocab(captions, min_freq=1)
print(f"Vocabulary size: {len(vocab)}")
print(f"Sample words: {list(vocab.keys())[:10]}")


Vocabulary size: 9184
Sample words: ['<pad>', '<unk>', '<start>', '<end>', 'a', 'child', 'in', 'pink', 'dress', 'is']


In [13]:
# Split the captions
train_captions, test_captions = split_data(captions, test_size=0.2)

# Create datasets
train_dataset = FlickrDataset(train_captions, preprocessed_images, vocab)
test_dataset = FlickrDataset(test_captions, preprocessed_images, vocab)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Training samples: {len(train_dataset)}, Testing samples: {len(test_dataset)}")


Training samples: 32360, Testing samples: 8095


In [14]:
dataset = FlickrDataset(captions, preprocessed_images, vocab)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

model = EncoderDecoder(len(vocab), device).to(device)
model = model.to(torch.float32)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

print("00")

00


In [15]:
print(f"Vocabulary size passed to Decoder: {len(vocab)}")
print(f"Decoder's vocabulary size: {model.decoder.vocabulary_size}")


Vocabulary size passed to Decoder: 9184
Decoder's vocabulary size: 9184


In [16]:
train_model(model, train_loader, criterion, optimizer, len(vocab), device)

Epoch 1/10, Loss: 4.600524904935257
Epoch 2/10, Loss: 3.7553504078755737
Epoch 3/10, Loss: 3.4727476170882876
Epoch 4/10, Loss: 3.276135066045603
Epoch 5/10, Loss: 3.12910344807998
Epoch 6/10, Loss: 3.0113821642200937
Epoch 7/10, Loss: 2.913134193231937
Epoch 8/10, Loss: 2.828891282025062
Epoch 9/10, Loss: 2.75844442703036
Epoch 10/10, Loss: 2.6925692461696067


In [25]:
def generate_captions(model, test_loader, vocab, device, max_len=20):
    """
    Generate captions for test images
    :param model: Trained Encoder-Decoder model
    :param test_loader: DataLoader for test data
    :param vocab: Vocabulary dictionary (word to index mapping)
    :param device: Device to run the model (CPU or GPU)
    :param max_len: Maximum length of generated captions
    """
    model.eval()  # Set the model to evaluation mode
    idx_to_word = {v: k for k, v in vocab.items()}  # Reverse vocab to map indices to words

    print("Generating captions on test data...")
    with torch.no_grad():
        for img_names, captions in test_loader:
            # Prepare images
            images = torch.cat([preprocessed_images[name] for name in img_names]).to(device)
            actual_captions = [[idx_to_word[idx.item()] for idx in caption if idx.item() != vocab["<pad>"]] for caption in captions]

            # Forward pass through the encoder
            encoder_out = model.encoder(images)

            # Initialize LSTM state and start token
            h, c = model.decoder.get_init_lstm_state(encoder_out)
            prev_word = torch.tensor([vocab['start']] * len(img_names)).to(device)  # Start tokens

            # Store generated captions
            generated_captions = [[] for _ in range(len(img_names))]

            for _ in range(max_len):
                # Embed previous word
                embedding = model.decoder.embedding(prev_word)

                # Attention
                context, _ = model.decoder.attention(encoder_out, h)
                gate = model.decoder.sigmoid(model.decoder.f_beta(h))
                gated_context = gate * context

                # LSTM forward
                lstm_input = torch.cat((embedding, gated_context), dim=1)
                h, c = model.decoder.lstm(lstm_input, (h, c))

                # Predict next word
                output = model.decoder.deep_output(h)  # (batch_size, vocab_size)
                prev_word = output.argmax(1)  # Greedy decoding: Pick the word with max probability

                # Append words to captions
                for i in range(len(img_names)):
                    word = idx_to_word.get(prev_word[i].item(), '<unk>')
                    if word == 'end':  # Stop generating if 'end' token is predicted
                        continue
                    generated_captions[i].append(word)

            # temp=0
            # count=5
            # # Print generated and actual captions
            # for i, name in enumerate(img_names):
            #     temp=temp+1
            #     if temp < count:
            #         print(f"Image: {name}")
            #         print(f"Generated Caption: {' '.join(generated_captions[i])}")
            #         print(f"Actual Caption: {' '.join(actual_captions[i])}")
            #         print("-" * 50)
    return actual_captions, generated_captions

In [26]:
print(f"Vocabulary size: {len(vocab)}")
print(f"Sample words: {list(vocab.keys())[:10]}")


Vocabulary size: 9184
Sample words: ['<pad>', '<unk>', '<start>', '<end>', 'a', 'child', 'in', 'pink', 'dress', 'is']


In [27]:
actual_captions, generated_captions = generate_captions(model, test_loader, vocab, device, max_len=20)

Generating captions on test data...


In [29]:
for i in range(10):
    print(f"Generated Caption: {' '.join(generated_captions[i])}")
    print(f"Actual Caption: {' '.join(actual_captions[i])}")
    print("---------")

Generated Caption: of a child in a blue shirt and a blue shirt <end> <end> <end> <end> <end> <end> <end> <end> <end>
Actual Caption: <start> the young boy is pushing the cart inside the store . <end>
---------
Generated Caption: man in a blue shirt and blue shorts is riding a bike <end> <end> <end> <end> <end> <end> <end> <end>
Actual Caption: <start> a man in red swim trunks is jumping onto a bodyboard . <end>
---------
Generated Caption: man in a blue shirt and blue shorts is riding a bike <end> <end> <end> <end> <end> <end> <end> <end>
Actual Caption: <start> a man in red trunks flies through the air with a boogie board . <end>
---------
Generated Caption: man in a blue shirt and blue shorts is riding a bike <end> <end> <end> <end> <end> <end> <end> <end>
Actual Caption: <start> a man with a wake-board is diving over a surface that is not water . <end>
---------
Generated Caption: man in a blue shirt and blue shorts is riding a bike <end> <end> <end> <end> <end> <end> <end> <end>
Ac