In [None]:
# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- Paths ---
# Path to your trained model checkpoint (update with your Kaggle dataset path)
MODEL_CHECKPOINT_PATH = "/kaggle/input/your-model-dataset/best_model.pt"

# Path to your test set embeddings (update with your Kaggle dataset path)
TEST_EMBEDDINGS_PATH = "/kaggle/input/your-embeddings-dataset/test_embeddings.pt"

# --- Model Architecture Hyperparameters ---
# CRITICAL: These MUST match the values used during training
EMBED_DIM = 768         # Dimension of your image embeddings (e.g., 768 for ViT-B, 2048 for CLIP ViT-L or ResNet)
GPT_DIM = 768           # GPT-2 embedding dimension (768 for base GPT-2, don't change unless using a different LM)
PREFIX_LENGTH = 10      # Number of prefix tokens generated by your mapping network
HIDDEN_LENGTH = 10      # Number of hidden tokens in your mapping network (architecture-specific)
NUM_LAYERS = 8          # Number of transformer layers in your mapping network

# --- Inference Parameters ---
BATCH_SIZE = 64         # Batch size for inference (adjust based on available GPU memory)
MAX_LENGTH = 50         # Maximum caption length to generate
TEMPERATURE = 1.0       # Sampling temperature (0 = greedy, higher = more random)
TOP_P = 0.9            # Nucleus sampling threshold

# --- Output ---
OUTPUT_FILENAME = "results.json"  # Default name for COCO submission


In [None]:
# MODEL DEFINITIONS
# ==============================================================================

def load_gpt2_tokenizer() -> GPT2Tokenizer:
    """
    Load and configure GPT-2 tokenizer.
    
    Returns:
        Configured GPT2Tokenizer with pad token set to eos token
    """
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token  # GPT-2 has no dedicated pad token
    return tokenizer


class TransformerMappingNetwork(nn.Module):
    """
    Maps image embeddings to a sequence of prefix tokens using a Transformer encoder.
    
    This network projects image embeddings into GPT-2's embedding space and processes
    them through a transformer to generate contextual prefix tokens for caption generation.
    
    Args:
        embed_dim: Dimension of input image embeddings
        gpt_dim: Dimension of GPT-2 embeddings (typically 768)
        prefix_length: Number of prefix tokens to generate
        hidden_length: Number of intermediate tokens before producing final prefix
        num_layers: Number of transformer encoder layers
    """
    
    def __init__(
        self,
        embed_dim: int,
        gpt_dim: int,
        prefix_length: int,
        hidden_length: int,
        num_layers: int = 8,
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.gpt_dim = gpt_dim
        self.hidden_length = hidden_length
        self.prefix_length = prefix_length
        
        # Linear projection from image embedding space to GPT embedding space
        self.linear = nn.Linear(embed_dim, hidden_length * gpt_dim)
        
        # Learnable prefix tokens that will be refined by the transformer
        self.prefix_const = nn.Parameter(
            torch.randn(prefix_length, gpt_dim), requires_grad=True
        )
        
        # Transformer encoder to process image tokens and generate contextual prefix
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=gpt_dim,
            nhead=8,
            dim_feedforward=int(gpt_dim * 4),
            batch_first=True,
            activation="relu",
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: image embeddings → prefix tokens
        
        Args:
            x: Image embeddings of shape (batch_size, embed_dim)
        
        Returns:
            Prefix tokens of shape (batch_size, prefix_length, gpt_dim)
        """
        batch_size = x.shape[0]
        
        # Project and reshape: (B, E) → (B, H*G) → (B, H, G)
        x = self.linear(x)
        x = x.view(batch_size, self.hidden_length, self.gpt_dim)
        
        # Expand learnable prefix: (P, G) → (B, P, G)
        prefix = self.prefix_const.unsqueeze(0).expand(batch_size, -1, -1)
        
        # Concatenate image tokens + prefix tokens: (B, H+P, G)
        inputs = torch.cat((x, prefix), dim=1)
        
        # Process through transformer and return only the prefix portion
        out = self.transformer(inputs)
        return out[:, self.hidden_length:, :]  # (B, P, G)


class ImageCaptioningModel(nn.Module):
    """
    Complete image captioning model combining mapping network and GPT-2.
    
    This model takes image embeddings, converts them to prefix tokens via the
    mapping network, and uses GPT-2 to generate captions conditioned on these prefixes.
    
    Args:
        mapping_network: Network that converts image embeddings to prefix tokens
        image_prefix_length: Length of image prefix (defaults to mapping_network.prefix_length)
        prefix_task_prompt: Optional task-specific text prefix (e.g., "A photo of")
        tokenizer: GPT-2 tokenizer (will be loaded if not provided)
        gpt: GPT-2 model (will be loaded if not provided)
        freeze_gpt_weights: Whether to freeze GPT-2 weights during training
    """
    
    def __init__(
        self,
        mapping_network: nn.Module,
        image_prefix_length: int | None = None,
        prefix_task_prompt: str | None = None,
        tokenizer: GPT2Tokenizer | None = None,
        gpt: GPT2LMHeadModel | None = None,
        freeze_gpt_weights: bool = True,
    ) -> None:
        super().__init__()
        self.image_prefix_length = image_prefix_length or mapping_network.prefix_length
        self.mapping_network = mapping_network
        
        # Load GPT-2 model and tokenizer
        self.gpt = gpt or GPT2LMHeadModel.from_pretrained("gpt2")
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        self.tokenizer = tokenizer or load_gpt2_tokenizer()
        
        # Freeze/unfreeze GPT-2 parameters
        for p in self.gpt.parameters():
            p.requires_grad = not freeze_gpt_weights
        
        # Optional task-specific prefix embeddings
        self.task_prefix_embeds: nn.Parameter | None = None
        if prefix_task_prompt:
            with torch.no_grad():
                task_token_ids = self.tokenizer.encode(
                    prefix_task_prompt, return_tensors="pt"
                )
                task_token_embeds = self.gpt.transformer.wte(task_token_ids)
            self.task_prefix_embeds = nn.Parameter(
                task_token_embeds.squeeze(0), requires_grad=True
            )
    
    def forward(
        self,
        caption_token_ids: torch.Tensor,
        image_embeddings: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
    ) -> CausalLMOutputWithCrossAttentions:
        """
        Forward pass for training (computes loss on caption tokens).
        
        Args:
            caption_token_ids: Tokenized captions (batch_size, seq_len)
            image_embeddings: Image embeddings (batch_size, embed_dim)
            attention_mask: Attention mask for caption tokens
            labels: Target tokens for loss computation
        
        Returns:
            GPT-2 model outputs including loss
        """
        # Get caption token embeddings
        caption_tokens = self.gpt.transformer.wte(caption_token_ids)
        
        # Generate image prefix tokens
        prefix_tokens = self.mapping_network(image_embeddings)  # (B, P, G)
        
        # Add optional task prefix
        if self.task_prefix_embeds is not None:
            batch_size = image_embeddings.shape[0]
            task_prefix_tokens = self.task_prefix_embeds.unsqueeze(0).expand(
                batch_size, -1, -1
            )
            prefix_tokens = torch.cat((prefix_tokens, task_prefix_tokens), dim=1)
        
        total_prefix_length = prefix_tokens.shape[1]
        
        # Concatenate prefix + caption tokens
        input_tokens = torch.cat((prefix_tokens, caption_tokens), dim=1)
        
        # Adjust labels to ignore prefix tokens (set to -100)
        if labels is not None:
            dummy_labels = torch.full(
                (labels.shape[0], total_prefix_length),
                -100,
                dtype=torch.int64,
                device=caption_token_ids.device,
            )
            labels = torch.cat((dummy_labels, labels), dim=1)
        
        # Adjust attention mask for prefix tokens
        if attention_mask is not None:
            dummy_attention_mask = torch.ones(
                (attention_mask.shape[0], total_prefix_length),
                dtype=attention_mask.dtype,
                device=attention_mask.device,
            )
            attention_mask = torch.cat((dummy_attention_mask, attention_mask), dim=1)
        
        return self.gpt(
            inputs_embeds=input_tokens,
            labels=labels,
            attention_mask=attention_mask,
        )
    
    def generate(
        self,
        image_embeddings: torch.Tensor,
        max_length: int = 50,
        temperature: float = 1.0,
        top_p: float = 0.9,
        stop_token_id: int = 50256,
    ) -> torch.Tensor:
        """
        Generate captions for given image embeddings using nucleus sampling.
        
        Args:
            image_embeddings: Image embeddings (batch_size, embed_dim)
            max_length: Maximum number of tokens to generate
            temperature: Sampling temperature (0 = greedy, higher = more random)
            top_p: Nucleus sampling threshold (keep top tokens with cumulative prob > top_p)
            stop_token_id: Token ID to stop generation (default: GPT-2 EOS token)
        
        Returns:
            Generated token IDs (batch_size, generated_length)
        """
        self.eval()
        device = image_embeddings.device
        batch_size = image_embeddings.shape[0]
        
        # Generate prefix tokens from image embeddings
        with torch.no_grad():
            prefix_tokens = self.mapping_network(image_embeddings)
        
        # Add optional task prefix
        if self.task_prefix_embeds is not None:
            task_prefix_tokens = self.task_prefix_embeds.unsqueeze(0).expand(
                batch_size, -1, -1
            )
            prefix_tokens = torch.cat((prefix_tokens, task_prefix_tokens), dim=1)
        
        current_input_tokens = prefix_tokens
        generated_tokens: list[torch.Tensor] = []
        is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
        
        # Autoregressive generation loop
        for _ in range(max_length):
            if is_finished.all():
                break
            
            with torch.no_grad():
                # Get next token logits
                outputs = self.gpt(inputs_embeds=current_input_tokens)
                next_token_logits = outputs.logits[:, -1, :]
                
                # Apply temperature scaling
                if temperature > 0:
                    next_token_logits = next_token_logits / temperature
                
                # Apply nucleus (top-p) sampling
                if top_p < 1.0 and temperature > 0:
                    next_token_logits[is_finished, :] = 0.0
                    sorted_logits, sorted_indices = torch.sort(
                        next_token_logits, descending=True
                    )
                    cumulative_probs = torch.cumsum(
                        F.softmax(sorted_logits, dim=-1), dim=-1
                    )
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                    sorted_indices_to_remove[:, 0] = 0
                    indices_to_remove = sorted_indices_to_remove.scatter(
                        1, sorted_indices, sorted_indices_to_remove
                    )
                    next_token_logits = next_token_logits.masked_fill(
                        indices_to_remove, float("-inf")
                    )
                
                # Sample next token
                if temperature == 0:
                    next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
                else:
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_token_id = torch.multinomial(probs, num_samples=1)
                
                # Check for stop token and mark finished sequences
                is_current_stop = next_token_id.squeeze(-1).eq(stop_token_id)
                is_finished = torch.logical_or(is_finished, is_current_stop)
                next_token_id[is_finished] = 0
                
                generated_tokens.append(next_token_id)
                
                # Append to input sequence
                next_token = self.gpt.transformer.wte(next_token_id)
                current_input_tokens = torch.cat((current_input_tokens, next_token), dim=1)
        
        if not generated_tokens:
            return torch.empty((batch_size, 0), dtype=torch.long, device=device)
        
        return torch.cat(generated_tokens, dim=1)

In [None]:
 DATASET AND DATA LOADING
# ==============================================================================

class CocoTestEmbeddingsDataset(Dataset):
    """
    Dataset for COCO test set pre-computed embeddings.
    
    Expected format of embeddings file:
        {
            "embeddings": tensor of shape (N, embed_dim),
            "filenames": list of N filenames (e.g., "COCO_test2014_000000123456.jpg")
        }
    
    Args:
        embeddings_path: Path to .pt file containing embeddings and filenames
    """
    
    def __init__(self, embeddings_path: str):
        data = torch.load(embeddings_path, map_location="cpu")
        self.image_embeddings = data["embeddings"]
        self.image_filenames = data["filenames"]
    
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, idx):
        # Extract image ID from filename (e.g., "COCO_test2014_000000123456.jpg" → 123456)
        fname = self.image_filenames[idx]
        image_id = int(fname.split("_")[-1].split(".")[0])
        return {
            "image_id": image_id,
            "image_embedding": self.image_embeddings[idx],
        }


def collate_fn(batch):
    """
    Custom collate function for batching embeddings.
    
    Args:
        batch: List of dictionaries from dataset
    
    Returns:
        Dictionary with batched tensors
    """
    ids = [b["image_id"] for b in batch]
    embs = torch.stack([b["image_embedding"] for b in batch], dim=0)
    return {
        "image_id": torch.tensor(ids, dtype=torch.long),
        "image_embedding": embs
    }


In [None]:
# MODEL SETUP AND CHECKPOINT LOADING
# ==============================================================================

print("\n" + "="*80)
print("LOADING MODEL")
print("="*80)

# Initialize tokenizer
gpt2_tokenizer = load_gpt2_tokenizer()
print(f"✓ Loaded GPT-2 tokenizer")

# Build mapping network with your hyperparameters
mapping_network = TransformerMappingNetwork(
    embed_dim=EMBED_DIM,
    gpt_dim=GPT_DIM,
    prefix_length=PREFIX_LENGTH,
    hidden_length=HIDDEN_LENGTH,
    num_layers=NUM_LAYERS,
)
print(f"✓ Built TransformerMappingNetwork (embed_dim={EMBED_DIM}, prefix_length={PREFIX_LENGTH})")

# Build complete captioning model
model = ImageCaptioningModel(
    mapping_network=mapping_network,
    tokenizer=gpt2_tokenizer,
    freeze_gpt_weights=True,  # GPT-2 weights are frozen during inference
).to(DEVICE)
print(f"✓ Built ImageCaptioningModel")

# Load trained weights from checkpoint
print(f"\nLoading checkpoint from: {MODEL_CHECKPOINT_PATH}")
state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location="cpu")

# Extract only mapping_network weights and strip the "mapping_network." prefix
# (Checkpoint keys are like "mapping_network.linear.weight", but we need "linear.weight")
mapping_state = {
    k.split("mapping_network.", 1)[1]: v
    for k, v in state_dict.items()
    if k.startswith("mapping_network.")
}

# Load the weights into the mapping network
model.mapping_network.load_state_dict(mapping_state)
model.eval()
print(f"✓ Loaded mapping network weights successfully")


In [None]:
# TEST DATASET SETUP
# ==============================================================================

print("\n" + "="*80)
print("LOADING TEST DATA")
print("="*80)

test_dataset = CocoTestEmbeddingsDataset(TEST_EMBEDDINGS_PATH)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,  # Set to 0 to avoid multiprocessing issues with GPU tensors
    collate_fn=collate_fn,
)

print(f"✓ Loaded test dataset: {len(test_dataset)} images")
print(f"✓ Batch size: {BATCH_SIZE}")


# ==============================================================================
# CAPTION GENERATION
# ==============================================================================

print("\n" + "="*80)
print("GENERATING CAPTIONS")
print("="*80)

predictions: list[dict] = []
seen: set[int] = set()

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Generating captions", unit="batch"):
        image_ids = batch["image_id"]
        image_embeddings = batch["image_embedding"].to(DEVICE)
        
        # Generate captions
        generated_ids = model.generate(
            image_embeddings=image_embeddings,
            max_length=MAX_LENGTH,
            temperature=TEMPERATURE,
            top_p=TOP_P,
        )
        
        # Decode to text
        captions = gpt2_tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True
        )
        
        # Add to predictions list (ensure one caption per image)
        for img_id, cap in zip(image_ids, captions):
            img_id_int = int(img_id.item())
            if img_id_int not in seen:
                predictions.append({"image_id": img_id_int, "caption": cap})
                seen.add(img_id_int)


# ==============================================================================
# SAVE RESULTS
# ==============================================================================

print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)

with open(OUTPUT_FILENAME, "w") as f:
    json.dump(predictions, f)

print(f"✓ Saved {len(predictions)} predictions to {OUTPUT_FILENAME}")
print(f"✓ Ready for submission to COCO evaluation server!")
print("\nNext steps:")
print("1. Download the results.json file from Kaggle")
print("2. Submit to: https://competitions.codalab.org/competitions/3221")
print("3. Check the 'Participate > Submit / View Results' tab")