# ðŸš€ Deep Dive: The Muon Optimizer

Welcome to this step-by-step tutorial on **Muon** (MomentUm Orthogonalized by Newton-schulz). 

In this notebook, we will:
1.  **Understand the Problem**: Why do we need another optimizer? What is "Gradient Orthogonalization"?
2.  **Explore the Math**: How Newton-Schulz iteration works to normalize matrices.
3.  **Implement Muon**: Step-by-step build of the optimizer class.
4.  **Compare Performance**: Pit Muon against SGD and Adam on a challenging optimization landscape.
5.  **Visualize Internals**: See how Muon affects the singular values of weight matrices.

## 1. The Problem: Scaling and Curvature

### The Landscape
Imagine training a neural network as walking down a mountain. 
-   **SGD** takes steps based on the slope immediately under your feet. If the slope is steep in one direction and flat in another (a ravine), SGD oscillates or moves slowly.
-   **Adam** scales the step size for each parameter individually (diagonal scaling). This helps with axis-aligned ravines but ignores correlations between parameters.

### The Muon Approach
Muon goes a step further. Instead of just scaling individual parameters, it considers the **entire geometry** of the weight matrix update. It tries to **orthogonalize** the update steps.

**Why?** 
In deep learning, weight matrices often develop a few very large "singular values" (dominant directions) that overshadow everything else. Gradients in these directions are huge, while gradients in other useful directions are tiny. 

Muon forces the update matrix to be "orthogonal" (or close to it). This effectively flattens the curvature in all directions, allowing the optimizer to make progress everywhere at once. It's like turning a narrow ravine into a smooth bowl.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. The Core: Newton-Schulz Iteration

To orthogonalize a matrix $G$, we want to map it to $U V^T$ where $G = U \Sigma V^T$ is the SVD. This is equivalent to $G (G^T G)^{-1/2}$.

Computing SVD is very slow on GPUs. **Newton-Schulz iteration** is a fast, iterative method to approximate this without calculating eigenvalues explicitly. It only uses matrix multiplications!

Let's look at the provided implementation:

In [None]:
@torch.compile
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
    We want to find X such that X @ X.T is approximately Identity.
    """
    assert G.ndim >= 2
    
    # Constants for the quintic (5th order) iteration
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.half() # Run in FP16 for speed/memory, usually sufficient for updates

    # Ensure we are working with a 'short and fat' or square matrix for stability
    # If it's 'tall and skinny', we transpose it.
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Preconditioning: Scale X so its norm is close to 1. 
    # This is crucial for Newton-Schulz convergence.
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

    # The Iteration Loop
    for _ in range(steps):
        A = X @ X.mT
        # The update rule: X_new = aX + bAX + cA^2X
        B = b * A + c * A @ A
        X = a * X + B @ X

    # Untranspose if we transposed earlier
    if G.size(-2) > G.size(-1):
        X = X.mT

    return X

### ðŸ§ª Interactive Experiment: Visualizing Orthogonalization

Let's create a random, ill-conditioned matrix (one direction is much stronger than others) and see what `zeropower_via_newtonschulz5` does to it.

In [None]:
# 1. Create a random matrix
torch.manual_seed(0)
G = torch.randn(100, 100).to(device)

# 2. Make it ill-conditioned (multiply one dimension by 100)
G[:, 0] *= 100

# 3. Apply Newton-Schulz
O = zeropower_via_newtonschulz5(G, steps=5)

# 4. Check Orthogonality: O @ O.T should be close to Identity
gram = O @ O.mT
identity = torch.eye(100, device=device).half()

# Visualize
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.title("Original G @ G.T (Log Scale)")
plt.imshow((G @ G.mT).abs().log().cpu().float(), cmap='viridis')
plt.colorbar()

plt.subplot(1, 2, 2)
plt.title("Orthogonalized O @ O.T")
plt.imshow(gram.cpu().float(), cmap='viridis')
plt.colorbar()

plt.show()

print(f"Max deviation from Identity: {(gram - identity).abs().max().item():.6f}")

**Observation**: The original matrix product has huge values (yellow spots). The orthogonalized version is a clean diagonal line (Identity matrix). This means `O` has successfully equalized the energy in all directions!

## 3. Implementing the Optimizer

Now we wrap this logic into a PyTorch Optimizer.

**Key Features**:
1.  **Momentum**: We don't just orthogonalize the raw gradient `g`. We maintain a momentum buffer `buf` (moving average of gradients) and orthogonalize *that*.
2.  **Nesterov**: Optionally applies Nesterov momentum correction.
3.  **Scaling**: The update is scaled by `max(1, rows/cols)**0.5`. This is a heuristic to handle rectangular matrices correctly.

In [None]:
class Muon(torch.optim.Optimizer):
    """Muon - MomentUm Orthogonalized by Newton-schulz"""
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                g = p.grad
                state = self.state[p]

                # 1. Initialize Momentum Buffer
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)

                # 2. Update Momentum (using lerp for stability)
                # buf = buf * momentum + g * (1 - momentum)
                buf = state["momentum_buffer"]
                buf.lerp_(g, 1 - group["momentum"])
                
                # 3. Nesterov Correction (optional)
                # If Nesterov, we use the 'lookahead' gradient for the update
                g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                
                # 4. Orthogonalize the Update
                # This is the MAGIC step. Instead of just subtracting g, 
                # we subtract a "whitened" version of g.
                g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                
                # 5. Apply Update
                # We scale the learning rate based on matrix shape.
                scale_factor = max(1, p.size(-2) / p.size(-1))**0.5
                p.add_(g.view_as(p), alpha=-group["lr"] * scale_factor)

## 4. The Showdown: Muon vs. Adam vs. SGD

We will train a Deep Linear Network on a synthetic task. Deep linear networks are notoriously hard to train because gradients vanish or explode, and the curvature becomes very ill-conditioned.

**Task**: Learn to map a 64-dim input to a 64-dim output via a 10-layer network.

In [None]:
# Configuration
N_LAYERS = 10
DIM = 64
BATCH_SIZE = 128
STEPS = 500

# Synthetic Data
X_train = torch.randn(1000, DIM, device=device)
Y_train = torch.randn(1000, DIM, device=device) # Random target

def get_model():
    layers = []
    for _ in range(N_LAYERS):
        # bias=False is important for this specific Muon implementation 
        # because it expects 2D tensors (matrices).
        layers.append(nn.Linear(DIM, DIM, bias=False))
    return nn.Sequential(*layers).to(device)

def train_optimizer(optim_cls, name, lr, **kwargs):
    torch.manual_seed(42)
    model = get_model()
    optimizer = optim_cls(model.parameters(), lr=lr, **kwargs)
    
    losses = []
    
    for step in range(STEPS):
        # Batch sampling
        indices = torch.randint(0, 1000, (BATCH_SIZE,))
        x_batch = X_train[indices]
        y_batch = Y_train[indices]
        
        # Forward
        pred = model(x_batch)
        loss = F.mse_loss(pred, y_batch)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
    return losses

print("Training SGD...")
loss_sgd = train_optimizer(torch.optim.SGD, "SGD", lr=0.01, momentum=0.9)

print("Training Adam...")
loss_adam = train_optimizer(torch.optim.Adam, "Adam", lr=0.001)

print("Training Muon...")
loss_muon = train_optimizer(Muon, "Muon", lr=0.02)

In [None]:
# Plotting Results
plt.figure(figsize=(10, 6))
plt.plot(loss_sgd, label='SGD (0.01)', alpha=0.7)
plt.plot(loss_adam, label='Adam (0.001)', alpha=0.7)
plt.plot(loss_muon, label='Muon (0.02)', linewidth=2, color='red')

plt.yscale('log')
plt.title(f"Optimization Speed on {N_LAYERS}-Layer Linear Network")
plt.xlabel("Steps")
plt.ylabel("MSE Loss (Log Scale)")
plt.legend()
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.show()

## 5. Analysis: Why did Muon win?

In deep linear networks, the singular values of the weight matrices tend to spread out. Some become huge, some tiny. 
-   **SGD** gets stuck bouncing between the huge directions.
-   **Muon** forces the update to be orthogonal. This means it pushes equally hard in *all* directions, effectively ignoring the fact that some directions are "steep" and others are "flat". It traverses the landscape much more efficiently.

### Summary
1.  **Muon** combines Momentum with Newton-Schulz orthogonalization.
2.  It is particularly effective for **large-scale training** (like Transformers) where spectral scaling issues are prominent.
3.  It adds a computational cost (the matrix multiplications in Newton-Schulz), but often converges in far fewer steps, saving total time.