In [None]:
!pip install --upgrade --force-reinstall torch torchvision torchaudio

In [None]:
import os
import zipfile
import gdown
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import inception_v3, Inception_V3_Weights
from PIL import Image
from tqdm import tqdm
from collections import Counter
import pickle

In [None]:
# ==================== Step 1: Download dataset ====================
base_path = 'archive'
os.makedirs(base_path, exist_ok=True)
dataset_id = ''
captions_id = ''

dataset_zip_path = os.path.join(base_path, 'flickr_dataset.zip')
if not os.path.exists(dataset_zip_path):
    print("Downloading dataset...")
    gdown.download(id=dataset_id, output=dataset_zip_path, quiet=False)

if not os.path.exists(os.path.join(base_path, 'Images')):
    print("Extracting dataset...")
    with zipfile.ZipFile(dataset_zip_path, 'r') as zip_ref:
        zip_ref.extractall(base_path)

captions_file = os.path.join(base_path, 'captions.txt')
if not os.path.exists(captions_file):
    print("Downloading captions.txt...")
    gdown.download(id=captions_id, output=captions_file, quiet=False)

image_folder = os.path.join(base_path, 'Images')

In [None]:
# ==================== Step 2: Load captions ====================
def load_captions(filename):
    captions_mapping = {}
    with open(filename, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            if len(parts) < 2:
                continue
            img_id, caption = parts[0], ','.join(parts[1:])
            img_id = os.path.splitext(img_id)[0]
            if img_id not in captions_mapping:
                captions_mapping[img_id] = []
            captions_mapping[img_id].append(caption.lower().strip())
    return captions_mapping

captions_mapping = load_captions(captions_file)

In [None]:
# ==================== Step 3: Feature Extraction ====================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def get_inception():
    weights = Inception_V3_Weights.DEFAULT
    model = inception_v3(weights=weights, aux_logits=True)  # MUST be True when loading pretrained weights
    model.aux_logits = False  # Disable aux output after loading
    model.fc = nn.Identity()  # Remove classification layer
    model = model.to(device).eval()
    return model

image_model = get_inception()

transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def extract_features(image_folder, captions_mapping):
    features = {}
    for img_id in tqdm(captions_mapping.keys(), desc="Extracting features"):
        img_path = os.path.join(image_folder, f"{img_id}.jpg")
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img).unsqueeze(0).to(device)
            with torch.no_grad():
                features[img_id] = image_model(img_tensor).cpu().numpy()
        except:
            continue
    return features

features = extract_features(image_folder, captions_mapping)

with open('image_features.pkl', 'wb') as f:
    pickle.dump(features, f)

In [None]:
# ==================== Step 4: Vocabulary ====================
def build_vocabulary(captions_mapping):
    word_counts = Counter()
    for img_id in captions_mapping:
        for caption in captions_mapping[img_id]:
            word_counts.update(caption.split())

    vocab = ['<pad>', '<start>', '<end>', '<unk>']
    vocab += [word for word, count in word_counts.items() if count >= 5]

    word2idx = {word: idx for idx, word in enumerate(vocab)}
    idx2word = {idx: word for word, idx in word2idx.items()}

    return vocab, word2idx, idx2word

vocab, word2idx, idx2word = build_vocabulary(captions_mapping)
vocab_size = len(vocab)
max_length = max(len(caption.split()) for captions in captions_mapping.values()
                 for caption in captions) + 2

In [None]:
# ==================== Step 5: Dataset ====================
class ImageCaptionDataset(Dataset):
    def __init__(self, img_ids, captions_mapping, features, word2idx, max_length):
        self.img_ids = [img_id for img_id in img_ids if img_id in features]
        self.captions_mapping = captions_mapping
        self.features = features
        self.word2idx = word2idx
        self.max_length = max_length

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        feature = torch.FloatTensor(self.features[img_id])
        captions = self.captions_mapping[img_id]
        caption = np.random.choice(captions)

        caption_words = ['<start>'] + caption.split() + ['<end>']
        caption_idx = [self.word2idx.get(word, self.word2idx['<unk>'])
                       for word in caption_words]
        caption_idx = caption_idx[:self.max_length]
        caption_idx += [self.word2idx['<pad>']] * (self.max_length - len(caption_idx))
        return feature.squeeze(0), torch.LongTensor(caption_idx)

In [None]:
# ==================== Step 6: Model ====================
class ImageCaptionModel(nn.Module):
    def __init__(self, vocab_size, embed_size=256, hidden_size=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.encoder = nn.Linear(2048, hidden_size)
        self.decoder = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):
        img_embed = self.dropout(self.encoder(features))
        cap_embed = self.dropout(self.embed(captions))
        lstm_out, _ = self.lstm(cap_embed)
        img_embed = img_embed.unsqueeze(1).expand(-1, lstm_out.size(1), -1)
        combined = img_embed + lstm_out
        output = self.decoder(combined)
        return output

model = ImageCaptionModel(vocab_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
# ==================== Step 7: Train Model ====================
img_ids = list(features.keys())
split = int(0.8 * len(img_ids))
train_ids, val_ids = img_ids[:split], img_ids[split:]

train_dataset = ImageCaptionDataset(train_ids, captions_mapping, features, word2idx, max_length)
val_dataset = ImageCaptionDataset(val_ids, captions_mapping, features, word2idx, max_length)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

def train_model(model, train_loader, val_loader, epochs=20):
    best_val_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for features, captions in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            features, captions = features.to(device), captions.to(device)
            optimizer.zero_grad()
            outputs = model(features, captions[:, :-1])
            loss = criterion(outputs.reshape(-1, vocab_size), captions[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for features, captions in val_loader:
                features, captions = features.to(device), captions.to(device)
                outputs = model(features, captions[:, :-1])
                loss = criterion(outputs.reshape(-1, vocab_size), captions[:, 1:].reshape(-1))
                val_loss += loss.item()
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
    return model

trained_model = train_model(model, train_loader, val_loader)

In [None]:
# ==================== Step 8: Caption Generation ====================
def generate_caption(model, image_path, max_length=20):
    model.eval()
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        features = image_model(img_tensor)
    caption = ['<start>']
    for _ in range(max_length):
        caption_idx = [word2idx.get(word, word2idx['<unk>']) for word in caption]
        caption_tensor = torch.LongTensor(caption_idx).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(features, caption_tensor)
            next_word_idx = output[0, -1].argmax().item()
            next_word = idx2word[next_word_idx]
            caption.append(next_word)
            if next_word == '<end>':
                break
    caption = ' '.join([word for word in caption if word not in ['<start>', '<end>', '<pad>']])
    return caption

In [None]:
# ==================== Step 9: Test ====================
sample_image = os.path.join(image_folder, list(features.keys())[1] + '.jpg')
if os.path.exists(sample_image):
    caption = generate_caption(trained_model, sample_image)
    print("Generated caption:", caption)
else:
    print("Sample image not found.")