# Shesha Tutorial: Vision Model Analysis

This notebook demonstrates how to use Shesha to analyze the geometric stability of vision model representations.

**What you'll learn:**
- How to extract embeddings from vision models
- How to compare stability across architectures (ResNet, ViT, etc.)
- How to measure class separability

**Requirements:**
```bash
pip install shesha-geometry torch torchvision
```

## 1. Setup and Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import shesha

# Configuration
SEED = 320
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

## 2. Load Dataset

We'll use CIFAR-10 for this example.

In [None]:
# ImageNet normalization
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 test set
dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Subset for speed
N_SAMPLES = 1000
indices = np.random.choice(len(dataset), N_SAMPLES, replace=False)
subset = torch.utils.data.Subset(dataset, indices)
dataloader = DataLoader(subset, batch_size=64, shuffle=False)

print(f"Dataset: {N_SAMPLES} images, 10 classes")

## 3. Extract Embeddings from a Model

In [None]:
def extract_features(model, dataloader, device):
    """Extract features from penultimate layer."""
    features = []
    labels = []

    # Remove final classification layer
    if hasattr(model, 'fc'):
        model.fc = nn.Identity()
    elif hasattr(model, 'classifier'):
        model.classifier = nn.Identity()
    elif hasattr(model, 'head'):
        model.head = nn.Identity()

    model = model.to(device)
    model.eval()

    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device)
            output = model(images)

            # Flatten if needed
            if output.dim() > 2:
                output = output.view(output.size(0), -1)

            features.append(output.cpu().numpy())
            labels.append(targets.numpy())

    return np.vstack(features), np.concatenate(labels)

# Test with ResNet-18
print("Loading ResNet-18...")
resnet = models.resnet18(weights='IMAGENET1K_V1')
features, labels = extract_features(resnet, dataloader, DEVICE)
print(f"Feature shape: {features.shape}")

## 4. Analyze Single Model

In [None]:
# Compute Shesha metrics
stability = shesha.feature_split(features, n_splits=30, seed=SEED)
alignment = shesha.supervised_alignment(features, labels, seed=SEED)
var_ratio = shesha.variance_ratio(features, labels)

print("ResNet-18 Metrics:")
print(f"  Stability (feature_split):    {stability:.3f}")
print(f"  Task Alignment (supervised):  {alignment:.3f}")
print(f"  Variance Ratio:               {var_ratio:.3f}")

## 5. Compare Architectures

In [None]:
def analyze_model(model_fn, model_name, dataloader):
    """Analyze a vision model."""
    print(f"Analyzing {model_name}...")
    model = model_fn(weights='IMAGENET1K_V1')
    features, labels = extract_features(model, dataloader, DEVICE)

    return {
        'model': model_name,
        'stability': shesha.feature_split(features, seed=SEED),
        'alignment': shesha.supervised_alignment(features, labels, seed=SEED),
        'var_ratio': shesha.variance_ratio(features, labels),
        'dim': features.shape[1]
    }

# Compare models
model_configs = [
    (models.resnet18, "ResNet-18"),
    (models.resnet50, "ResNet-50"),
    (models.vgg16, "VGG-16"),
]

results = []
for model_fn, name in model_configs:
    results.append(analyze_model(model_fn, name, dataloader))

# Print comparison
print(f"\n{'Model':<12} {'Dim':>6} {'Stability':>10} {'Alignment':>10} {'VarRatio':>10}")
print("-" * 52)
for r in results:
    print(f"{r['model']:<12} {r['dim']:>6} {r['stability']:>10.3f} {r['alignment']:>10.3f} {r['var_ratio']:>10.3f}")

## 6. Visualization

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))

models_names = [r['model'] for r in results]
x = np.arange(len(models_names))
width = 0.25

ax.bar(x - width, [r['stability'] for r in results], width, label='Stability')
ax.bar(x, [r['alignment'] for r in results], width, label='Alignment')
ax.bar(x + width, [r['var_ratio'] for r in results], width, label='Var Ratio')

ax.set_ylabel('Score')
ax.set_title('Vision Model Comparison')
ax.set_xticks(x)
ax.set_xticklabels(models_names)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('vision_comparison.png', dpi=150)
plt.show()

## 7. Quick Reference

```python
import shesha

# Extract features from your vision model
features = model(images)  # shape: (n_samples, n_features)

# Measure stability
stability = shesha.feature_split(features, seed=320)

# Measure task alignment
alignment = shesha.supervised_alignment(features, labels, seed=320)

# Quick separability check
var_ratio = shesha.variance_ratio(features, labels)
```