In [15]:
from torch_snippets import *
import json, os, urllib.request
import pandas as pd, numpy as np
import os, random
import torch
import torch, torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence
from torchtext.vocab import build_vocab_from_iterator
from tqdm.notebook import tqdm
import torchvision.transforms.functional as TF
from torchvision.io import read_image



In [16]:
# load and parse JSONL captions
url = "https://storage.googleapis.com/localized-narratives/annotations/open_images_train_v6_captions.jsonl"
DEST_DIR = os.path.expanduser("~/Projects/img_captioning/data")
FILENAME = "open_images_train_captions.jsonl"
DEST_PATH = os.path.join(DEST_DIR, FILENAME)

rows = []
with open(DEST_PATH, 'r') as f:
    for js in f:
        try:
            d = json.loads(js.strip())
            rows.append({"image_id": d["image_id"], "caption": d["caption"]})
        except:
            continue

data = pd.DataFrame(rows).drop_duplicates("image_id").reset_index(drop=True)
print("Total caption rows:", len(data))

Total caption rows: 504413


In [17]:

#filter the caption DataFrame to include only those image IDs that exist in the image folders
train_files = {f.replace(".jpg", "") for f in os.listdir("train-images") if f.endswith(".jpg")}
val_files = {f.replace(".jpg", "") for f in os.listdir("val-images") if f.endswith(".jpg")}
all_ids = list(train_files.union(val_files))

df2 = data[data["image_id"].isin(all_ids)].reset_index(drop=True)
train_ids, val_ids = train_test_split(df2["image_id"], test_size=0.2, random_state=42)
train_df = df2[df2["image_id"].isin(train_ids)].reset_index(drop=True)
val_df = df2[df2["image_id"].isin(val_ids)].reset_index(drop=True)
print("Captions after filtering:", len(df2))

print("Train images:", len(train_df))
print("Val images:", len(val_df))

print("\nSample train rows:")
print(train_df.head())

print("\nSample val rows:")
print(val_df.head())

Captions after filtering: 2091
Train images: 1672
Val images: 419

Sample train rows:
           image_id                                            caption
0  46e58e6c8d11cd70  To this statue there is a cloth and objects. B...
1  a0e1fd6810824cd8  In this picture I can see few toilet paper rol...
2  06091e5f7e946d3d  In this image there is a camera on a table, in...
3  ed23444b40d5aa86  Here we can see a man presenting a award to a ...
4  7f6926d1659d3caa  At the bottom of this image, there are plants ...

Sample val rows:
           image_id                                            caption
0  a991a3f532d4d0f9  In front of the image there is a person drinki...
1  b3e0c474f1cffdce  In this picture there is a girl in the center ...
2  b7d04dcbac05475d  In this image there is a tissue holder and on ...
3  efb8555f62112d9a  In the center of the image we can see a man st...
4  29c5019f5023618c  This picture is a black and white image. In th...


In [18]:
# build vocabulary
SPECIALS = ["<pad>", "<unk>", "<start>", "<end>"]

def tokenize(text):
    return text.lower().strip().split()

def yield_tokens(captions):
    for c in captions:
        yield tokenize(c)

vocab = build_vocab_from_iterator(yield_tokens(train_df["caption"]), specials=SPECIALS)
vocab.set_default_index(vocab["<unk>"])
itos = vocab.get_itos()
stoi = {w: i for i, w in enumerate(itos)}

PAD_IDX = stoi["<pad>"]
BOS_IDX = stoi["<start>"]
EOS_IDX = stoi["<end>"]

print("Vocabulary built")
print("Vocabulary size:", len(vocab))

# show first 20 tokens
print("\nFirst 20 tokens in vocab:")
print(itos[:20])




Vocabulary built
Vocabulary size: 2580

First 20 tokens in vocab:
['<pad>', '<unk>', '<start>', '<end>', 'the', 'a', 'in', 'and', 'there', 'is', 'see', 'can', 'on', 'we', 'are', 'this', 'of', 'image', 'i', 'background']


In [19]:
# define the dataset class
class CaptioningDataset(Dataset):
    def __init__(self, df, folder, stoi, size=(224, 224)):
        self.df = df.reset_index(drop=True)
        self.folder = folder
        self.stoi = stoi
        self.size = size
        self.normalize = transforms.Normalize((0.485, 0.456, 0.406),
                                              (0.229, 0.224, 0.225))

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = f"{self.folder}/{row.image_id}.jpg"
        if not os.path.exists(img_path):
            return self.__getitem__(torch.randint(0, len(self.df), (1,)).item())

        img = read_image(img_path)
        if img.shape[0] == 1: img = img.repeat(3, 1, 1)
        elif img.shape[0] == 4: img = img[:3]

        img = TF.resize(img, self.size)
        img = img.float() / 255.0
        img = self.normalize(img)

        tokens = row.caption.lower().split()
        encoded = [self.stoi["<start>"]] + [self.stoi.get(w, self.stoi["<unk>"]) for w in tokens] + [self.stoi["<end>"]]
        return img, torch.tensor(encoded, dtype=torch.long), len(encoded)

    def __len__(self):
        return len(self.df)




In [21]:
# define a method to work on batch data
def collate_fn(batch):
    batch.sort(key=lambda x: x[2], reverse=True)
    imgs, caps, lens = zip(*batch)
    imgs = torch.stack(imgs)
    max_len = max(lens)
    padded = torch.zeros(len(caps), max_len).long()
    for i, cap in enumerate(caps):
        padded[i, :len(cap)] = cap
    return imgs, padded, torch.tensor(lens)

In [22]:
# create  datasets and dataloaders
train_dataset = CaptioningDataset(train_df, "train-images", stoi)
val_dataset = CaptioningDataset(val_df, "val-images", stoi)

train_dl = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_dl = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)
print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))

Train samples: 1672
Val samples: 419


In [24]:
# define the CNN encoder
class EncoderCNN(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.cnn = nn.Sequential(*list(resnet.children())[:-1])  # remove final FC layer
        for param in self.cnn.parameters():
            param.requires_grad = False
        self.fc = nn.Linear(resnet.fc.in_features, embed_dim)

    def forward(self, x):
        feats = self.cnn(x).flatten(1)
        return self.fc(feats)



In [25]:
# define the RNN decoder
class DecoderRNN(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.init_h = nn.Linear(embed_dim, hidden_dim)
        self.init_c = nn.Linear(embed_dim, hidden_dim)

    def forward(self, feats, inputs, lengths):
        h = torch.tanh(self.init_h(feats)).unsqueeze(0)
        c = torch.tanh(self.init_c(feats)).unsqueeze(0)
        emb = self.embed(inputs)
        packed = pack_padded_sequence(
            emb, lengths.cpu(), batch_first=True, enforce_sorted=True
        )
        out, _ = self.lstm(packed, (h, c))
        return self.fc(out.data)

    @torch.no_grad()
    def predict(self, feat, max_len=20):
        # initialize hidden state from image features
        h = torch.tanh(self.init_h(feat)).unsqueeze(0)
        c = torch.tanh(self.init_c(feat)).unsqueeze(0)

        # start with BOS token
        word = torch.tensor([BOS_IDX], device=feat.device)
        out_words = []

        for _ in range(max_len):
            emb = self.embed(word.unsqueeze(0))       # (1,1,embed_dim)
            o, (h, c) = self.lstm(emb, (h, c))        # (1,1,hidden_dim)
            logits = self.fc(o.squeeze(1))            # (1,vocab_size)

            # choose next word
            word = logits.argmax(1)                   # (1,)
            tok = itos[word.item()]                   # map idx → token string

            # skip repeating <start> tokens
            if tok == "<start>":
                continue

            # stop at <end>
            if tok == "<end>":
                break

            out_words.append(tok)

        return " ".join(out_words)


In [26]:
# define training and validation steps
def train_batch(imgs, caps, lens, encoder, decoder, optimizer, criterion):
    encoder.train(); decoder.train()
    feats = encoder(imgs)
    out = decoder(feats, caps, lens)
    tgt_packed = pack_padded_sequence(caps, lens.cpu(), batch_first=True).data
    loss = criterion(out, tgt_packed)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def validate_batch(imgs, caps, lens, encoder, decoder, criterion):
    encoder.eval(); decoder.eval()
    feats = encoder(imgs)
    out = decoder(feats, caps, lens)
    tgt_packed = pack_padded_sequence(caps, lens.cpu(), batch_first=True).data
    loss = criterion(out, tgt_packed)
    return loss.item()

In [27]:
# training loop
enc = EncoderCNN(256).to(device)
dec = DecoderRNN(256, 512, len(vocab)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(dec.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

EPOCHS = 5
log = Report(EPOCHS)

for epoch in range(EPOCHS):
    enc.train(); dec.train()
    train_losses = []

    for imgs, caps, lens in tqdm(train_dl, desc=f"Epoch {epoch+1} - Training"):
        imgs, caps, lens = imgs.to(device), caps.to(device), lens.to(device)
        loss = train_batch(imgs, caps, lens, enc, dec, optimizer, criterion)
        train_losses.append(loss)

    enc.eval(); dec.eval()
    val_losses = []

    for imgs, caps, lens in tqdm(val_dl, desc=f"Epoch {epoch+1} - Validation"):
        imgs, caps, lens = imgs.to(device), caps.to(device), lens.to(device)
        val_loss = validate_batch(imgs, caps, lens, enc, dec, criterion)
        val_losses.append(val_loss)

    scheduler.step()

    # record average losses for the epoch
    log.record(epoch + 1, trn_loss=np.mean(train_losses), val_loss=np.mean(val_losses))
    log.report_avgs(epoch + 1)

# plot after all epochs
log.plot_epochs(log=True)

Epoch 1 - Training:   0%|          | 0/105 [00:00<?, ?it/s]



KeyboardInterrupt: 

In [49]:
# prediction function
device = torch.device("cpu")

# same normalization used in your dataset
img_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

@torch.no_grad()
def predict_caption(image_path):
    img = read_image(image_path)

    # fix grayscale / RGBA
    if img.shape[0] == 1:
        img = img.repeat(3, 1, 1)
    elif img.shape[0] == 4:
        img = img[:3]

    img = img_transform(img)
    img = img.unsqueeze(0).to(device) # add batch dimension

    enc.eval()
    dec.eval()

    # encode image → get embedding
    feats = enc(img)

    # call decoder to generate caption
    caption = dec.predict(feats)

    return caption


In [50]:
# initialize models for prediction
enc = EncoderCNN(256)
dec = DecoderRNN(256, 512, len(vocab))


In [51]:
# load trained model weights

test_img = random.choice(os.listdir("val-images"))
test_path = f"val-images/{test_img}"

print("Image:", test_path)
print("Caption:", predict_caption(test_path))


Image: val-images/7773dcec1af6dd14.jpg
Caption: gate,wall rays. singing. lake plant, signboard, ocean an costume carpets, surface. cloud. fences, jacket. crane gun, moving. full autos. brown,grey
