In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import inception_v3
from PIL import Image
from tqdm import tqdm
from collections import Counter
import pickle

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Path configuration
base_path = 'archive (2)'
image_folder = os.path.join(base_path, 'Images')  # Folder containing images
captions_file = os.path.join(base_path, 'captions.txt')  # Captions text file

# Load and preprocess captions
def load_captions(filename):
    """Load captions and create mapping from image IDs to captions"""
    captions_mapping = {}
    with open(filename, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            if len(parts) < 2:
                continue
            img_id, caption = parts[0], ','.join(parts[1:])
            img_id = os.path.splitext(img_id)[0]  # Remove file extension
            if img_id not in captions_mapping:
                captions_mapping[img_id] = []
            captions_mapping[img_id].append(caption.lower().strip())
    return captions_mapping

captions_mapping = load_captions(captions_file)

# Initialize InceptionV3 for feature extraction
def get_inception():
    model = inception_v3(pretrained=True)
    model.fc = nn.Identity()  # Remove final layer
    model = model.to(device).eval()
    return model

image_model = get_inception()

# Image transformations
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Feature extraction with progress tracking
def extract_features(image_folder, captions_mapping):
    features = {}
    missing_images = []

    for img_id in tqdm(captions_mapping.keys(), desc="Extracting features"):
        img_path = os.path.join(image_folder, f"{img_id}.jpg")
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img).unsqueeze(0).to(device)
            with torch.no_grad():
                features[img_id] = image_model(img_tensor).cpu().numpy()
        except Exception as e:
            missing_images.append(img_id)
            continue

    print(f"\nFailed to process {len(missing_images)} images")
    return features

features = extract_features(image_folder, captions_mapping)

# Save features
with open('image_features.pkl', 'wb') as f:
    pickle.dump(features, f)

# Vocabulary construction
def build_vocabulary(captions_mapping):
    word_counts = Counter()
    for img_id in captions_mapping:
        for caption in captions_mapping[img_id]:
            word_counts.update(caption.split())

    # Create vocabulary with special tokens
    vocab = ['<pad>', '<start>', '<end>', '<unk>']
    vocab += [word for word, count in word_counts.items() if count >= 5]

    word2idx = {word: idx for idx, word in enumerate(vocab)}
    idx2word = {idx: word for word, idx in word2idx.items()}

    return vocab, word2idx, idx2word

vocab, word2idx, idx2word = build_vocabulary(captions_mapping)
vocab_size = len(vocab)
print(f"Vocabulary size: {vocab_size}")

# Find maximum caption length
max_length = max(len(caption.split()) for captions in captions_mapping.values()
                for caption in captions) + 2  # +2 for start/end tokens
print(f"Maximum caption length: {max_length}")

# Dataset class
class ImageCaptionDataset(Dataset):
    def __init__(self, img_ids, captions_mapping, features, word2idx, max_length):
        self.img_ids = [img_id for img_id in img_ids if img_id in features]
        self.captions_mapping = captions_mapping
        self.features = features
        self.word2idx = word2idx
        self.max_length = max_length

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        feature = torch.FloatTensor(self.features[img_id])
        captions = self.captions_mapping[img_id]

        # Randomly select one caption per image
        caption = np.random.choice(captions)

        # Convert caption to indices
        caption_words = ['<start>'] + caption.split() + ['<end>']
        caption_idx = [self.word2idx.get(word, self.word2idx['<unk>'])
                      for word in caption_words]

        # Pad to max_length
        caption_idx = caption_idx[:self.max_length]
        caption_idx += [self.word2idx['<pad>']] * (self.max_length - len(caption_idx))

        return feature.squeeze(0), torch.LongTensor(caption_idx)

# Model architecture with attention
class ImageCaptionModel(nn.Module):
    def __init__(self, vocab_size, embed_size=256, hidden_size=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.encoder = nn.Linear(2048, hidden_size)
        self.decoder = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):
        # Encode image features
        img_embed = self.dropout(self.encoder(features))

        # Embed captions
        cap_embed = self.dropout(self.embed(captions))

        # LSTM processing
        lstm_out, _ = self.lstm(cap_embed)

        # Combine image and language features
        img_embed = img_embed.unsqueeze(1).expand(-1, lstm_out.size(1), -1)
        combined = img_embed + lstm_out

        # Predict next words
        output = self.decoder(combined)
        return output

# Initialize model
model = ImageCaptionModel(vocab_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index

# Split dataset
img_ids = list(features.keys())
split = int(0.8 * len(img_ids))
train_ids = img_ids[:split]
val_ids = img_ids[split:]

train_dataset = ImageCaptionDataset(train_ids, captions_mapping, features, word2idx, max_length)
val_dataset = ImageCaptionDataset(val_ids, captions_mapping, features, word2idx, max_length)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Training loop
def train_model(model, train_loader, val_loader, epochs=20):
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        # Training phase
        for features, captions in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            features = features.to(device)
            captions = captions.to(device)

            optimizer.zero_grad()

            # Forward pass - predict next words
            outputs = model(features, captions[:, :-1])
            loss = criterion(outputs.reshape(-1, vocab_size),
                           captions[:, 1:].reshape(-1))

            # Backward pass
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for features, captions in val_loader:
                features = features.to(device)
                captions = captions.to(device)

                outputs = model(features, captions[:, :-1])
                loss = criterion(outputs.reshape(-1, vocab_size),
                               captions[:, 1:].reshape(-1))
                val_loss += loss.item()

        # Print statistics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

    return model

# Start training
trained_model = train_model(model, train_loader, val_loader)

# Prediction function
def generate_caption(model, image_path, max_length=20):
    model.eval()

    # Process image
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)

    # Extract features
    with torch.no_grad():
        features = image_model(img_tensor)

    # Initialize caption
    caption = ['<start>']

    for _ in range(max_length):
        # Convert current caption to indices
        caption_idx = [word2idx.get(word, word2idx['<unk>']) for word in caption]
        caption_tensor = torch.LongTensor(caption_idx).unsqueeze(0).to(device)

        # Predict next word
        with torch.no_grad():
            output = model(features, caption_tensor)
            # Get the last predicted word (modification here)
            next_word_idx = output[0, -1].argmax().item()
            next_word = idx2word[next_word_idx]

            caption.append(next_word)
            if next_word == '<end>':
                break

    # Remove special tokens and join
    caption = ' '.join([word for word in caption if word not in ['<start>', '<end>', '<pad>']])
    return caption

# Test on sample image
# sample_image = os.path.join(image_folder, list(features.keys())[0] + '.jpg')
# print("Generated caption:", generate_caption(trained_model, sample_image))

# test_image = os.path.join(base_path, 'Images', '10815824_2997e03d76.jpg')
# if os.path.exists(test_image):
#     caption = generate_caption(model, test_image)
#     print(f"Generated caption: {caption}")
# else:
#     print(f"Test image not found at path: {test_image}")


# List of test images (add more image filenames as needed)
test_images = [
    os.path.join(base_path, 'Images', '667626_18933d713e.jpg'),
    os.path.join(base_path, 'Images', '3637013_c675de7705.jpg'),  # Replace with actual filenames
    os.path.join(base_path, 'Images', '3639617775_149001232a.jpg'),   # Replace with actual filenames
    os.path.join(base_path, 'Images', '3639547922_0b00fed5cd.jpg'),  # Replace with actual filenames
    os.path.join(base_path, 'Images', '3639428663_dae5e8146e.jpg'),  # Replace with actual filenames
    os.path.join(base_path, 'Images', '3639363462_bcdb21de29.jpg')  # Replace with actual filenames
]

# Iterate over each test image
for test_image in test_images:
    if os.path.exists(test_image):
        # Generate caption
        caption = generate_caption(model, test_image)

        # Display the image and its caption
        img = Image.open(test_image)
        img.show()  # This works in Jupyter Notebook or IPython environments
        print(f"Generated caption: {caption}")
    else:
        print(f"Test image not found at path: {test_image}")