# Shesha Tutorial: LLM Embedding Analysis

This notebook demonstrates how to use Shesha to analyze the geometric stability of language model embeddings.

**What you'll learn:**
- How to extract embeddings from transformer models
- How to measure representation stability with `feature_split`
- How to compare stability across layers and models
- How to use supervised alignment to measure task relevance

**Requirements:**
```bash
pip install shesha-geometry transformers torch datasets
```

## 1. Setup and Imports

In [None]:
# Optional: Install dependencies
# !pip install shesha-geometry transformers torch datasets

In [None]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
import matplotlib.pyplot as plt

# Shesha
import shesha

# Configuration
SEED = 320
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")
print(f"Shesha version: {shesha.__version__}")

## 2. Load Model and Tokenizer

We'll use BERT-base as an example, but this works with any HuggingFace transformer.

In [None]:
MODEL_NAME = "bert-base-uncased"

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True)
model = model.to(DEVICE)
model.eval()

print(f"Model loaded: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden dim")

## 3. Load a Text Dataset

We'll use a subset of the SST-2 sentiment dataset.

In [None]:
# Load SST-2 dataset
print("Loading SST-2 dataset...")
dataset = load_dataset("glue", "sst2", split="validation")

# Take a subset for speed
N_SAMPLES = 500
texts = dataset["sentence"][:N_SAMPLES]
labels = np.array(dataset["label"][:N_SAMPLES])

print(f"Loaded {len(texts)} samples")
print(f"Label distribution: {np.bincount(labels)}")
print(f"\nExample: '{texts[0]}' -> {labels[0]}")

## 4. Extract Embeddings

We'll extract [CLS] token embeddings from each layer.

In [None]:
def extract_embeddings(texts, model, tokenizer, batch_size=32):
    """Extract [CLS] embeddings from all layers."""
    all_hidden_states = [[] for _ in range(model.config.num_hidden_layers + 1)]

    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]

            # Tokenize
            inputs = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=128,
                return_tensors="pt"
            ).to(DEVICE)

            # Forward pass
            outputs = model(**inputs)

            # Extract [CLS] token (index 0) from each layer
            for layer_idx, hidden_state in enumerate(outputs.hidden_states):
                cls_embedding = hidden_state[:, 0, :].cpu().numpy()
                all_hidden_states[layer_idx].append(cls_embedding)

    # Concatenate batches
    return [np.vstack(layer) for layer in all_hidden_states]

print("Extracting embeddings...")
layer_embeddings = extract_embeddings(texts, model, tokenizer)
print(f"Extracted embeddings from {len(layer_embeddings)} layers")
print(f"Shape per layer: {layer_embeddings[0].shape}")

## 5. Measure Stability Across Layers

Use `shesha.feature_split` to measure geometric stability at each layer.

In [None]:
# Compute stability for each layer
stabilities = []

print(f"{'Layer':<10} {'Stability':>10}")
print("-" * 22)

for layer_idx, embeddings in enumerate(layer_embeddings):
    stability = shesha.feature_split(embeddings, n_splits=30, seed=SEED)
    stabilities.append(stability)
    print(f"{layer_idx:<10} {stability:>10.3f}")

print(f"\nMean stability: {np.mean(stabilities):.3f}")

## 6. Measure Task Alignment (Supervised)

Use `shesha.supervised_alignment` to measure how well each layer's geometry aligns with sentiment labels.

In [None]:
# Compute supervised alignment for each layer
alignments = []

print(f"{'Layer':<10} {'Stability':>10} {'Alignment':>10}")
print("-" * 32)

for layer_idx, embeddings in enumerate(layer_embeddings):
    alignment = shesha.supervised_alignment(embeddings, labels, seed=SEED)
    alignments.append(alignment)
    print(f"{layer_idx:<10} {stabilities[layer_idx]:>10.3f} {alignment:>10.3f}")

## 7. Visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

layers = list(range(len(layer_embeddings)))

# Plot 1: Stability across layers
ax = axes[0]
ax.plot(layers, stabilities, 'b-o', markersize=6)
ax.set_xlabel('Layer')
ax.set_ylabel('Stability (feature_split)')
ax.set_title('Geometric Stability Across Layers')
ax.grid(True, alpha=0.3)
ax.set_xticks(layers)

# Plot 2: Task alignment across layers
ax = axes[1]
ax.plot(layers, alignments, 'r-s', markersize=6)
ax.set_xlabel('Layer')
ax.set_ylabel('Task Alignment (supervised)')
ax.set_title('Sentiment Alignment Across Layers')
ax.grid(True, alpha=0.3)
ax.set_xticks(layers)

plt.tight_layout()
plt.savefig('llm_layer_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print("Figure saved as 'llm_layer_analysis.png'")

## 8. Compare Multiple Models

Let's compare stability across different model sizes.

In [None]:
def analyze_model(model_name, texts, labels):
    """Analyze a single model and return metrics."""
    print(f"\nAnalyzing {model_name}...")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
    model = model.to(DEVICE)
    model.eval()

    # Extract final layer embeddings
    layer_embeddings = extract_embeddings(texts, model, tokenizer)
    final_layer = layer_embeddings[-1]

    # Compute metrics
    stability = shesha.feature_split(final_layer, n_splits=30, seed=SEED)
    alignment = shesha.supervised_alignment(final_layer, labels, seed=SEED)

    # Clean up
    del model
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

    return {
        'model': model_name,
        'stability': stability,
        'alignment': alignment,
        'hidden_dim': final_layer.shape[1]
    }

# Compare models (uncomment to run - takes a few minutes)
# models_to_compare = [
#     "bert-base-uncased",
#     "bert-large-uncased",
#     "distilbert-base-uncased",
# ]
#
# results = [analyze_model(m, texts, labels) for m in models_to_compare]
#
# for r in results:
#     print(f"{r['model']:<30} stability={r['stability']:.3f}  alignment={r['alignment']:.3f}")

## 9. Monitor Fine-tuning Drift

Shesha can track how representations change during fine-tuning.

In [None]:
# Simulated fine-tuning drift example
# In practice, you'd extract embeddings at each checkpoint

print("Simulating fine-tuning drift monitoring...")
print("(In practice, extract embeddings at each training checkpoint)\n")

# Simulate embeddings evolving during training
rng = np.random.default_rng(SEED)
base_embeddings = layer_embeddings[-1]  # Start from final BERT layer

epochs = [0, 1, 2, 3, 4, 5]
epoch_stabilities = []
epoch_alignments = []

for epoch in epochs:
    # Simulate drift: add task-relevant signal + noise
    signal_strength = epoch * 0.1
    noise_strength = epoch * 0.05

    # Add class-conditional signal
    signal = np.zeros_like(base_embeddings)
    signal[labels == 0] -= signal_strength
    signal[labels == 1] += signal_strength

    # Add noise
    noise = rng.standard_normal(base_embeddings.shape) * noise_strength

    epoch_embeddings = base_embeddings + signal + noise

    stability = shesha.feature_split(epoch_embeddings, seed=SEED)
    alignment = shesha.supervised_alignment(epoch_embeddings, labels, seed=SEED)

    epoch_stabilities.append(stability)
    epoch_alignments.append(alignment)

    print(f"Epoch {epoch}: stability={stability:.3f}, alignment={alignment:.3f}")

In [None]:
# Plot training dynamics
fig, ax = plt.subplots(figsize=(8, 4))

ax.plot(epochs, epoch_stabilities, 'b-o', label='Stability', markersize=8)
ax.plot(epochs, epoch_alignments, 'r-s', label='Task Alignment', markersize=8)
ax.set_xlabel('Epoch')
ax.set_ylabel('Score')
ax.set_title('Representation Dynamics During Fine-tuning')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('finetuning_dynamics.png', dpi=150, bbox_inches='tight')
plt.show()

## 10. Interpretation Guide

### Stability (feature_split)
- **High (> 0.5)**: Redundant, robust representations - geometry is consistent across feature subsets
- **Low (< 0.2)**: Fragile representations - different features encode different structure

### Task Alignment (supervised_alignment)
- **High (> 0.3)**: Representations encode task-relevant structure
- **Low (< 0.1)**: Representations don't reflect task labels

### Common Patterns
1. **Early layers**: Lower task alignment, variable stability
2. **Middle layers**: Often highest stability (general features)
3. **Final layers**: Higher task alignment (task-specific features)
4. **Fine-tuning**: Alignment increases, stability may decrease (specialization)

## 11. Quick Reference

```python
import shesha

# Measure geometric stability (unsupervised)
stability = shesha.feature_split(embeddings, n_splits=30, seed=320)

# Measure task alignment (supervised)
alignment = shesha.supervised_alignment(embeddings, labels, seed=320)

# Quick separability check
var_ratio = shesha.variance_ratio(embeddings, labels)

# Alternative stability metrics
sample_stab = shesha.sample_split(embeddings, seed=320)
anchor_stab = shesha.anchor_stability(embeddings, seed=320)
```