In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import json
import random
import time
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

from transformers import CLIPProcessor, CLIPModel, CLIPVisionModel, CLIPTextModel
from transformers import CLIPTokenizer

from PIL import Image
from tqdm import tqdm
import numpy as np


@dataclass
class Config:
    """Central configuration class for all hyperparameters and paths."""

    # Model
    model_name: str = "openai/clip-vit-large-patch14-336"

    # Dataset paths
    dataset_path: str = "/content/drive/MyDrive/GD-VCR-Dataset/questions.jsonl"
    image_base_path: str = "/content/drive/MyDrive/GD-VCR-Dataset/"

    # VPT Hyperparameters
    num_prompt_tokens: int = 20
    prompt_dim: int = 1024  # Must match CLIP's hidden dimension
    prompt_dropout: float = 0.1

    # Training Hyperparameters
    batch_size: int = 16
    learning_rate: float = 1e-3
    num_epochs: int = 10
    warmup_steps: int = 100
    weight_decay: float = 0.01

    # System
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 42
    num_workers: int = 2

    # Validation split
    val_split: float = 0.3
    max_text_length: int = 77


def set_seed(seed: int) -> None:
    """Set random seeds for reproducibility across all libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # For complete reproducibility (may slow down training)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def load_jsonl(file_path: str) -> List[Dict]:
    """Load a JSONL file and return a list of dictionaries."""
    data = []
    print(f"Loading dataset from: {file_path}")

    with open(file_path, "r", encoding="utf-8") as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if line:
                try:
                    record = json.loads(line)
                    data.append(record)
                except json.JSONDecodeError as e:
                    print(f"Warning: Could not parse line {line_num}: {e}")
                    continue

    print(f"Loaded {len(data)} records from {file_path}")
    return data


def validate_record(record: Dict) -> bool:
    """Validate that a GD-VCR record has all required fields."""
    required_fields = ["img_fn", "objects", "question", "answer_choices", "answer_label"]

    for field in required_fields:
        if field not in record:
            return False

    if record["answer_label"] < 0 or record["answer_label"] >= len(record["answer_choices"]):
        return False

    return True


def resolve_object_reference(token: Any, objects: List[str]) -> str:
    """Resolve an object reference token to its actual object name."""
    if isinstance(token, list) and len(token) == 1 and isinstance(token[0], int):
        index = token[0]
        if 0 <= index < len(objects):
            return objects[index]
        else:
            return f"object{index}"

    if isinstance(token, int):
        if 0 <= token < len(objects):
            return objects[token]
        else:
            return f"object{token}"

    return str(token)


def tokens_to_sentence(tokens: List[Any], objects: List[str]) -> str:
    """Convert a list of tokens (with object references) into a natural sentence."""
    resolved_tokens = [resolve_object_reference(token, objects) for token in tokens]
    sentence = " ".join(resolved_tokens)

    punctuation = [".", ",", "!", "?", "'", '"', ":", ";"]
    for punct in punctuation:
        sentence = sentence.replace(f" {punct}", punct)

    sentence = sentence.replace(" ' s", "'s")
    sentence = sentence.replace(" 's", "'s")
    sentence = " ".join(sentence.split())

    return sentence


def preprocess_sample(record: Dict) -> Dict:
    """Preprocess a single GD-VCR record into a format suitable for CLIP."""
    objects = record["objects"]
    question = tokens_to_sentence(record["question"], objects)

    answer_choices = []
    for choice_tokens in record["answer_choices"]:
        choice_sentence = tokens_to_sentence(choice_tokens, objects)
        answer_choices.append(choice_sentence)

    return {
        "question": question,
        "answer_choices": answer_choices,
        "answer_label": record["answer_label"],
        "img_fn": record["img_fn"],
        "objects": objects,
    }


def create_qa_text(question: str, answer: str) -> str:
    """Combine question and answer into a single text for CLIP encoding."""
    return f"Question: {question} Answer: {answer}"


class GDVCRDataset(Dataset):
    """PyTorch Dataset for GD-VCR visual commonsense reasoning."""

    def __init__(
        self,
        jsonl_path: str,
        processor: CLIPProcessor,
        config: Config,
        image_base_path: Optional[str] = None,
    ):
        self.processor = processor
        self.config = config
        self.image_base_path = image_base_path or config.image_base_path

        raw_data = load_jsonl(jsonl_path)

        self.data = []
        for record in raw_data:
            if validate_record(record):
                preprocessed = preprocess_sample(record)
                self.data.append(preprocessed)
            else:
                print(f"Warning: Skipping invalid record")

        print(f"Dataset initialized with {len(self.data)} valid samples")

    def __len__(self) -> int:
        return len(self.data)

    def load_image(self, img_fn: str) -> Image.Image:
        """Load an image from disk."""
        img_path = os.path.join(self.image_base_path, img_fn)

        try:
            image = Image.open(img_path).convert("RGB")
            return image
        except FileNotFoundError:
            print(f"Warning: Image not found: {img_path}")
            return Image.new("RGB", (224, 224), color="gray")
        except Exception as e:
            print(f"Warning: Error loading {img_path}: {e}")
            return Image.new("RGB", (224, 224), color="gray")

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.data[idx]

        image = self.load_image(sample["img_fn"])
        image_inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = image_inputs["pixel_values"].squeeze(0)

        question = sample["question"]
        answer_choices = sample["answer_choices"]

        qa_texts = []
        for answer in answer_choices:
            qa_text = create_qa_text(question, answer)
            qa_texts.append(qa_text)

        text_inputs = self.processor(
            text=qa_texts,
            padding="max_length",
            truncation=True,
            max_length=self.config.max_text_length,
            return_tensors="pt",
        )

        return {
            "pixel_values": pixel_values,
            "input_ids": text_inputs["input_ids"],
            "attention_mask": text_inputs["attention_mask"],
            "label": torch.tensor(sample["answer_label"], dtype=torch.long),
            "num_choices": len(answer_choices),
        }


def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """Custom collate function for batching GD-VCR samples."""
    batch_size = len(batch)

    max_choices = max(sample["num_choices"] for sample in batch)
    max_length = batch[0]["input_ids"].shape[1]

    pixel_values = torch.stack([sample["pixel_values"] for sample in batch])
    labels = torch.stack([sample["label"] for sample in batch])
    num_choices = torch.tensor([sample["num_choices"] for sample in batch])

    input_ids = torch.zeros(batch_size, max_choices, max_length, dtype=torch.long)
    attention_mask = torch.zeros(batch_size, max_choices, max_length, dtype=torch.long)

    for i, sample in enumerate(batch):
        n_choices = sample["num_choices"]
        input_ids[i, :n_choices] = sample["input_ids"]
        attention_mask[i, :n_choices] = sample["attention_mask"]

    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "num_choices": num_choices,
    }


class BaselineCLIP(nn.Module):
    """Baseline CLIP model for visual commonsense reasoning (zero-shot, all frozen)."""

    def __init__(self, config: Config):
        super().__init__()

        print(f"Loading CLIP model: {config.model_name}")
        self.clip = CLIPModel.from_pretrained(config.model_name)

        # Freeze all parameters
        for param in self.clip.parameters():
            param.requires_grad = False

        self.config = config
        self.to(config.device)

        print(f"Baseline CLIP loaded with {self.count_parameters():,} parameters (all frozen)")

    def count_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def count_trainable_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
        image_outputs = self.clip.get_image_features(pixel_values=pixel_values)

        if hasattr(image_outputs, "pooler_output"):
            image_features = image_outputs.pooler_output
        else:
            image_features = image_outputs

        image_features = F.normalize(image_features, dim=-1)
        return image_features

    def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        batch_size, num_choices, seq_len = input_ids.shape

        input_ids_flat = input_ids.view(-1, seq_len)
        attention_mask_flat = attention_mask.view(-1, seq_len)

        text_outputs = self.clip.get_text_features(
            input_ids=input_ids_flat, attention_mask=attention_mask_flat
        )

        if hasattr(text_outputs, "pooler_output"):
            text_features = text_outputs.pooler_output
        else:
            text_features = text_outputs

        text_features = F.normalize(text_features, dim=-1)
        text_features = text_features.view(batch_size, num_choices, -1)

        return text_features

    def forward(
        self,
        pixel_values: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        image_features = self.encode_image(pixel_values)
        text_features = self.encode_text(input_ids, attention_mask)

        image_features = image_features.unsqueeze(1)
        logits = torch.einsum("bcd,bcd->bc", image_features.expand_as(text_features), text_features)

        logit_scale = self.clip.logit_scale.exp()
        logits = logits * logit_scale

        return logits


class DeepVisualPromptTokens(nn.Module):
    """Learnable visual prompt tokens for VPT-Deep (separate prompts per layer)."""

    def __init__(
        self, num_prompts: int, hidden_dim: int, num_layers: int, dropout: float = 0.1
    ):
        super().__init__()

        self.num_prompts = num_prompts
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Separate prompts for each transformer layer
        self.prompts = nn.Parameter(torch.zeros(num_layers, 1, num_prompts, hidden_dim))

        # Xavier initialization for proper gradient flow
        for layer_idx in range(num_layers):
            nn.init.xavier_uniform_(self.prompts[layer_idx])

        self.dropout = nn.Dropout(dropout)

        total_prompt_params = num_layers * num_prompts * hidden_dim
        print(f"Created VPT-Deep with {num_prompts} prompts x {num_layers} layers")
        print(f"Total prompt parameters: {total_prompt_params:,}")

    def forward(self, batch_size: int, layer_idx: int) -> torch.Tensor:
        layer_prompts = self.prompts[layer_idx]
        prompts = layer_prompts.expand(batch_size, -1, -1)
        prompts = self.dropout(prompts)
        return prompts


class CLIPWithVPT(nn.Module):
    """CLIP model enhanced with Visual Prompt Tuning (VPT-Deep)."""

    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        print(f"Loading CLIP model for VPT-Deep: {config.model_name}")
        self.clip = CLIPModel.from_pretrained(config.model_name)

        self.num_layers = len(self.clip.vision_model.encoder.layers)
        print(f"Vision encoder has {self.num_layers} transformer layers")

        # Freeze CLIP backbone - only prompts are trainable
        for param in self.clip.parameters():
            param.requires_grad = False

        self.visual_prompts = DeepVisualPromptTokens(
            num_prompts=config.num_prompt_tokens,
            hidden_dim=config.prompt_dim,
            num_layers=self.num_layers,
            dropout=config.prompt_dropout,
        )

        self.to(config.device)

        total_params = self.count_parameters()
        trainable_params = self.count_trainable_parameters()
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        print(f"Trainable ratio: {100*trainable_params/total_params:.4f}%")

    def count_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def count_trainable_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def get_visual_embeddings_with_prompts(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Get visual embeddings with VPT-Deep prompts injected at every layer."""
        batch_size = pixel_values.shape[0]
        num_prompts = self.config.num_prompt_tokens

        vision_model = self.clip.vision_model

        embeddings = vision_model.embeddings(pixel_values)

        cls_token = embeddings[:, :1, :]
        patch_embeddings = embeddings[:, 1:, :]

        initial_prompts = self.visual_prompts(batch_size, layer_idx=0)

        # Inject prompts: [CLS] + [Prompts] + [Patches]
        hidden_states = torch.cat([cls_token, initial_prompts, patch_embeddings], dim=1)

        hidden_states = vision_model.pre_layrnorm(hidden_states)

        for layer_idx, encoder_layer in enumerate(vision_model.encoder.layers):
            layer_outputs = encoder_layer(
                hidden_states,
                attention_mask=None,
                causal_attention_mask=None,
                output_attentions=False,
            )
            hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs

            # Replace prompts with fresh layer-specific prompts (except last layer)
            if layer_idx < self.num_layers - 1:
                next_layer_prompts = self.visual_prompts(batch_size, layer_idx=layer_idx + 1)
                cls_token = hidden_states[:, :1, :]
                patch_embeddings = hidden_states[:, 1 + num_prompts :, :]
                hidden_states = torch.cat([cls_token, next_layer_prompts, patch_embeddings], dim=1)

        pooled_output = hidden_states[:, 0, :]
        pooled_output = vision_model.post_layernorm(pooled_output)

        image_features = self.clip.visual_projection(pooled_output)
        image_features = F.normalize(image_features, dim=-1)

        return image_features

    def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        batch_size, num_choices, seq_len = input_ids.shape

        input_ids_flat = input_ids.view(-1, seq_len)
        attention_mask_flat = attention_mask.view(-1, seq_len)

        text_output = self.clip.get_text_features(
            input_ids=input_ids_flat, attention_mask=attention_mask_flat
        )

        if hasattr(text_output, "pooler_output"):
            text_features = text_output.pooler_output
        elif hasattr(text_output, "last_hidden_state"):
            text_features = text_output.last_hidden_state[:, 0, :]
        elif isinstance(text_output, torch.Tensor):
            text_features = text_output
        else:
            text_features = text_output[0] if isinstance(text_output, tuple) else text_output

        text_features = F.normalize(text_features, dim=-1)
        text_features = text_features.view(batch_size, num_choices, -1)

        return text_features

    def forward(
        self,
        pixel_values: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        image_features = self.get_visual_embeddings_with_prompts(pixel_values)
        text_features = self.encode_text(input_ids, attention_mask)

        image_features = image_features.unsqueeze(1)
        logits = torch.einsum("bcd,bcd->bc", image_features.expand_as(text_features), text_features)

        logit_scale = self.clip.logit_scale.exp()
        logits = logits * logit_scale

        return logits


class Trainer:
    """Training manager for VPT models."""

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: Config,
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config

        trainable_params = [p for p in model.parameters() if p.requires_grad]
        self.optimizer = AdamW(trainable_params, lr=config.learning_rate, weight_decay=config.weight_decay)

        self.criterion = nn.CrossEntropyLoss()

        total_steps = len(train_loader) * config.num_epochs

        def lr_lambda(step):
            if step < config.warmup_steps:
                return float(step) / float(max(1, config.warmup_steps))
            return 1.0

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)

        self.global_step = 0
        self.best_val_accuracy = 0.0

        print(f"Trainer initialized:")
        print(f"  - Training samples: {len(train_loader.dataset)}")
        print(f"  - Validation samples: {len(val_loader.dataset)}")
        print(f"  - Epochs: {config.num_epochs}")
        print(f"  - Batch size: {config.batch_size}")
        print(f"  - Learning rate: {config.learning_rate}")

    def train_epoch(self, epoch: int) -> Dict[str, float]:
        self.model.train()

        total_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(
            self.train_loader,
            desc=f"Epoch {epoch+1}/{self.config.num_epochs}",
            leave=True,
        )

        for batch in progress_bar:
            pixel_values = batch["pixel_values"].to(self.config.device)
            input_ids = batch["input_ids"].to(self.config.device)
            attention_mask = batch["attention_mask"].to(self.config.device)
            labels = batch["labels"].to(self.config.device)

            self.optimizer.zero_grad()
            logits = self.model(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
            )

            loss = self.criterion(logits, labels)
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(
                [p for p in self.model.parameters() if p.requires_grad], max_norm=1.0
            )

            self.optimizer.step()
            self.scheduler.step()

            total_loss += loss.item()
            predictions = logits.argmax(dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

            self.global_step += 1

            progress_bar.set_postfix(
                {"loss": f"{loss.item():.4f}", "acc": f"{100*correct/total:.2f}%"}
            )

        avg_loss = total_loss / len(self.train_loader)
        accuracy = correct / total

        return {"loss": avg_loss, "accuracy": accuracy}

    @torch.no_grad()
    def validate(self) -> Dict[str, float]:
        self.model.eval()

        total_loss = 0.0
        correct = 0
        total = 0

        for batch in tqdm(self.val_loader, desc="Validating", leave=False):
            pixel_values = batch["pixel_values"].to(self.config.device)
            input_ids = batch["input_ids"].to(self.config.device)
            attention_mask = batch["attention_mask"].to(self.config.device)
            labels = batch["labels"].to(self.config.device)

            logits = self.model(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
            )

            loss = self.criterion(logits, labels)

            total_loss += loss.item()
            predictions = logits.argmax(dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

        avg_loss = total_loss / len(self.val_loader)
        accuracy = correct / total

        return {"loss": avg_loss, "accuracy": accuracy}

    def train(self) -> Dict[str, List[float]]:
        history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

        print("\n" + "=" * 60)
        print("STARTING TRAINING")
        print("=" * 60 + "\n")

        start_time = time.time()

        for epoch in range(self.config.num_epochs):
            train_metrics = self.train_epoch(epoch)
            val_metrics = self.validate()

            print(f"\nEpoch {epoch+1}/{self.config.num_epochs}:")
            print(f"  Train Loss: {train_metrics['loss']:.4f}, Train Acc: {100*train_metrics['accuracy']:.2f}%")
            print(f"  Val Loss: {val_metrics['loss']:.4f}, Val Acc: {100*val_metrics['accuracy']:.2f}%")

            if val_metrics["accuracy"] > self.best_val_accuracy:
                self.best_val_accuracy = val_metrics["accuracy"]
                print(f"  → New best validation accuracy: {100*self.best_val_accuracy:.2f}%")

            history["train_loss"].append(train_metrics["loss"])
            history["train_acc"].append(train_metrics["accuracy"])
            history["val_loss"].append(val_metrics["loss"])
            history["val_acc"].append(val_metrics["accuracy"])

        total_time = time.time() - start_time

        print("\n" + "=" * 60)
        print("TRAINING COMPLETE")
        print("=" * 60)
        print(f"Total time: {total_time/60:.2f} minutes")
        print(f"Best validation accuracy: {100*self.best_val_accuracy:.2f}%")

        return history


@torch.no_grad()
def evaluate_model(
    model: nn.Module, data_loader: DataLoader, config: Config, model_name: str = "Model"
) -> Dict[str, float]:
    """Evaluate a model on a dataset."""
    model.eval()

    correct = 0
    total = 0
    all_predictions = []
    all_labels = []

    print(f"\nEvaluating {model_name}...")

    for batch in tqdm(data_loader, desc="Evaluating"):
        pixel_values = batch["pixel_values"].to(config.device)
        input_ids = batch["input_ids"].to(config.device)
        attention_mask = batch["attention_mask"].to(config.device)
        labels = batch["labels"].to(config.device)

        logits = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        predictions = logits.argmax(dim=-1)

        correct += (predictions == labels).sum().item()
        total += labels.size(0)

        all_predictions.extend(predictions.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

    accuracy = correct / total

    print(f"\n{model_name} Results:")
    print(f"  Accuracy: {100*accuracy:.2f}%")
    print(f"  Correct: {correct}/{total}")

    return {
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "predictions": all_predictions,
        "labels": all_labels,
    }


def compare_models(
    baseline_results: Dict,
    vpt_results: Dict,
    baseline_params: int,
    vpt_trainable_params: int,
    vpt_training_time: float,
) -> None:
    """Print a comparison between baseline and VPT models."""
    print("\n" + "=" * 70)
    print("MODEL COMPARISON: BASELINE CLIP vs CLIP + VPT")
    print("=" * 70)

    print("\n ACCURACY COMPARISON:")
    print("-" * 40)
    print(f"  Baseline CLIP (Zero-shot): {100*baseline_results['accuracy']:.2f}%")
    print(f"  CLIP + VPT (Trained):      {100*vpt_results['accuracy']:.2f}%")

    improvement = vpt_results["accuracy"] - baseline_results["accuracy"]
    print(f"\n  Improvement: {100*improvement:+.2f}%")

    print("\n PARAMETER COMPARISON:")
    print("-" * 40)
    print(f"  Baseline trainable params: 0 (all frozen)")
    print(f"  VPT trainable params:      {vpt_trainable_params:,}")
    print(f"  Trainable ratio:           {100*vpt_trainable_params/baseline_params:.4f}%")

    print("\n TRAINING TIME:")
    print("-" * 40)
    print(f"  Baseline: 0 (no training)")
    print(f"  VPT:      {vpt_training_time/60:.2f} minutes")

    print("\n" + "=" * 70)

    print("\n INTERPRETATION:")
    if improvement > 0:
        print(f"  → VPT improved accuracy by {100*improvement:.2f}% over zero-shot CLIP")
        print(f"  → This was achieved by training only {vpt_trainable_params:,} parameters")
        print(f"  → VPT is {baseline_params/vpt_trainable_params:.0f}x more parameter-efficient than full fine-tuning")
    else:
        print(f"  → VPT did not improve over baseline in this experiment")
        print(f"  → Consider: adjusting prompt count, learning rate, or training longer")


def save_model(model: CLIPWithVPT, save_path: str, config: Config) -> None:
    """Save the trained VPT model (only prompt tokens, backbone is frozen)."""
    save_dir = os.path.dirname(save_path)
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir)

    checkpoint = {
        "visual_prompts": model.visual_prompts.state_dict(),
        "config": {
            "model_name": config.model_name,
            "num_prompt_tokens": config.num_prompt_tokens,
            "prompt_dim": config.prompt_dim,
            "prompt_dropout": config.prompt_dropout,
        },
    }

    torch.save(checkpoint, save_path)
    print(f"\n Model saved to: {save_path}")
    print(f"   Saved prompt parameters: {model.count_trainable_parameters():,}")


def load_model(load_path: str, config: Optional[Config] = None) -> CLIPWithVPT:
    """Load a saved VPT model."""
    checkpoint = torch.load(load_path, map_location="cpu")

    if config is None:
        config = Config()
        saved_config = checkpoint["config"]
        config.model_name = saved_config["model_name"]
        config.num_prompt_tokens = saved_config["num_prompt_tokens"]
        config.prompt_dim = saved_config["prompt_dim"]
        config.prompt_dropout = saved_config["prompt_dropout"]

    model = CLIPWithVPT(config)
    model.visual_prompts.load_state_dict(checkpoint["visual_prompts"])

    print(f"\n Model loaded from: {load_path}")
    print(f"   Loaded prompt parameters: {model.count_trainable_parameters():,}")

    return model


def create_data_loaders(dataset: GDVCRDataset, config: Config) -> Tuple[DataLoader, DataLoader]:
    """Split dataset and create train/val data loaders."""
    total_size = len(dataset)
    val_size = int(total_size * config.val_split)
    train_size = total_size - val_size

    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(config.seed),
    )

    print(f"\nDataset split:")
    print(f"  Training samples: {len(train_dataset)}")
    print(f"  Validation samples: {len(val_dataset)}")

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        collate_fn=collate_fn,
        pin_memory=True if config.device == "cuda" else False,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        collate_fn=collate_fn,
        pin_memory=True if config.device == "cuda" else False,
    )

    return train_loader, val_loader


def run_experiment(config: Config) -> None:
    """Run the complete VPT experiment."""
    print("\n" + "=" * 70)
    print("VISUAL PROMPT TUNING EXPERIMENT")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Model: {config.model_name}")
    print(f"  Device: {config.device}")
    print(f"  Prompt tokens: {config.num_prompt_tokens}")
    print(f"  Learning rate: {config.learning_rate}")
    print(f"  Batch size: {config.batch_size}")
    print(f"  Epochs: {config.num_epochs}")

    print("\n[Step 1] Setting random seed for reproducibility...")
    set_seed(config.seed)

    print("\n[Step 2] Loading CLIP processor...")
    processor = CLIPProcessor.from_pretrained(config.model_name)

    print("\n[Step 3] Loading GD-VCR dataset...")
    dataset = GDVCRDataset(
        jsonl_path=config.dataset_path,
        processor=processor,
        config=config,
        image_base_path=config.image_base_path,
    )

    print("\n[Step 4] Creating data loaders...")
    train_loader, val_loader = create_data_loaders(dataset, config)

    print("\n[Step 5] Evaluating baseline CLIP (zero-shot)...")
    baseline_model = BaselineCLIP(config)
    baseline_results = evaluate_model(baseline_model, val_loader, config, "Baseline CLIP")
    baseline_params = baseline_model.count_parameters()

    del baseline_model
    if config.device == "cuda":
        torch.cuda.empty_cache()

    print("\n[Step 6] Training CLIP + VPT...")
    vpt_model = CLIPWithVPT(config)
    vpt_trainable_params = vpt_model.count_trainable_parameters()

    trainer = Trainer(
        model=vpt_model, train_loader=train_loader, val_loader=val_loader, config=config
    )

    start_time = time.time()
    history = trainer.train()
    training_time = time.time() - start_time

    print("\n[Step 7] Evaluating trained CLIP + VPT...")
    vpt_results = evaluate_model(vpt_model, val_loader, config, "CLIP + VPT")

    print("\n[Step 8] Comparing models...")
    compare_models(
        baseline_results=baseline_results,
        vpt_results=vpt_results,
        baseline_params=baseline_params,
        vpt_trainable_params=vpt_trainable_params,
        vpt_training_time=training_time,
    )

    print("\n TRAINING HISTORY SUMMARY:")
    print("-" * 40)
    print(f"  Initial train acc: {100*history['train_acc'][0]:.2f}%")
    print(f"  Final train acc:   {100*history['train_acc'][-1]:.2f}%")
    print(f"  Initial val acc:   {100*history['val_acc'][0]:.2f}%")
    print(f"  Final val acc:     {100*history['val_acc'][-1]:.2f}%")

    print("\n[Step 9] Saving trained model...")
    save_path = f"clip_model_vpt_deep_{config.num_prompt_tokens}.pt"
    save_model(vpt_model, save_path, config)

    print("\n" + "=" * 70)
    print("EXPERIMENT COMPLETE")
    print("=" * 70)

    return history, vpt_model


def main():
    """Main entry point for the VPT experiment."""
    config = Config()

    try:
        history, model = run_experiment(config)
        print("\n Experiment completed successfully!")
        return history, model
    except FileNotFoundError as e:
        print(f"\n Error: Dataset file not found: {e}")
        print("Please update the dataset_path in the Config class.")
        raise
    except Exception as e:
        print(f"\n Error during experiment: {e}")
        raise


def quick_test():
    """Quick test function to verify the pipeline works with synthetic data."""
    print("Running quick test with synthetic data...")

    config = Config()
    config.num_epochs = 1
    config.batch_size = 2

    synthetic_data = [
        {
            "img_fn": "test.jpg",
            "objects": ["person", "dog", "car"],
            "question": ["What", "is", [0], "doing", "?"],
            "answer_choices": [
                ["Walking", "the", [1], "."],
                ["Driving", "the", [2], "."],
                ["Sleeping", "."],
                ["Running", "."],
            ],
            "answer_label": 0,
        }
    ]

    with open("test_data.jsonl", "w") as f:
        for item in synthetic_data:
            f.write(json.dumps(item) + "\n")

    test_image = Image.new("RGB", (224, 224), color="white")
    test_image.save("test.jpg")

    print("\nTesting preprocessing...")
    for record in synthetic_data:
        processed = preprocess_sample(record)
        print(f"  Question: {processed['question']}")
        print(f"  Answers: {processed['answer_choices']}")

    print("\n Quick test passed! The pipeline is working.")

    os.remove("test_data.jsonl")
    os.remove("test.jpg")


def install_dependencies():
    """Install required dependencies for Google Colab."""
    import subprocess
    import sys

    packages = ["torch", "transformers", "Pillow", "tqdm"]

    for package in packages:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])

    print("\n All dependencies installed!")

In [None]:
if __name__ == "__main__":
    main()