# CLIP Baseline Experiments
## Zero-Shot, Linear Probe, and Few-Shot Evaluation

This notebook demonstrates the complete evaluation pipeline for CLIP baseline model.

## Setup

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from models.clip_baseline import CLIPBaseline
from utils.config import CLIPConfig
from utils.helpers import seed_everything
from evaluation.metrics import MetricsTracker
from training.trainer_clip import CLIPTrainer
from datasets.dataloaders import DatasetFactory
from utils.templates import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Configuration

In [None]:
# Set random seed
seed_everything(42)

# Load configuration
config = CLIPConfig(
    model_name="ViT-B-32",
    pretrained_tag="openai",
    batch_size=128,
    seed=42
)

print(f"Device: {config.device}")
print(f"Model: {config.model_name}")

## Initialize Model

In [None]:
# Initialize CLIP baseline
model = CLIPBaseline(config)

# Initialize metrics tracker
metrics = MetricsTracker("CLIP_BASELINE_NOTEBOOK", config.results_dir)
metrics.track_parameters(model)

# Initialize trainer
trainer = CLIPTrainer(model, config, metrics)

print("\n✓ Model initialized successfully")

## Experiment 1: Zero-Shot Classification

In [None]:
print("=" * 60)
print("ZERO-SHOT CLASSIFICATION")
print("=" * 60)

# Load test datasets
datasets = DatasetFactory.get_zeroshot_config(model.preprocess, config.data_root)

zero_shot_results = {}

for dataset_name, dataset_config in datasets.items():
    print(f"\nEvaluating {dataset_name}...")
    
    # Create text classifier
    text_classifier = trainer.create_text_classifier(
        dataset_config['class_names'],
        dataset_config['templates']
    )
    
    # Run zero-shot evaluation
    accuracy, predictions, labels = trainer.zero_shot(
        dataset_config['dataset'],
        text_classifier
    )
    
    zero_shot_results[dataset_name] = {
        'accuracy': accuracy,
        'predictions': predictions,
        'labels': labels
    }

print("\n" + "="*60)
print("ZERO-SHOT RESULTS SUMMARY")
print("="*60)
for name, result in zero_shot_results.items():
    print(f"{name:20} {result['accuracy']:.2f}%")

## Experiment 2: Linear Probe Evaluation

In [None]:
print("=" * 60)
print("LINEAR PROBE EVALUATION")
print("=" * 60)

# Load train/test datasets
datasets = DatasetFactory.get_linear_probe_datasets(model.preprocess, config.data_root)

linear_probe_results = {}

for dataset_name, dataset_splits in datasets.items():
    print(f"\nEvaluating {dataset_name}...")
    
    # Run linear probe
    accuracy, predictions, labels = trainer.linear_probe(
        dataset_splits['train'],
        dataset_splits['test']
    )
    
    linear_probe_results[dataset_name] = {
        'accuracy': accuracy,
        'predictions': predictions,
        'labels': labels
    }

print("\n" + "="*60)
print("LINEAR PROBE RESULTS SUMMARY")
print("="*60)
for name, result in linear_probe_results.items():
    print(f"{name:20} {result['accuracy']:.2f}%")

## Experiment 3: Few-Shot Learning

In [None]:
print("=" * 60)
print("FEW-SHOT EVALUATION")
print("=" * 60)

# Select one dataset for few-shot analysis
dataset_name = "CIFAR100"
dataset_splits = datasets[dataset_name]

print(f"\nEvaluating {dataset_name} with k-shot learning...")

# Run few-shot evaluation
few_shot_results, test_labels = trainer.few_shot(
    dataset_splits['train'],
    dataset_splits['test'],
    k_shots=[1, 2, 4, 8, 16, 32]
)

print("\n" + "="*60)
print("FEW-SHOT RESULTS")
print("="*60)
for k_shot, result in few_shot_results.items():
    print(f"{k_shot:10} {result['accuracy']:.2f}%")

## Visualization

In [None]:
# Plot Zero-Shot vs Linear Probe comparison
fig, ax = plt.subplots(figsize=(12, 6))

dataset_names = list(zero_shot_results.keys())
zs_accuracies = [zero_shot_results[name]['accuracy'] for name in dataset_names]
lp_accuracies = [linear_probe_results[name]['accuracy'] for name in dataset_names]

x = np.arange(len(dataset_names))
width = 0.35

ax.bar(x - width/2, zs_accuracies, width, label='Zero-Shot', alpha=0.8)
ax.bar(x + width/2, lp_accuracies, width, label='Linear Probe', alpha=0.8)

ax.set_xlabel('Dataset')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Zero-Shot vs Linear Probe Performance')
ax.set_xticks(x)
ax.set_xticklabels(dataset_names, rotation=45, ha='right')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('../plots/clip_baseline_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Plot Few-Shot Learning Curve
fig, ax = plt.subplots(figsize=(10, 6))

k_values = [int(k.split('-')[0]) for k in few_shot_results.keys()]
accuracies = [few_shot_results[k]['accuracy'] for k in few_shot_results.keys()]

ax.plot(k_values, accuracies, marker='o', linewidth=2, markersize=8)
ax.set_xlabel('Number of shots per class (k)')
ax.set_ylabel('Accuracy (%)')
ax.set_title(f'Few-Shot Learning Curve - {dataset_name}')
ax.grid(True, alpha=0.3)
ax.set_xscale('log', base=2)

plt.tight_layout()
plt.savefig('../plots/clip_baseline_fewshot.png', dpi=300, bbox_inches='tight')
plt.show()

## Save Results

In [None]:
# Track final metrics
metrics.track_performance(
    accuracy=np.mean([r['accuracy'] for r in zero_shot_results.values()]),
    loss=0.0
)

# Save all metrics
metrics.save_metrics(run_id="notebook_experiment")
metrics.print_summary()

print("\n✓ Results saved successfully!")