# Pseudo-Spectral Landscapes: Understanding Neural Network Optimization

This notebook demonstrates **pseudo-spectral landscapes** and their relevance to understanding Large Language Models (LLMs) through geometric approaches.

## What You'll Learn

1. **Loss landscape geometry** through Hessian analysis
2. **Eigenvalue spectra** and what they reveal
3. **Sharp vs flat minima** and generalization
4. **Attention mechanism geometry** in transformers
5. **Practical implications** for training LLMs

---

## Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
from scipy.linalg import eigh, svd
from scipy.optimize import minimize
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# For better figure quality
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

print("✓ All imports successful!")

---
## Part 1: Toy Neural Network

We start with a simple 2-layer neural network. While small, it demonstrates all the key concepts that apply to billion-parameter LLMs.

### Key Concepts:
- **Hessian Matrix**: $H_{ij} = \frac{\partial^2 L}{\partial \theta_i \partial \theta_j}$
- **Eigenvalue Spectrum**: The set of eigenvalues $\{\lambda_1, \lambda_2, ..., \lambda_n\}$
- **Geometric Interpretation**: Eigenvalues reveal local curvature of the loss landscape

In [None]:
class ToyNeuralNetwork:
    """
    A simple 2-layer neural network for demonstration.
    Small enough to compute full Hessian, yet rich enough to show key concepts.
    """
    
    def __init__(self, input_dim=2, hidden_dim=3, output_dim=1, seed=42):
        np.random.seed(seed)
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        # Initialize weights
        self.W1 = np.random.randn(input_dim, hidden_dim) * 0.5
        self.b1 = np.random.randn(hidden_dim) * 0.1
        self.W2 = np.random.randn(hidden_dim, output_dim) * 0.5
        self.b2 = np.random.randn(output_dim) * 0.1
        
    def sigmoid(self, x):
        return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
    
    def sigmoid_derivative(self, x):
        s = self.sigmoid(x)
        return s * (1 - s)
    
    def forward(self, X):
        """Forward pass"""
        self.z1 = X @ self.W1 + self.b1
        self.a1 = self.sigmoid(self.z1)
        self.z2 = self.a1 @ self.W2 + self.b2
        return self.z2
    
    def get_params_vector(self):
        """Flatten all parameters into a single vector"""
        return np.concatenate([
            self.W1.flatten(),
            self.b1.flatten(),
            self.W2.flatten(),
            self.b2.flatten()
        ])
    
    def set_params_vector(self, params):
        """Set parameters from a flattened vector"""
        idx = 0
        
        w1_size = self.input_dim * self.hidden_dim
        self.W1 = params[idx:idx + w1_size].reshape(self.input_dim, self.hidden_dim)
        idx += w1_size
        
        self.b1 = params[idx:idx + self.hidden_dim]
        idx += self.hidden_dim
        
        w2_size = self.hidden_dim * self.output_dim
        self.W2 = params[idx:idx + w2_size].reshape(self.hidden_dim, self.output_dim)
        idx += w2_size
        
        self.b2 = params[idx:idx + self.output_dim]
    
    def loss(self, X, y):
        """MSE loss"""
        pred = self.forward(X)
        return 0.5 * np.mean((pred - y) ** 2)
    
    def compute_gradient(self, X, y):
        """Compute gradient via backpropagation"""
        m = X.shape[0]
        
        # Forward pass
        pred = self.forward(X)
        
        # Backward pass
        dz2 = (pred - y) / m
        dW2 = self.a1.T @ dz2
        db2 = np.sum(dz2, axis=0)
        
        da1 = dz2 @ self.W2.T
        dz1 = da1 * self.sigmoid_derivative(self.z1)
        dW1 = X.T @ dz1
        db1 = np.sum(dz1, axis=0)
        
        return np.concatenate([
            dW1.flatten(),
            db1.flatten(),
            dW2.flatten(),
            db2.flatten()
        ])
    
    def compute_hessian(self, X, y):
        """
        Compute the Hessian matrix numerically using finite differences.
        This is the KEY for pseudo-spectral analysis!
        """
        n_params = len(self.get_params_vector())
        H = np.zeros((n_params, n_params))
        eps = 1e-5
        
        def loss_fn(params):
            self.set_params_vector(params)
            return self.loss(X, y)
        
        params = self.get_params_vector()
        
        for i in range(n_params):
            for j in range(i, n_params):
                # Compute second derivative using finite differences
                params_pp = params.copy()
                params_pp[i] += eps
                params_pp[j] += eps
                
                params_pm = params.copy()
                params_pm[i] += eps
                params_pm[j] -= eps
                
                params_mp = params.copy()
                params_mp[i] -= eps
                params_mp[j] += eps
                
                params_mm = params.copy()
                params_mm[i] -= eps
                params_mm[j] -= eps
                
                f_pp = loss_fn(params_pp)
                f_pm = loss_fn(params_pm)
                f_mp = loss_fn(params_mp)
                f_mm = loss_fn(params_mm)
                
                H[i, j] = (f_pp - f_pm - f_mp + f_mm) / (4 * eps * eps)
                H[j, i] = H[i, j]
        
        # Reset to original parameters
        self.set_params_vector(params)
        return H

print("✓ ToyNeuralNetwork class defined")

### Generate Synthetic Data

In [None]:
def generate_synthetic_data(n_samples=100, seed=42):
    """Generate simple synthetic data for demonstration"""
    np.random.seed(seed)
    X = np.random.randn(n_samples, 2)
    # Simple nonlinear function
    y = (np.sin(X[:, 0]) + 0.5 * X[:, 1]**2).reshape(-1, 1)
    y += np.random.randn(n_samples, 1) * 0.1  # Add noise
    return X, y

# Generate data
X, y = generate_synthetic_data(n_samples=50)
print(f"✓ Generated data: X.shape={X.shape}, y.shape={y.shape}")

---
## Part 2: Loss Landscape Visualization

Let's visualize the loss landscape - a 2D slice through high-dimensional parameter space.

In [None]:
def visualize_loss_landscape_2d(nn, X, y, param_indices=(0, 1), 
                                 range_scale=2.0, resolution=50):
    """
    Visualize the loss landscape along two parameter directions.
    """
    params = nn.get_params_vector()
    i, j = param_indices
    
    # Create grid around current parameters
    param_range_i = np.linspace(params[i] - range_scale, 
                                 params[i] + range_scale, resolution)
    param_range_j = np.linspace(params[j] - range_scale, 
                                 params[j] + range_scale, resolution)
    
    loss_grid = np.zeros((resolution, resolution))
    
    for idx_i, pi in enumerate(param_range_i):
        for idx_j, pj in enumerate(param_range_j):
            params_temp = params.copy()
            params_temp[i] = pi
            params_temp[j] = pj
            nn.set_params_vector(params_temp)
            loss_grid[idx_j, idx_i] = nn.loss(X, y)
    
    # Reset parameters
    nn.set_params_vector(params)
    
    return param_range_i, param_range_j, loss_grid

# Create network and visualize
nn = ToyNeuralNetwork()

fig = plt.figure(figsize=(16, 5))

# 2D Contour plot
ax1 = plt.subplot(1, 3, 1)
param_i, param_j, loss_grid = visualize_loss_landscape_2d(
    nn, X, y, param_indices=(0, 1), range_scale=1.5, resolution=40
)
contour = ax1.contour(param_i, param_j, loss_grid, levels=20, cmap='viridis')
ax1.contourf(param_i, param_j, loss_grid, levels=20, cmap='viridis', alpha=0.6)
plt.colorbar(contour, ax=ax1)
params = nn.get_params_vector()
ax1.plot(params[0], params[1], 'r*', markersize=15, label='Current position')
ax1.set_xlabel('Parameter 0 (W1[0,0])')
ax1.set_ylabel('Parameter 1 (W1[0,1])')
ax1.set_title('Loss Landscape (2D Slice)', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 3D Surface plot
ax2 = plt.subplot(1, 3, 2, projection='3d')
Pi, Pj = np.meshgrid(param_i, param_j)
surf = ax2.plot_surface(Pi, Pj, loss_grid, cmap='viridis', alpha=0.8)
ax2.set_xlabel('Parameter 0')
ax2.set_ylabel('Parameter 1')
ax2.set_zlabel('Loss')
ax2.set_title('3D Loss Surface', fontsize=12, fontweight='bold')

# Different slice
ax3 = plt.subplot(1, 3, 3)
param_i2, param_j2, loss_grid2 = visualize_loss_landscape_2d(
    nn, X, y, param_indices=(2, 3), range_scale=1.5, resolution=40
)
contour2 = ax3.contour(param_i2, param_j2, loss_grid2, levels=20, cmap='plasma')
ax3.contourf(param_i2, param_j2, loss_grid2, levels=20, cmap='plasma', alpha=0.6)
plt.colorbar(contour2, ax=ax3)
ax3.plot(params[2], params[3], 'r*', markersize=15)
ax3.set_xlabel('Parameter 2 (W1[0,2])')
ax3.set_ylabel('Parameter 3 (W1[1,0])')
ax3.set_title('Loss Landscape (Different Slice)', fontsize=12, fontweight='bold')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("✓ Loss landscape visualization complete")

---
## Part 3: Hessian Spectrum Analysis

Now for the **key concept**: analyzing the eigenvalue spectrum of the Hessian matrix.

### What the Eigenvalues Tell Us:
- **λ > 0**: Upward curvature (local minimum in this direction)
- **λ < 0**: Downward curvature (saddle point)
- **|λ| ≈ 0**: Flat direction (easy to move without changing loss)
- **|λ| large**: Sharp curvature (dominates optimization difficulty)

In [None]:
def analyze_hessian_spectrum(nn, X, y):
    """
    Compute and analyze the eigenvalue spectrum of the Hessian.
    This is the CORE of pseudo-spectral analysis!
    """
    print("Computing Hessian (this may take a moment)...")
    H = nn.compute_hessian(X, y)
    
    # Compute eigenvalues
    eigenvalues, eigenvectors = eigh(H)
    
    return eigenvalues, eigenvectors, H

# Analyze the spectrum
eigenvalues, eigenvectors, H = analyze_hessian_spectrum(nn, X, y)

print(f"\n✓ Hessian analysis complete!")
print(f"   Number of parameters: {len(eigenvalues)}")
print(f"   Max eigenvalue: {np.max(eigenvalues):.4f}")
print(f"   Min eigenvalue: {np.min(eigenvalues):.4f}")
print(f"   Negative eigenvalues: {np.sum(eigenvalues < 0)}")

### Visualize the Spectrum

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# 1. Eigenvalue Spectrum
ax = axes[0, 0]
ax.bar(range(len(eigenvalues)), eigenvalues, color='steelblue', alpha=0.7)
ax.axhline(y=0, color='r', linestyle='--', alpha=0.5)
ax.set_xlabel('Eigenvalue Index')
ax.set_ylabel('Eigenvalue Magnitude')
ax.set_title('Hessian Eigenvalue Spectrum', fontweight='bold')
ax.grid(True, alpha=0.3)

# Add statistics
n_positive = np.sum(eigenvalues > 1e-6)
n_negative = np.sum(eigenvalues < -1e-6)
ax.text(0.05, 0.95, f'Positive: {n_positive}\nNegative: {n_negative}\nZero: {len(eigenvalues)-n_positive-n_negative}',
        transform=ax.transAxes, verticalalignment='top',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# 2. Eigenvalue Distribution
ax = axes[0, 1]
ax.hist(eigenvalues, bins=15, color='coral', alpha=0.7, edgecolor='black')
ax.axvline(x=0, color='r', linestyle='--', linewidth=2, label='Zero')
ax.set_xlabel('Eigenvalue')
ax.set_ylabel('Frequency')
ax.set_title('Eigenvalue Distribution', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. Hessian Matrix Heatmap
ax = axes[0, 2]
im = ax.imshow(H, cmap='RdBu_r', aspect='auto', vmin=-np.abs(H).max(), vmax=np.abs(H).max())
plt.colorbar(im, ax=ax)
ax.set_xlabel('Parameter Index')
ax.set_ylabel('Parameter Index')
ax.set_title('Hessian Matrix', fontweight='bold')

# 4. Loss along principal curvature directions
ax = axes[1, 0]
current_params = nn.get_params_vector()
alphas = np.linspace(-1, 1, 50)

for idx in [0, len(eigenvalues)//2, -1]:  # Min, middle, max curvature
    losses = []
    for alpha in alphas:
        perturbed = current_params + alpha * eigenvectors[:, idx]
        nn.set_params_vector(perturbed)
        losses.append(nn.loss(X, y))
    nn.set_params_vector(current_params)
    
    label = f'λ={eigenvalues[idx]:.3f}'
    ax.plot(alphas, losses, label=label, linewidth=2)

ax.set_xlabel('Step size α')
ax.set_ylabel('Loss')
ax.set_title('Loss Along Principal Curvature Directions', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 5. Sorted eigenvalues (log scale)
ax = axes[1, 1]
sorted_eigs = np.sort(np.abs(eigenvalues))[::-1]
ax.semilogy(range(len(sorted_eigs)), sorted_eigs, 'o-', color='purple', markersize=6)
ax.set_xlabel('Eigenvalue Rank')
ax.set_ylabel('|Eigenvalue| (log scale)')
ax.set_title('Curvature Spectrum (Sorted)', fontweight='bold')
ax.grid(True, alpha=0.3, which='both')

# 6. Gradient magnitudes
ax = axes[1, 2]
grad = nn.compute_gradient(X, y)
ax.bar(range(len(grad)), np.abs(grad), color='teal', alpha=0.7)
ax.set_xlabel('Parameter Index')
ax.set_ylabel('|Gradient|')
ax.set_title('Gradient Magnitudes', fontweight='bold')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n=== KEY INSIGHTS ===")
print(f"Max eigenvalue (sharpest direction): {np.max(eigenvalues):.4f}")
print(f"Min eigenvalue (flattest direction): {np.min(eigenvalues):.4f}")
print(f"Condition number: {np.max(eigenvalues) / (np.abs(np.min(eigenvalues)) + 1e-8):.2f}")
print(f"Trace (total curvature): {np.sum(eigenvalues):.4f}")
print(f"\nGeometry: {'Saddle Point' if np.sum(eigenvalues < 0) > 0 else 'Local Minimum'}")

---
## Part 4: Training Dynamics

Let's train the network and watch how the spectral properties evolve.

In [None]:
def train_and_analyze(X, y, epochs=100, lr=0.05):
    """
    Train the network and track spectral properties.
    """
    nn = ToyNeuralNetwork()
    history = {'loss': [], 'max_eigenvalue': [], 'min_eigenvalue': [], 'trace': []}
    
    print("Training neural network...")
    for epoch in range(epochs):
        # Forward and backward pass
        loss = nn.loss(X, y)
        grad = nn.compute_gradient(X, y)
        
        # Update parameters
        params = nn.get_params_vector()
        params -= lr * grad
        nn.set_params_vector(params)
        
        history['loss'].append(loss)
        
        # Compute Hessian spectrum periodically
        if epoch % 20 == 0:
            eigenvalues, _, _ = analyze_hessian_spectrum(nn, X, y)
            history['max_eigenvalue'].append(np.max(eigenvalues))
            history['min_eigenvalue'].append(np.min(eigenvalues))
            history['trace'].append(np.sum(eigenvalues))
            print(f"Epoch {epoch}: Loss = {loss:.4f}, "
                  f"λ_max = {np.max(eigenvalues):.4f}, "
                  f"λ_min = {np.min(eigenvalues):.4f}")
    
    return nn, history

# Train the network
trained_nn, history = train_and_analyze(X, y, epochs=100, lr=0.05)

### Visualize Training Evolution

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs_hessian = np.arange(0, len(history['loss']), 20)

# Loss curve
axes[0, 0].plot(history['loss'], 'b-', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss', fontweight='bold')
axes[0, 0].set_yscale('log')
axes[0, 0].grid(True, alpha=0.3)

# Max and min eigenvalues
axes[0, 1].plot(epochs_hessian, history['max_eigenvalue'], 'r-', 
                linewidth=2, label='λ_max', marker='o')
axes[0, 1].plot(epochs_hessian, history['min_eigenvalue'], 'b-', 
                linewidth=2, label='λ_min', marker='s')
axes[0, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Eigenvalue')
axes[0, 1].set_title('Extreme Eigenvalues During Training', fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Trace
axes[1, 0].plot(epochs_hessian, history['trace'], 'g-', 
                linewidth=2, marker='o')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Trace(H)')
axes[1, 0].set_title('Hessian Trace (Total Curvature)', fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)

# Sharpness ratio
sharpness = np.array(history['max_eigenvalue']) / (np.abs(np.array(history['min_eigenvalue'])) + 1e-8)
axes[1, 1].plot(epochs_hessian, sharpness, 'm-', linewidth=2, marker='d')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('λ_max / |λ_min|')
axes[1, 1].set_title('Sharpness Ratio (Flat vs Sharp)', fontweight='bold')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(True, alpha=0.3, which='both')

plt.tight_layout()
plt.show()

print("✓ Training dynamics visualization complete")

---
## Part 5: Attention Mechanism Geometry

Now let's apply these concepts to transformer attention - the core of LLMs!

In [None]:
class SimpleAttentionLayer:
    """
    Simplified self-attention mechanism to demonstrate geometric properties.
    """
    
    def __init__(self, d_model=8, n_heads=2, seq_len=4, seed=42):
        np.random.seed(seed)
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.seq_len = seq_len
        
        # Query, Key, Value projection matrices
        self.W_Q = np.random.randn(d_model, d_model) * 0.1
        self.W_K = np.random.randn(d_model, d_model) * 0.1
        self.W_V = np.random.randn(d_model, d_model) * 0.1
        self.W_O = np.random.randn(d_model, d_model) * 0.1
        
    def attention(self, X):
        """
        Scaled dot-product attention
        """
        Q = X @ self.W_Q
        K = X @ self.W_K
        V = X @ self.W_V
        
        # Attention scores
        scores = Q @ K.T / np.sqrt(self.d_k)
        
        # Softmax
        attention_weights = self.softmax(scores)
        
        # Apply attention to values
        attended = attention_weights @ V
        
        # Output projection
        output = attended @ self.W_O
        
        return output, attention_weights, Q, K, V
    
    def softmax(self, x):
        exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

# Create attention layer
attention_layer = SimpleAttentionLayer(d_model=8, n_heads=2, seq_len=4)

# Input sequence (token embeddings)
X_attention = np.random.randn(4, 8) * 0.5

# Get attention outputs
output, attn_weights, Q, K, V = attention_layer.attention(X_attention)

print("✓ Attention layer created")
print(f"   Input shape: {X_attention.shape}")
print(f"   Attention weights shape: {attn_weights.shape}")
print(f"   Output shape: {output.shape}")

### Analyze Attention Geometry

In [None]:
# Analyze spectral properties of attention matrices
U_Q, S_Q, Vt_Q = svd(attention_layer.W_Q)
U_K, S_K, Vt_K = svd(attention_layer.W_K)
U_V, S_V, Vt_V = svd(attention_layer.W_V)
U_O, S_O, Vt_O = svd(attention_layer.W_O)

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# 1. Attention Weights Heatmap
ax = axes[0, 0]
im = ax.imshow(attn_weights, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_title('Attention Weight Matrix', fontweight='bold')
plt.colorbar(im, ax=ax)

# Add values
for i in range(attn_weights.shape[0]):
    for j in range(attn_weights.shape[1]):
        ax.text(j, i, f'{attn_weights[i, j]:.2f}',
               ha="center", va="center", color="black", fontsize=8)

# 2. QK^T Score Matrix
ax = axes[0, 1]
QKT = Q @ K.T / np.sqrt(attention_layer.d_k)
im = ax.imshow(QKT, cmap='RdBu_r', aspect='auto',
               vmin=-np.abs(QKT).max(), vmax=np.abs(QKT).max())
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_title('QK^T Score Matrix (Before Softmax)', fontweight='bold')
plt.colorbar(im, ax=ax)

# 3. Singular Value Spectra
ax = axes[0, 2]
x_pos = np.arange(len(S_Q))
width = 0.2

ax.bar(x_pos - 1.5*width, S_Q, width, label='W_Q', alpha=0.8)
ax.bar(x_pos - 0.5*width, S_K, width, label='W_K', alpha=0.8)
ax.bar(x_pos + 0.5*width, S_V, width, label='W_V', alpha=0.8)
ax.bar(x_pos + 1.5*width, S_O, width, label='W_O', alpha=0.8)

ax.set_xlabel('Singular Value Index')
ax.set_ylabel('Singular Value Magnitude')
ax.set_title('Singular Value Spectra\n(Low-Rank Structure!)', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 4. Q Matrix
ax = axes[1, 0]
im = ax.imshow(Q, cmap='RdBu_r', aspect='auto',
               vmin=-np.abs(Q).max(), vmax=np.abs(Q).max())
ax.set_xlabel('Dimension')
ax.set_ylabel('Sequence Position')
ax.set_title('Query Matrix Q', fontweight='bold')
plt.colorbar(im, ax=ax)

# 5. K Matrix
ax = axes[1, 1]
im = ax.imshow(K, cmap='RdBu_r', aspect='auto',
               vmin=-np.abs(K).max(), vmax=np.abs(K).max())
ax.set_xlabel('Dimension')
ax.set_ylabel('Sequence Position')
ax.set_title('Key Matrix K', fontweight='bold')
plt.colorbar(im, ax=ax)

# 6. V Matrix
ax = axes[1, 2]
im = ax.imshow(V, cmap='RdBu_r', aspect='auto',
               vmin=-np.abs(V).max(), vmax=np.abs(V).max())
ax.set_xlabel('Dimension')
ax.set_ylabel('Sequence Position')
ax.set_title('Value Matrix V', fontweight='bold')
plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

print("\n=== ATTENTION GEOMETRY INSIGHTS ===")
print(f"W_Q condition number: {S_Q[0] / (S_Q[-1] + 1e-8):.2f}")
print(f"W_K condition number: {S_K[0] / (S_K[-1] + 1e-8):.2f}")
print(f"W_V condition number: {S_V[0] / (S_V[-1] + 1e-8):.2f}")
print(f"W_O condition number: {S_O[0] / (S_O[-1] + 1e-8):.2f}")
print(f"\nAttention entropy: {-np.mean(attn_weights * np.log(attn_weights + 1e-10)):.4f}")
print(f"\nThis low-rank structure enables LoRA for efficient fine-tuning!")

---
## Part 6: Sharp vs Flat Minima

Let's visualize why flat minima generalize better.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

theta = np.linspace(-2, 2, 200)
epsilon = 0.5  # Perturbation magnitude

# Flat minimum: small second derivative
flat_minimum = 0.1 + 0.5 * theta**2
flat_lambda = 1.0

# Sharp minimum: large second derivative
sharp_minimum = 0.1 + 5 * theta**2
sharp_lambda = 10.0

# 1. Flat Minimum
ax = axes[0]
ax.plot(theta, flat_minimum, 'b-', linewidth=3, label='Loss')
ax.scatter([0], [0.1], color='gold', s=300, marker='*', 
           edgecolors='black', linewidths=2, zorder=5, label='Minimum')
ax.axvspan(-epsilon, epsilon, alpha=0.2, color='green', label=f'±{epsilon} perturbation')

loss_at_perturb_flat = 0.1 + 0.5 * epsilon**2
ax.plot([epsilon, epsilon], [0.1, loss_at_perturb_flat], 'r--', linewidth=2)
ax.text(epsilon + 0.1, (0.1 + loss_at_perturb_flat)/2, 
        f'ΔL={loss_at_perturb_flat-0.1:.2f}',
        fontsize=10, color='red', fontweight='bold')

ax.set_xlabel('Parameter θ', fontsize=12)
ax.set_ylabel('Loss L(θ)', fontsize=12)
ax.set_title(f'Flat Minimum (λ = {flat_lambda:.1f})\n✓ Robust to noise', 
             fontsize=13, fontweight='bold', color='blue')
ax.set_ylim([0, 2])
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# 2. Sharp Minimum
ax = axes[1]
ax.plot(theta, sharp_minimum, 'r-', linewidth=3, label='Loss')
ax.scatter([0], [0.1], color='gold', s=300, marker='*', 
           edgecolors='black', linewidths=2, zorder=5, label='Minimum')
ax.axvspan(-epsilon, epsilon, alpha=0.2, color='orange', label=f'±{epsilon} perturbation')

loss_at_perturb_sharp = 0.1 + 5 * epsilon**2
ax.plot([epsilon, epsilon], [0.1, loss_at_perturb_sharp], 'r--', linewidth=2)
ax.text(epsilon + 0.1, (0.1 + loss_at_perturb_sharp)/2, 
        f'ΔL={loss_at_perturb_sharp-0.1:.2f}',
        fontsize=10, color='red', fontweight='bold')

ax.set_xlabel('Parameter θ', fontsize=12)
ax.set_ylabel('Loss L(θ)', fontsize=12)
ax.set_title(f'Sharp Minimum (λ = {sharp_lambda:.1f})\n✗ Sensitive to noise', 
             fontsize=13, fontweight='bold', color='red')
ax.set_ylim([0, 2])
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# 3. Comparison
ax = axes[2]
ax.axis('off')

comparison_text = f"""
ROBUSTNESS COMPARISON

Same perturbation ε = {epsilon}

Flat Minimum:
  • ΔL = {loss_at_perturb_flat-0.1:.3f}
  • Small sensitivity
  • ✓ Robust to noise
  • ✓ Better generalization

Sharp Minimum:  
  • ΔL = {loss_at_perturb_sharp-0.1:.3f}
  • High sensitivity
  • ✗ Sensitive to noise
  • ✗ Poor generalization

Sensitivity Ratio:
{(loss_at_perturb_sharp-0.1)/(loss_at_perturb_flat-0.1):.1f}× more sensitive!

KEY INSIGHT:
Test data = perturbed training data
Flat minima → consistent performance

For LLMs:
• Use SAM optimizer
• Monitor λ_max
• Prefer flat architectures
"""

ax.text(0.05, 0.95, comparison_text, ha='left', va='top', 
        fontsize=11, family='monospace',
        bbox=dict(boxstyle='round', facecolor='lightyellow', 
                 alpha=0.8, edgecolor='black', linewidth=2))

plt.tight_layout()
plt.show()

print("✓ Sharp vs Flat comparison complete")

---
## Summary and Key Takeaways

### What We've Learned:

1. **Hessian Eigenvalues = Optimization Fingerprint**
   - λ > 0: Local minimum direction
   - λ < 0: Saddle point direction
   - |λ| ≈ 0: Flat direction (most important!)

2. **Flat Minima Generalize Better**
   - Small max eigenvalue → robust to perturbations
   - Test data ≈ perturbed training data
   - Motivates SAM and other sharpness-aware methods

3. **Low-Rank Structure in Attention**
   - Singular values decay rapidly
   - Only a few dimensions matter
   - Enables LoRA for efficient LLM fine-tuning

4. **High Dimensions = Saddle Points**
   - Almost all critical points are saddles
   - Negative eigenvalues are normal
   - SGD's noise helps escape saddles

5. **Practical Implications**
   - Learning rate ∝ 1/λ_max
   - Monitor condition number κ = λ_max/λ_min
   - Use spectral analysis to debug training
   - Design architectures with better geometry

### For LLMs Specifically:

- **Scale**: Billion parameters → need stochastic spectral approximations
- **Architecture**: Attention creates complex but analyzable geometry
- **Fine-tuning**: Low-rank structure enables parameter-efficient methods
- **Generalization**: Flat minima explain why some models generalize better

---

## Next Steps

To explore further:
1. Try different network architectures
2. Implement SAM optimizer
3. Analyze real transformer models
4. Study mode connectivity between minima
5. Investigate Neural Tangent Kernel connections

**The geometric perspective is fundamental to understanding modern AI!**

In [None]:
print("\n" + "="*70)
print("NOTEBOOK COMPLETE!")
print("="*70)
print("\nYou've learned about:")
print("  ✓ Loss landscape geometry")
print("  ✓ Hessian eigenvalue spectra")
print("  ✓ Sharp vs flat minima")
print("  ✓ Attention mechanism geometry")
print("  ✓ Connections to LLMs")
print("\nThese concepts are at the heart of modern deep learning!")
print("="*70)