In [None]:
# Show, Attend and Tell: Neural Image Captioning with Visual Attention
# Implementation based on Xu et al. (2015)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json
import os
from collections import Counter
import pickle

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================================================
# 1. DATA PREPROCESSING AND VOCABULARY
# =============================================================================

class Vocabulary:
    """Vocabulary wrapper for image captions"""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def build_vocab(caption_path, threshold=4):
    """Build vocabulary from caption annotations"""
    # This is a simplified version - in practice you'd load from COCO annotations
    vocab = Vocabulary()

    # Add special tokens
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # In practice, you would:
    # 1. Load COCO annotations
    # 2. Count word frequencies
    # 3. Add words above threshold to vocabulary

    # For demonstration, adding some common words
    common_words = ['a', 'the', 'man', 'woman', 'dog', 'cat', 'car', 'tree',
                   'house', 'red', 'blue', 'large', 'small', 'sitting', 'standing',
                   'walking', 'running', 'eating', 'playing', 'on', 'in', 'with']

    for word in common_words:
        vocab.add_word(word)

    return vocab

# =============================================================================
# 2. CNN ENCODER (Feature Extractor)
# =============================================================================

class EncoderCNN(nn.Module):
    """CNN Encoder using ResNet as backbone"""
    def __init__(self, encoded_image_size=14):
        super(EncoderCNN, self).__init__()
        self.enc_image_size = encoded_image_size

        # Load pretrained ResNet-101 and remove final layers
        resnet = models.resnet101(pretrained=True)
        modules = list(resnet.children())[:-2]  # Remove avgpool and fc layers
        self.resnet = nn.Sequential(*modules)

        # Adaptive pooling to get fixed size output
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        # Disable gradient computation for pretrained layers (optional)
        for param in self.resnet.parameters():
            param.requires_grad = False

    def forward(self, images):
        """
        Forward pass of encoder
        Args:
            images: (batch_size, 3, image_size, image_size)
        Returns:
            features: (batch_size, encoded_image_size, encoded_image_size, 2048)
        """
        features = self.resnet(images)  # (batch_size, 2048, image_size/32, image_size/32)
        features = self.adaptive_pool(features)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        features = features.permute(0, 2, 3, 1)  # (batch_size, encoded_image_size, encoded_image_size, 2048)
        return features

# =============================================================================
# 3. ATTENTION MECHANISM
# =============================================================================

class Attention(nn.Module):
    """Soft Attention mechanism"""
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # Linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # Linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # Linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # Softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        """
        Forward pass of attention
        Args:
            encoder_out: (batch_size, num_pixels, encoder_dim)
            decoder_hidden: (batch_size, decoder_dim)
        Returns:
            attention_weighted_encoding: (batch_size, encoder_dim)
            alpha: (batch_size, num_pixels)
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha

# =============================================================================
# 4. LSTM DECODER WITH ATTENTION
# =============================================================================

class DecoderWithAttention(nn.Module):
    """LSTM Decoder with Attention"""
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        # Attention network
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)

        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout_layer = nn.Dropout(p=self.dropout)

        # LSTM cell
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)

        # Linear layers for initial hidden and cell states
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)

        # Linear layer to create a sigmoid-activated gate
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()

        # Linear layer to find scores over vocabulary
        self.fc = nn.Linear(decoder_dim, vocab_size)

        # Initialize some layers with the uniform distribution
        self.init_weights()

    def init_weights(self):
        """Initialize some parameters with values from the uniform distribution"""
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        """Load pretrained embeddings"""
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        """Fine-tune embedding layer"""
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        """
        Creates the initial hidden and cell states for the decoder's LSTM
        Args:
            encoder_out: (batch_size, num_pixels, encoder_dim)
        Returns:
            h: (batch_size, decoder_dim)
            c: (batch_size, decoder_dim)
        """
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)  # (batch_size, decoder_dim)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Forward pass of decoder
        Args:
            encoder_out: (batch_size, enc_image_size, enc_image_size, encoder_dim)
            encoded_captions: (batch_size, max_caption_length)
            caption_lengths: (batch_size, 1)
        Returns:
            predictions: (batch_size, max_caption_length, vocab_size)
            alphas: (batch_size, max_caption_length, num_pixels)
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word prediction scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        # At each time-step, decode by attention-weighing the encoder's output based on the decoder's previous hidden state output
        # then generate a new word in the decoder with the previous word and the attention weighted encoding
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout_layer(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, alphas, encoded_captions, decode_lengths, sort_ind

# =============================================================================
# 5. COMPLETE MODEL
# =============================================================================

class ShowAttendTell(nn.Module):
    """Complete Show, Attend and Tell model"""
    def __init__(self, vocab_size, attention_dim=512, embed_dim=512, decoder_dim=512, dropout=0.5):
        super(ShowAttendTell, self).__init__()

        self.encoder = EncoderCNN()
        self.decoder = DecoderWithAttention(
            attention_dim=attention_dim,
            embed_dim=embed_dim,
            decoder_dim=decoder_dim,
            vocab_size=vocab_size,
            dropout=dropout
        )

    def forward(self, images, captions, caption_lengths):
        """Forward pass of complete model"""
        encoder_out = self.encoder(images)
        predictions, alphas, encoded_captions, decode_lengths, sort_ind = self.decoder(
            encoder_out, captions, caption_lengths
        )
        return predictions, alphas, encoded_captions, decode_lengths, sort_ind

# =============================================================================
# 6. DATASET CLASS
# =============================================================================

class CaptionDataset(Dataset):
    """Custom Dataset for loading image-caption pairs"""
    def __init__(self, data_folder, data_name, split, transform=None):
        self.split = split
        assert self.split in {'TRAIN', 'VAL', 'TEST'}

        # Open hdf5 file where images are stored
        self.h = h5py.File(os.path.join(data_folder, self.split + '_IMAGES_' + data_name + '.hdf5'), 'r')
        self.imgs = self.h['images']

        # Captions per image
        self.cpi = self.h.attrs['captions_per_image']

        # Load encoded captions
        with open(os.path.join(data_folder, self.split + '_CAPTIONS_' + data_name + '.json'), 'r') as j:
            self.captions = json.load(j)

        # Load caption lengths
        with open(os.path.join(data_folder, self.split + '_CAPLENS_' + data_name + '.json'), 'r') as j:
            self.caplens = json.load(j)

        # Transform
        self.transform = transform

        # Total number of datapoints
        self.dataset_size = len(self.captions)

    def __getitem__(self, i):
        # The Nth caption corresponds to the (N // captions_per_image)th image
        img = torch.FloatTensor(self.imgs[i // self.cpi] / 255.)
        if self.transform is not None:
            img = self.transform(img)

        caption = torch.LongTensor(self.captions[i])
        caplen = torch.LongTensor([self.caplens[i]])

        if self.split == 'TRAIN':
            return img, caption, caplen
        else:
            # For validation and testing, also return all 'captions_per_image' captions for evaluation
            all_captions = torch.LongTensor(
                self.captions[((i // self.cpi) * self.cpi):(((i // self.cpi) * self.cpi) + self.cpi)])
            return img, caption, caplen, all_captions

    def __len__(self):
        return self.dataset_size

# =============================================================================
# 7. TRAINING UTILITIES
# =============================================================================

def clip_gradient(optimizer, grad_clip):
    """Clips gradients computed during backpropagation to avoid explosion of gradients"""
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder,
                   encoder_optimizer, decoder_optimizer, bleu4, is_best):
    """Saves model checkpoint"""
    state = {'epoch': epoch,
             'epochs_since_improvement': epochs_since_improvement,
             'bleu-4': bleu4,
             'encoder': encoder,
             'decoder': decoder,
             'encoder_optimizer': encoder_optimizer,
             'decoder_optimizer': decoder_optimizer}
    filename = 'checkpoint_' + data_name + '.pth.tar'
    torch.save(state, filename)
    if is_best:
        torch.save(state, 'BEST_' + filename)

class AverageMeter(object):
    """Keeps track of most recent, average, sum, and count of a metric"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(scores, targets, k):
    """Computes top-k accuracy"""
    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)

# =============================================================================
# 8. TRAINING FUNCTION
# =============================================================================

def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer,
          epoch, vocab_size, grad_clip, alpha_c=1.0, print_freq=100):
    """
    Performs one epoch's training
    """
    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to GPU, if available
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top5=top5accs))

# =============================================================================
# 9. INFERENCE AND BEAM SEARCH
# =============================================================================

def beam_search(encoder, decoder, image, word_map, beam_size=3):
    """
    Beam Search for generating captions

    Args:
        encoder: encoder model
        decoder: decoder model
        image: image tensor
        word_map: word2ix mapping
        beam_size: number of beams

    Returns:
        seq: caption sequence
        alphas: attention weights
    """
    k = beam_size
    vocab_size = len(word_map)

    # Encode
    encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)

    # Flatten encoding
    encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
    num_pixels = encoder_out.size(1)

    # We'll treat the problem as having a batch size of k
    encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

    # Tensor to store top k previous words at each step; now they're just <start>
    k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)

    # Tensor to store top k sequences; now they're just <start>
    seqs = k_prev_words  # (k, 1)

    # Tensor to store top k sequences' scores; now they're just 0
    top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

    # Tensor to store top k sequences' alphas; now they're just 1s
    seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device)  # (k, 1, enc_image_size, enc_image_size)

    # Lists to store completed sequences, their alphas and scores
    complete_seqs = list()
    complete_seqs_alpha = list()
    complete_seqs_scores = list()

    # Start decoding
    step = 1
    h, c = decoder.init_hidden_state(encoder_out)

    # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
    while True:

        embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

        awe, alpha = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

        alpha = alpha.view(-1, enc_image_size, enc_image_size)  # (s, enc_image_size, enc_image_size)

        gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
        awe = gate * awe

        h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

        scores = decoder.fc(h)  # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        # Add
        scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

        # For the first step, all k points will have the same scores (since same k previous words, h, c)
        if step == 1:
            top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (k)
        else:
            # Unroll and find top scores, and their unrolled indices
            top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (k)

        # Convert unrolled indices to actual indices of scores
        prev_word_inds = top_k_words // vocab_size  # (k)
        next_word_inds = top_k_words % vocab_size  # (k)

        # Add new words to sequences, alphas
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (k, step+1)
        seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
                               dim=1)  # (k, step+1, enc_image_size, enc_image_size)

        # Which sequences are incomplete (didn't reach <end>)?
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                           next_word != word_map['<end>']]
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        # Set aside complete sequences
        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist())
            complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
            complete_seqs_scores.extend(top_k_scores[complete_inds])
        k -= len(complete_inds)  # reduce beam length accordingly

        # Proceed with incomplete sequences
        if k == 0:
            break
        seqs = seqs[incomplete_inds]
        seqs_alpha = seqs_alpha[incomplete_inds]
        h = h[prev_word_inds[incomplete_inds]]
        c = c[prev_word_inds[incomplete_inds]]
        encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

        # Break if things have been going on too long
        if step > 50:
            break
        step += 1

    i = complete_seqs_scores.index(max(complete_seqs_scores))
    seq = complete_seqs[i]
    alphas = complete_seqs_alpha[i]

    return seq, alphas

# =============================================================================
# 10. VISUALIZATION UTILITIES
# =============================================================================

def visualize_attention(image_path, seq, alphas, vocab, smooth=True):
    """
    Visualizes caption with attention over image
    """
    image = Image.open(image_path)
    image = image.resize([14 * 24, 14 * 24], Image.LANCZOS)

    words = [vocab.idx2word[ind] for ind in seq]

    for t in range(len(words)):
        if t > 50:
            break
        plt.subplot(int(np.ceil(len(words) / 5.)), 5, t + 1)

        plt.text(0, 1, '%s' % (words[t]), color='black', fontsize=12)
        plt.imshow(image)
        current_alpha = alphas[t, :]
        if smooth:
            alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=24, sigma=8)
        else:
            alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 24, 14 * 24])
        if t == 0:
            plt.imshow(alpha, alpha=0)
        else:
            plt.imshow(alpha, alpha=0.8)
        plt.set_cmap(cm.Greys_r)
        plt.axis('off')
    plt.show()

# =============================================================================
# 11. EXAMPLE USAGE AND TRAINING SETUP
# =============================================================================

def main():
    # Hyperparameters
    batch_size = 32
    workers = 4
    encoder_lr = 1e-4
    decoder_lr = 4e-4
    grad_clip = 5.0
    alpha_c = 1.0  # regularization parameter for 'doubly stochastic attention'
    best_bleu4 = 0.
    epochs = 120
    epochs_since_improvement = 0

    # Build vocabulary (in practice, load from preprocessed files)
    vocab = build_vocab('path_to_captions', threshold=5)
    vocab_size = len(vocab)

    # Initialize model
    model = ShowAttendTell(vocab_size=vocab_size).to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Optimizers
    encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.encoder.parameters()),
                                         lr=encoder_lr)
    decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.decoder.parameters()),
                                         lr=decoder_lr)

    # Data loaders (you would need to implement the actual data loading)
    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)

    print("Model initialized successfully!")
    print(f"Vocabulary size: {vocab_size}")
    print(f"Using device: {device}")

    # Example of how training would work:
    # for epoch in range(epochs):
    #     train(train_loader, model.encoder, model.decoder, criterion,
    #           encoder_optimizer, decoder_optimizer, epoch, vocab_size, grad_clip, alpha_c)

    return model, vocab

if __name__ == "__main__":
    model, vocab = main()

# =============================================================================
# 12. EXAMPLE INFERENCE CODE
# =============================================================================

def generate_caption_example():
    """Example of how to generate captions for new images"""

    # Load trained model (you would load from checkpoint)
    # checkpoint = torch.load('BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar')
    # model = checkpoint['model']
    # vocab = checkpoint['vocab']

    # For demonstration, create dummy model and vocab
    vocab = build_vocab('dummy_path')
    model = ShowAttendTell(vocab_size=len(vocab)).to(device)
    model.eval()

    # Image preprocessing
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    # Load and preprocess image (example path)
    # image_path = 'path/to/your/image.jpg'
    # image = Image.open(image_path).convert('RGB')
    # image = transform(image).unsqueeze(0).to(device)

    # Generate caption using beam search
    # with torch.no_grad():
    #     seq, alphas = beam_search(model.encoder, model.decoder, image, vocab.word2idx)

    # Convert sequence to words
    # words = [vocab.idx2word[ind] for ind in seq if ind not in {vocab.word2idx['<start>'],
    #                                                            vocab.word2idx['<end>'],
    #                                                            vocab.word2idx['<pad>']}]
    # caption = ' '.join(words)
    # print(f"Generated Caption: {caption}")

    # Visualize attention
    # visualize_attention(image_path, seq, alphas, vocab)

    print("Caption generation example prepared!")

# =============================================================================
# 13. EVALUATION METRICS (BLEU SCORE)
# =============================================================================

def evaluate_bleu(model, data_loader, vocab):
    """
    Evaluate model using BLEU-4 score
    """
    model.eval()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    with torch.no_grad():
        for i, (imgs, caps, caplens, allcaps) in enumerate(data_loader):

            imgs = imgs.to(device)

            # Generate captions
            for img in imgs:
                seq, _ = beam_search(model.encoder, model.decoder, img.unsqueeze(0), vocab.word2idx)

                # Convert to words
                words = [vocab.idx2word[ind] for ind in seq if ind not in {vocab.word2idx['<start>'],
                                                                         vocab.word2idx['<end>'],
                                                                         vocab.word2idx['<pad>']}]
                hypotheses.append(' '.join(words))

            # References
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [vocab.idx2word[ind] for ind in c if ind not in {vocab.word2idx['<start>'],
                                                                                  vocab.word2idx['<end>'],
                                                                                  vocab.word2idx['<pad>']}],
                        img_caps))  # remove <start> and pads
                references.append(img_captions)

    # Calculate BLEU-4 scores
    # Note: You would need to install nltk and use nltk.translate.bleu_score
    # from nltk.translate.bleu_score import corpus_bleu
    # bleu4 = corpus_bleu(references, hypotheses)

    return 0.0  # placeholder return

# =============================================================================
# 14. DATA PREPROCESSING PIPELINE
# =============================================================================

def preprocess_coco_data(coco_dir, output_dir, captions_per_image=5, min_word_freq=5, max_len=50):
    """
    Preprocess COCO dataset for training

    This function would:
    1. Load COCO annotations
    2. Build vocabulary
    3. Encode captions
    4. Resize and save images
    5. Create train/val/test splits
    """

    import json
    import h5py
    from collections import Counter
    from PIL import Image
    import numpy as np

    # Load COCO annotations (you would need to download COCO dataset)
    # with open(os.path.join(coco_dir, 'annotations', 'captions_train2014.json'), 'r') as f:
    #     train_data = json.load(f)

    # This is a simplified preprocessing pipeline
    # In practice, you would:

    # 1. Extract image paths and captions
    # 2. Build vocabulary from captions
    # 3. Encode captions using vocabulary
    # 4. Store processed data in HDF5 files

    print("Data preprocessing pipeline prepared!")
    print("To use with actual COCO data, uncomment and modify the preprocessing code.")

# =============================================================================
# 15. HYPERPARAMETER TUNING UTILITIES
# =============================================================================

def hyperparameter_search():
    """
    Example hyperparameter search configuration
    """

    # Define hyperparameter ranges
    hyperparams = {
        'attention_dim': [256, 512, 1024],
        'embed_dim': [256, 512, 1024],
        'decoder_dim': [256, 512, 1024],
        'learning_rate': [1e-4, 5e-4, 1e-3],
        'batch_size': [16, 32, 64],
        'dropout': [0.3, 0.5, 0.7]
    }

    # You could implement grid search or random search here
    print("Hyperparameter search configuration prepared!")

    return hyperparams

# =============================================================================
# 16. MODEL ANALYSIS AND DEBUGGING
# =============================================================================

def analyze_attention_patterns(model, image, caption, vocab):
    """
    Analyze attention patterns for debugging
    """
    model.eval()

    with torch.no_grad():
        # Forward pass
        encoder_out = model.encoder(image.unsqueeze(0))

        # Get attention weights for each time step
        # This would require modifying the forward pass to return attention weights

        print("Attention analysis prepared!")

def count_parameters(model):
    """Count total number of trainable parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    return total_params, trainable_params

# =============================================================================
# 17. EXPORT AND DEPLOYMENT UTILITIES
# =============================================================================

def export_model_for_inference(model, vocab, save_path):
    """
    Export model for production inference
    """
    # Save model state dict and vocabulary
    torch.save({
        'model_state_dict': model.state_dict(),
        'vocab': vocab,
        'model_config': {
            'vocab_size': len(vocab),
            'attention_dim': 512,
            'embed_dim': 512,
            'decoder_dim': 512
        }
    }, save_path)

    print(f"Model exported to {save_path}")

def load_model_for_inference(model_path):
    """
    Load model for inference
    """
    checkpoint = torch.load(model_path, map_location=device)

    # Reconstruct model
    config = checkpoint['model_config']
    model = ShowAttendTell(
        vocab_size=config['vocab_size'],
        attention_dim=config['attention_dim'],
        embed_dim=config['embed_dim'],
        decoder_dim=config['decoder_dim']
    ).to(device)

    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    vocab = checkpoint['vocab']

    model.eval()
    return model, vocab

# =============================================================================
# 18. TESTING AND VALIDATION
# =============================================================================

def test_model_components():
    """
    Test individual model components
    """
    print("Testing model components...")

    # Test vocabulary
    vocab = build_vocab('dummy_path')
    assert len(vocab) > 0, "Vocabulary should not be empty"

    # Test encoder
    encoder = EncoderCNN().to(device)
    dummy_image = torch.randn(2, 3, 224, 224).to(device)
    encoder_out = encoder(dummy_image)
    expected_shape = (2, 14, 14, 2048)
    assert encoder_out.shape == expected_shape, f"Expected {expected_shape}, got {encoder_out.shape}"

    # Test attention
    attention = Attention(2048, 512, 512).to(device)
    encoder_out_flat = encoder_out.view(2, -1, 2048)
    decoder_hidden = torch.randn(2, 512).to(device)
    att_out, alpha = attention(encoder_out_flat, decoder_hidden)
    assert att_out.shape == (2, 2048), f"Attention output shape mismatch: {att_out.shape}"
    assert alpha.shape == (2, 196), f"Alpha shape mismatch: {alpha.shape}"

    # Test decoder
    decoder = DecoderWithAttention(512, 512, 512, len(vocab)).to(device)
    dummy_captions = torch.randint(0, len(vocab), (2, 20)).to(device)
    dummy_lengths = torch.tensor([15, 18]).unsqueeze(1).to(device)

    # Test complete model
    model = ShowAttendTell(len(vocab)).to(device)
    predictions, alphas, _, _, _ = model(dummy_image, dummy_captions, dummy_lengths)

    print("All component tests passed!")

    # Count parameters
    total, trainable = count_parameters(model)

    return True

# Run tests
if __name__ == "__main__":
    # Run component tests
    test_model_components()

    # Initialize main components
    model, vocab = main()

    # Prepare example inference
    generate_caption_example()

    # Show hyperparameter options
    hyperparams = hyperparameter_search()

    print("\n" + "="*80)
    print("SHOW, ATTEND AND TELL - IMPLEMENTATION COMPLETE")
    print("="*80)
    print("\nNext steps:")
    print("1. Download and preprocess COCO dataset")
    print("2. Modify data loading paths in the code")
    print("3. Run training with your dataset")
    print("4. Evaluate model performance")
    print("5. Generate captions for new images")
    print("\nImplementation includes:")
    print("- Complete encoder-decoder architecture with attention")
    print("- Beam search for caption generation")
    print("- Training and evaluation utilities")
    print("- Visualization tools for attention maps")
    print("- Model export/import for deployment")
    print("="*80)