## Shesha Tutorial: Measuring Representational Drift

This notebook demonstrates how to use `shesha.rdm_drift` and `shesha.rdm_similarity` to measure representational drift under various perturbations:

1. **Gaussian noise injection** - Adding noise to model weights
2. **LoRA adapters** - Simulating fine-tuning perturbations

**What you'll learn:**
- How to measure drift between two representations
- How drift scales with perturbation magnitude
- How to interpret drift values

**Requirements:**
```bash
pip install shesha-geometry transformers torch peft datasets
```

## 1. Setup and Imports

In [None]:
# Optional: Install dependencies
# !pip install shesha-geometry transformers torch peft datasets

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import warnings

# Shesha
import shesha

# Optional: LoRA support
try:
    from peft import get_peft_model, LoraConfig, TaskType
    HAS_PEFT = True
except ImportError:
    HAS_PEFT = False
    print("peft not installed - LoRA experiments will be skipped")

warnings.filterwarnings('ignore')

# Configuration
SEED = 320
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
BATCH_SIZE = 32 if DEVICE == "cuda" else 8
N_SAMPLES = 200
MAX_SEQ_LEN = 64

np.random.seed(SEED)
torch.manual_seed(SEED)

print(f"SHESHA version: {shesha.__version__}")
print(f"Device: {DEVICE}")
print(f"Available functions: {[x for x in dir(shesha) if not x.startswith('_')]}")

## 2. Load Dataset

In [None]:
print("Loading SST-2 dataset...")
dataset = load_dataset("glue", "sst2", split="validation")

# Sample balanced subset
df = pd.DataFrame({'text': dataset['sentence'], 'label': dataset['label']})
df = df.groupby('label').apply(lambda x: x.sample(N_SAMPLES // 2, random_state=SEED)).reset_index(drop=True)
texts = df['text'].tolist()
labels = df['label'].values

print(f"Loaded {len(texts)} samples")
print(f"Label distribution: {np.bincount(labels)}")
print(f"\nExample: '{texts[0][:60]}...'")

## 3. Helper Functions

In [None]:
def load_model(model_name):
    """Load a causal LM model."""
    print(f"Loading {model_name}...")

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"  # For causal LMs

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=DTYPE,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    ).to(DEVICE).eval()

    return model, tokenizer


def get_embeddings(model, tokenizer, texts, layer=-1):
    """Extract embeddings using mean pooling, L2 normalized."""
    model.eval()
    all_vecs = []

    for i in range(0, len(texts), BATCH_SIZE):
        batch = texts[i:i+BATCH_SIZE]
        inputs = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_SEQ_LEN
        ).to(DEVICE)

        with torch.no_grad():
            out = model(**inputs, output_hidden_states=True, return_dict=True)
            h = out.hidden_states[layer]

            # Mean pooling over non-padding tokens
            mask = inputs["attention_mask"].unsqueeze(-1)
            vecs = (h * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
            all_vecs.append(vecs.float().cpu().numpy())

    emb = np.vstack(all_vecs)

    # L2 normalize
    norms = np.linalg.norm(emb, axis=1, keepdims=True)
    emb = emb / np.maximum(norms, 1e-9)

    return emb


def inject_noise(model, alpha, seed=320):
    """Inject Gaussian noise into model parameters."""
    if alpha == 0:
        return

    torch.manual_seed(seed)
    with torch.no_grad():
        for p in model.parameters():
            if p.requires_grad and p.numel() > 1:
                std = float(p.std().item())
                if std > 0:
                    noise = torch.randn_like(p) * (std * alpha)
                    p.add_(noise)


def get_state_dict(model):
    """Copy model state dict to CPU."""
    return {k: v.cpu().clone() for k, v in model.state_dict().items()}


def load_state_dict(model, state_dict):
    """Restore model state dict."""
    device_state = {k: v.to(DEVICE) for k, v in state_dict.items()}
    model.load_state_dict(device_state)


print("Helper functions defined!")

## 4. Experiment 1: Gaussian Noise Drift

We'll inject increasing amounts of Gaussian noise into model weights and measure how the representational geometry drifts.

In [None]:
# Models to test (small for fast iteration)
MODELS = [
    "HuggingFaceTB/SmolLM-135M",
    "HuggingFaceTB/SmolLM-360M",
]

# Noise levels (fraction of parameter std)
NOISE_LEVELS = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.5]

noise_results = []

for model_name in MODELS:
    print(f"\n{'='*60}")
    print(f"Model: {model_name}")
    print(f"{'='*60}")

    # Load model
    model, tokenizer = load_model(model_name)
    clean_state = get_state_dict(model)

    # Get clean embeddings
    emb_clean = get_embeddings(model, tokenizer, texts)
    print(f"Clean embeddings: {emb_clean.shape}")

    # Test each noise level
    print(f"\n{'Noise':<10} {'RDM_Sim':<12} {'Drift':<12} {'Stability':<12}")
    print("-" * 50)

    for alpha in NOISE_LEVELS:
        # Restore clean weights
        load_state_dict(model, clean_state)

        # Inject noise
        inject_noise(model, alpha, seed=SEED + int(alpha * 1000))

        # Get noisy embeddings
        emb_noisy = get_embeddings(model, tokenizer, texts)

        # Compute drift metrics using SHESHA
        rdm_sim = shesha.rdm_similarity(emb_clean, emb_noisy)
        drift = shesha.rdm_drift(emb_clean, emb_noisy)
        stability = shesha.feature_split(emb_noisy, n_splits=20, seed=SEED)

        print(f"{alpha:<10.2f} {rdm_sim:<12.4f} {drift:<12.4f} {stability:<12.4f}")

        noise_results.append({
            'model': model_name.split('/')[-1],
            'noise_level': alpha,
            'rdm_similarity': rdm_sim,
            'rdm_drift': drift,
            'stability': stability,
        })

    # Cleanup
    del model, tokenizer, clean_state, emb_clean
    torch.cuda.empty_cache() if DEVICE == "cuda" else None

df_noise = pd.DataFrame(noise_results)
print(f"\n\nCollected {len(df_noise)} measurements")

## 5. Experiment 2: LoRA Adapter Drift

We'll apply LoRA adapters with varying ranks and initialization scales to simulate fine-tuning perturbations.

In [None]:
if not HAS_PEFT:
    print("Skipping LoRA experiment (peft not installed)")
    df_lora = pd.DataFrame()
else:
    # LoRA configurations
    LORA_CONFIGS = [
        # Varying rank at moderate init (simulates "how much capacity")
        {"rank": 1, "alpha": 2, "init_scale": 0.05},
        {"rank": 4, "alpha": 8, "init_scale": 0.05},
        {"rank": 8, "alpha": 16, "init_scale": 0.05},
        {"rank": 16, "alpha": 32, "init_scale": 0.05},
        {"rank": 32, "alpha": 64, "init_scale": 0.05},

        # Varying init at fixed rank (simulates "how much training")
        {"rank": 8, "alpha": 16, "init_scale": 0.01},
        {"rank": 8, "alpha": 16, "init_scale": 0.05},
        {"rank": 8, "alpha": 16, "init_scale": 0.1},
    ]

    lora_results = []

    for model_name in MODELS:
        print(f"\n{'='*60}")
        print(f"Model: {model_name}")
        print(f"{'='*60}")

        # Load base model and get clean embeddings
        model, tokenizer = load_model(model_name)
        emb_clean = get_embeddings(model, tokenizer, texts)
        print(f"Clean embeddings: {emb_clean.shape}")
        del model
        torch.cuda.empty_cache() if DEVICE == "cuda" else None

        print(f"\n{'Rank':<8} {'Alpha':<8} {'RDM_Sim':<12} {'Drift':<12} {'Stability':<12}")
        print("-" * 55)

        for config in LORA_CONFIGS:
            rank = config["rank"]
            alpha = config["alpha"]
            init_scale = config["init_scale"]

            # Load fresh model
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=DTYPE,
                trust_remote_code=True,
            ).to(DEVICE)

            # Apply LoRA
            try:
                lora_config = LoraConfig(
                    r=rank,
                    lora_alpha=alpha,
                    target_modules=["q_proj", "v_proj"],
                    lora_dropout=0.0,
                    bias="none",
                    task_type=TaskType.CAUSAL_LM
                )
                model = get_peft_model(model, lora_config)

                # Initialize LoRA weights with random values
                torch.manual_seed(SEED + rank)
                with torch.no_grad():
                    for name, param in model.named_parameters():
                        if "lora_" in name:
                            param.data = torch.randn_like(param) * init_scale

                model.eval()

                # Get embeddings
                emb_lora = get_embeddings(model, tokenizer, texts)

                # Compute drift
                rdm_sim = shesha.rdm_similarity(emb_clean, emb_lora)
                drift = shesha.rdm_drift(emb_clean, emb_lora)
                stability = shesha.feature_split(emb_lora, n_splits=20, seed=SEED)

                print(f"{rank:<8} {alpha:<8} {rdm_sim:<12.4f} {drift:<12.4f} {stability:<12.4f}")

                lora_results.append({
                    'model': model_name.split('/')[-1],
                    'lora_rank': rank,
                    'lora_alpha': alpha,
                    'rdm_similarity': rdm_sim,
                    'rdm_drift': drift,
                    'stability': stability,
                })

            except Exception as e:
                print(f"{rank:<8} {alpha:<8} ERROR: {e}")

            del model
            torch.cuda.empty_cache() if DEVICE == "cuda" else None

        del emb_clean, tokenizer

    df_lora = pd.DataFrame(lora_results)
    print(f"\n\nCollected {len(df_lora)} measurements")

In [None]:
if not HAS_PEFT:
    print("Skipping LoRA experiment (peft not installed)")
    df_lora = pd.DataFrame()
else:
    # LoRA configurations
    LORA_CONFIGS = [
        {"rank": 1, "alpha": 2, "init_scale": 0.01},
        {"rank": 2, "alpha": 4, "init_scale": 0.01},
        {"rank": 4, "alpha": 8, "init_scale": 0.01},
        {"rank": 8, "alpha": 16, "init_scale": 0.01},
        {"rank": 16, "alpha": 32, "init_scale": 0.01},
        {"rank": 32, "alpha": 64, "init_scale": 0.01},
    ]


    lora_results = []

    for model_name in MODELS:
        print(f"\n{'='*60}")
        print(f"Model: {model_name}")

        # 1. Load Base Model
        base_model, tokenizer = load_model(model_name)

        # Get clean embeddings from base model
        emb_clean = get_embeddings(base_model, tokenizer, texts)
        print(f"Clean embeddings: {emb_clean.shape}")

        print(f"\n{'Rank':<8} {'Alpha':<8} {'RDM_Sim':<12} {'Drift':<12} {'Stability':<12}")
        print("-" * 55)

        for config in LORA_CONFIGS:
            rank = config["rank"]
            alpha = config["alpha"]
            init_scale = config["init_scale"]

            try:
                # 2. Dynamic Target Modules (Optional safety check)
                # You might need a helper here to determine targets based on model_name
                target_modules = ["q_proj", "v_proj"]

                lora_config = LoraConfig(
                    r=rank,
                    lora_alpha=alpha,
                    target_modules=target_modules,
                    lora_dropout=0.0,
                    bias="none",
                    task_type=TaskType.CAUSAL_LM
                )

                # 3. Apply Adapter to existing model in-memory
                # This wraps the base model without reloading weights
                peft_model = get_peft_model(base_model, lora_config)

                # 4. Initialization
                torch.manual_seed(SEED + rank)
                with torch.no_grad():
                    for name, param in peft_model.named_parameters():
                        if "lora_A" in name:
                            # Initialize A with noise
                            param.data = torch.randn_like(param) * init_scale
                        elif "lora_B" in name:
                            # Initialize B with noise (For Drift Experiment)
                            # NOTE: Standard LoRA keeps this at 0.
                            # To simulate drift, noise is fine, but be aware
                            # you are breaking the identity property.
                            param.data = torch.randn_like(param)

                peft_model.eval()

                # Get embeddings
                emb_lora = get_embeddings(peft_model, tokenizer, texts)

                # Compute metrics
                rdm_sim = shesha.rdm_similarity(emb_clean, emb_lora)
                drift = shesha.rdm_drift(emb_clean, emb_lora)
                stability = shesha.feature_split(emb_lora, n_splits=20, seed=SEED)

                print(f"{rank:<8} {alpha:<8} {rdm_sim:<12.4f} {drift:<12.4f} {stability:<12.4f}")

                lora_results.append({
                    'model': model_name.split('/')[-1],
                    'lora_rank': rank,
                    'lora_alpha': alpha,
                    'rdm_similarity': rdm_sim,
                    'rdm_drift': drift,
                    'stability': stability,
                })

                # This removes the LoRA layers so the next iteration starts clean
                peft_model.unload()
                del peft_model

            except Exception as e:
                print(f"{rank:<8} {alpha:<8} ERROR: {e}")
                # Fallback cleanup just in case
                if 'peft_model' in locals():
                    try: peft_model.unload()
                    except: pass

            torch.cuda.empty_cache() if DEVICE == "cuda" else None

        # Clean up base model before next model in outer loop
        del base_model, tokenizer
        torch.cuda.empty_cache() if DEVICE == "cuda" else None

## 6. Results Summary

In [None]:
print("="*60)
print("GAUSSIAN NOISE RESULTS")
print("="*60)
print(df_noise.to_string(index=False))

print(f"\n\n{'='*60}")
print("MEAN DRIFT BY NOISE LEVEL")
print("="*60)
noise_summary = df_noise.groupby('noise_level')[['rdm_similarity', 'rdm_drift', 'stability']].mean()
print(noise_summary.to_string())

if len(df_lora) > 0:
    print(f"\n\n{'='*60}")
    print("LORA RESULTS")
    print("="*60)
    print(df_lora.to_string(index=False))

    print(f"\n\n{'='*60}")
    print("MEAN DRIFT BY LORA RANK")
    print("="*60)
    lora_summary = df_lora.groupby('lora_rank')[['rdm_similarity', 'rdm_drift', 'stability']].mean()
    print(lora_summary.to_string())

## 7. Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Drift vs Noise Level
ax = axes[0, 0]
for model in df_noise['model'].unique():
    subset = df_noise[df_noise['model'] == model]
    ax.plot(subset['noise_level'], subset['rdm_drift'], 'o-', label=model, markersize=6)
ax.set_xlabel('Noise Level (fraction of param std)')
ax.set_ylabel('RDM Drift')
ax.set_title('Representational Drift vs Gaussian Noise')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Similarity vs Noise Level
ax = axes[0, 1]
for model in df_noise['model'].unique():
    subset = df_noise[df_noise['model'] == model]
    ax.plot(subset['noise_level'], subset['rdm_similarity'], 's-', label=model, markersize=6)
ax.set_xlabel('Noise Level (fraction of param std)')
ax.set_ylabel('RDM Similarity')
ax.set_title('RDM Similarity vs Gaussian Noise')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])

# Plot 3: Drift vs LoRA Rank
ax = axes[1, 0]
if len(df_lora) > 0:
    for model in df_lora['model'].unique():
        subset = df_lora[df_lora['model'] == model]
        ax.plot(subset['lora_rank'], subset['rdm_drift'], 'o-', label=model, markersize=6)
    ax.set_xlabel('LoRA Rank')
    ax.set_ylabel('RDM Drift')
    ax.set_title('Representational Drift vs LoRA Rank')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xscale('log', base=2)
else:
    ax.text(0.5, 0.5, 'LoRA experiment skipped\n(peft not installed)',
            ha='center', va='center', transform=ax.transAxes, fontsize=12)
    ax.set_title('LoRA Drift (skipped)')

# Plot 4: Stability vs Drift (all experiments)
ax = axes[1, 1]
ax.scatter(df_noise['rdm_drift'], df_noise['stability'],
           alpha=0.7, label='Gaussian Noise', s=50, c='blue')
if len(df_lora) > 0:
    ax.scatter(df_lora['rdm_drift'], df_lora['stability'],
               alpha=0.7, label='LoRA', s=50, c='orange', marker='s')
ax.set_xlabel('RDM Drift (from clean)')
ax.set_ylabel('Feature-Split Stability')
ax.set_title('Internal Stability vs Drift')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('drift_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print("\nFigure saved as 'drift_analysis.png'")

## 8. Interpretation Guide

### RDM Drift Values

| Drift | Interpretation |
|-------|----------------|
| 0.0 | Identical geometry |
| 0.0-0.1 | Minimal drift (very similar) |
| 0.1-0.3 | Moderate drift |
| 0.3-0.5 | Substantial drift |
| 0.5+ | Major geometric change |
| 1.0 | Uncorrelated (random) |

### Key Findings

1. **Gaussian Noise**: Drift increases monotonically with noise level
2. **LoRA Rank**: Higher ranks = more parameters = more potential drift
3. **Stability vs Drift**: Internal stability often decreases as drift increases

### When to Use Each Metric

- **`rdm_similarity`**: When you want similarity (higher = more similar)
- **`rdm_drift`**: When you want to measure change (higher = more change)
- **`feature_split`**: For internal consistency of a single representation

## 9. Quick Reference

```python
import shesha

# Measure similarity between two representations
similarity = shesha.rdm_similarity(X_before, X_after)
# Returns: float in [-1, 1], higher = more similar geometry

# Measure drift between two representations
drift = shesha.rdm_drift(X_before, X_after)
# Returns: float in [0, 2], where 0 = identical, 1 = uncorrelated

# Use Pearson instead of Spearman
drift = shesha.rdm_drift(X_before, X_after, method='pearson')

# Use different distance metric
drift = shesha.rdm_drift(X_before, X_after, metric='euclidean')
```