In [6]:
from lib.lorentz.layers import LorentzBatchNorm1d, LorentzBatchNorm2d
from lib.lorentz.manifold import CustomLorentz

In [7]:
man = CustomLorentz()

In [16]:
"""
Diagnostic script to visualize Lorentz BatchNorm train/eval discrepancy.

Run this in your environment where lib.lorentz is available.
"""

import torch
import torch.nn as nn
from lib.lorentz.layers import LorentzBatchNorm1d, LorentzBatchNorm2d
from lib.lorentz.manifold import CustomLorentz

torch.manual_seed(42)

def generate_lorentz_points(manifold, batch_size, channels, height=None, width=None, spread=0.5):
    """Generate random points on the hyperboloid."""
    if height is not None and width is not None:
        # For 2D: [B, C, H, W] where C includes time dimension
        shape = (batch_size, channels, height, width)
    else:
        # For 1D: [B, C]
        shape = (batch_size, channels)
    
    # Generate in tangent space at origin, then project
    space_components = torch.randn(*shape) * spread
    # Time component to satisfy hyperboloid constraint: t^2 - ||x||^2 = 1/k
    # For k=1: t = sqrt(1 + ||x||^2)
    if len(shape) == 4:
        space_norm_sq = (space_components[:, 1:, :, :] ** 2).sum(dim=1, keepdim=True)
    else:
        space_norm_sq = (space_components[:, 1:] ** 2).sum(dim=1, keepdim=True)
    
    time_component = torch.sqrt(1.0 + space_norm_sq)
    
    if len(shape) == 4:
        space_components[:, 0:1, :, :] = time_component
    else:
        space_components[:, 0:1] = time_component
    
    return space_components


def shift_points(manifold, points, shift_amount, dim=1):
    """Shift points in a given spatial direction."""
    shifted = points.clone()
    # Add to spatial component
    if len(points.shape) == 4:
        shifted[:, dim, :, :] += shift_amount
        # Recompute time to stay on manifold
        space_norm_sq = (shifted[:, 1:, :, :] ** 2).sum(dim=1, keepdim=True)
        shifted[:, 0:1, :, :] = torch.sqrt(1.0 + space_norm_sq)
    else:
        shifted[:, dim] += shift_amount
        space_norm_sq = (shifted[:, 1:] ** 2).sum(dim=1, keepdim=True)
        shifted[:, 0:1] = torch.sqrt(1.0 + space_norm_sq)
    return shifted


In [28]:
def test_train_eval_discrepancy_1d():
    """Test 1D BatchNorm discrepancy."""
    print("=" * 70)
    print("TEST: LorentzBatchNorm1d Train vs Eval Discrepancy")
    print("=" * 70)
    
    manifold = CustomLorentz()
    channels = 16  # includes time dimension
    bn = LorentzBatchNorm1d(manifold, channels)  # -1 because channels includes time
    
    # Simulate training with drifting data
    print("\n--- Training Phase (5 batches with drift) ---")
    bn.train()
    
    for i in range(5):
        # Generate batch, shift more each iteration
        batch = generate_lorentz_points(manifold, 32, channels, spread=0.3)
        batch = shift_points(manifold, batch, shift_amount=0.3 * i, dim=1)
        print(batch.shape)
        
        output = bn(batch, momentum=0.1)
        
        # Compute output statistics
        space_norms = torch.norm(output[:, 1:], dim=1)
        print(f"Batch {i}: input_space_norm={torch.norm(batch[:, 1:], dim=1).mean():.4f}, "
              f"output_space_norm={space_norms.mean():.4f}")
    
    # Now test in eval mode on similar data
    print("\n--- Eval Phase (same distribution as last training batch) ---")
    bn.eval()
    
    # Generate test data similar to last training batch
    test_batch = generate_lorentz_points(manifold, 32, channels, spread=0.3)
    test_batch = shift_points(manifold, test_batch, shift_amount=0.3 * 4, dim=1)
    
    with torch.no_grad():
        # Get output in eval mode
        output_eval = bn(test_batch, momentum=0.1)
        
        # For comparison, temporarily switch to train mode
        bn.train()
        output_train = bn(test_batch, momentum=0.1)
        bn.eval()
    
    print(f"\nSame input, different modes:")
    print(f"  Input space norm:        {torch.norm(test_batch[:, 1:], dim=1).mean():.4f}")
    print(f"  Output (train mode):     {torch.norm(output_train[:, 1:], dim=1).mean():.4f}")
    print(f"  Output (eval mode):      {torch.norm(output_eval[:, 1:], dim=1).mean():.4f}")
    
    # Check point-wise difference
    diff = torch.norm(output_train - output_eval, dim=1)
    print(f"\n  Point-wise difference (L2): mean={diff.mean():.4f}, max={diff.max():.4f}")
    
    # Check if outputs are on manifold
    def check_manifold(x, name):
        # Should satisfy t^2 - ||s||^2 = 1
        constraint = x[:, 0]**2 - (x[:, 1:]**2).sum(dim=1)
        print(f"  {name} manifold constraint (should be ~1): mean={constraint.mean():.4f}, std={constraint.std():.6f}")
    
    print("\n--- Manifold Constraint Check ---")
    check_manifold(test_batch, "Input")
    check_manifold(output_train, "Output (train)")
    check_manifold(output_eval, "Output (eval)")
test_train_eval_discrepancy_1d()

TEST: LorentzBatchNorm1d Train vs Eval Discrepancy

--- Training Phase (5 batches with drift) ---
torch.Size([32, 16])
Batch 0: input_space_norm=1.0538, output_space_norm=1.1903
torch.Size([32, 16])
Batch 1: input_space_norm=1.1937, output_space_norm=1.1892
torch.Size([32, 16])
Batch 2: input_space_norm=1.2777, output_space_norm=1.1854
torch.Size([32, 16])
Batch 3: input_space_norm=1.4315, output_space_norm=1.1880
torch.Size([32, 16])
Batch 4: input_space_norm=1.6288, output_space_norm=1.1858

--- Eval Phase (same distribution as last training batch) ---

Same input, different modes:
  Input space norm:        1.6486
  Output (train mode):     1.1864
  Output (eval mode):      1.4896

  Point-wise difference (L2): mean=0.9616, max=1.1411

--- Manifold Constraint Check ---
  Input manifold constraint (should be ~1): mean=1.0000, std=0.000000
  Output (train) manifold constraint (should be ~1): mean=1.0000, std=0.000000
  Output (eval) manifold constraint (should be ~1): mean=1.0000, std

In [31]:
def test_train_eval_discrepancy_2d():
    """Test 2D BatchNorm discrepancy."""
    print("\n" + "=" * 70)
    print("TEST: LorentzBatchNorm2d Train vs Eval Discrepancy")
    print("=" * 70)
    
    manifold = CustomLorentz()
    channels = 8  # includes time dimension
    height, width = 4, 4
    bn = LorentzBatchNorm2d(manifold, channels)
    
    # Simulate training
    print("\n--- Training Phase (5 batches with drift) ---")
    bn.train()
    
    for i in range(5):
        batch = generate_lorentz_points(manifold, 16, channels, height, width, spread=0.3)
        batch = shift_points(manifold, batch, shift_amount=0.5 * i, dim=1)

        batch = batch.permute(0, 2, 3, 1)
        
        output = bn(batch, momentum=0.1)
        
        space_norms = torch.norm(output[:, 1:], dim=1).mean()
        print(f"Batch {i}: output_space_norm={space_norms:.4f}")
    
    # Eval mode
    print("\n--- Eval Phase ---")
    bn.eval()
    
    test_batch = generate_lorentz_points(manifold, 16, channels, height, width, spread=0.3)
    test_batch = shift_points(manifold, test_batch, shift_amount=0.5 * 4, dim=1)
    
    with torch.no_grad():
        test_batch = test_batch.permute(0, 2, 3, 1)
        output_eval = bn(test_batch, momentum=0.1)
        
        bn.train()
        output_train = bn(test_batch, momentum=0.1)
        bn.eval()
    
    print(f"\nSame input, different modes:")
    print(f"  Output (train mode): space_norm={torch.norm(output_train[:, 1:], dim=1).mean():.4f}")
    print(f"  Output (eval mode):  space_norm={torch.norm(output_eval[:, 1:], dim=1).mean():.4f}")
    
    ratio = torch.norm(output_eval[:, 1:], dim=1).mean() / torch.norm(output_train[:, 1:], dim=1).mean()
    print(f"  Ratio (eval/train): {ratio:.4f}")
test_train_eval_discrepancy_2d()


TEST: LorentzBatchNorm2d Train vs Eval Discrepancy

--- Training Phase (5 batches with drift) ---
Batch 0: output_space_norm=0.9965
Batch 1: output_space_norm=0.9958
Batch 2: output_space_norm=0.9939
Batch 3: output_space_norm=0.9976
Batch 4: output_space_norm=1.0243

--- Eval Phase ---

Same input, different modes:
  Output (train mode): space_norm=1.0023
  Output (eval mode):  space_norm=1.2545
  Ratio (eval/train): 1.2517


In [33]:
def test_running_stats_analysis():
    """Analyze what's stored in running statistics."""
    print("\n" + "=" * 70)
    print("TEST: Running Statistics Analysis")
    print("=" * 70)
    
    manifold = CustomLorentz()
    channels = 16
    bn = LorentzBatchNorm1d(manifold, channels)
    
    bn.train()
    
    print("\n--- Tracking running statistics during training ---")
    
    for i in range(10):
        batch = generate_lorentz_points(manifold, 32, channels, spread=0.3)
        batch = shift_points(manifold, batch, shift_amount=0.2 * i, dim=1)
        
        # Before forward pass
        running_mean_before = bn.running_mean.clone() if hasattr(bn, 'running_mean') else None
        running_var_before = bn.running_var.clone() if hasattr(bn, 'running_var') else None
        
        output = bn(batch, momentum=0.1)
        
        # After forward pass
        if hasattr(bn, 'running_mean') and hasattr(bn, 'running_var'):
            print(f"Batch {i}: running_var={bn.running_var.item():.4f}, "
                  f"running_mean_norm={torch.norm(bn.running_mean):.4f}")
    
    print("\n--- Final running statistics ---")
    if hasattr(bn, 'running_mean'):
        print(f"running_mean: {bn.running_mean[:5]}...")  # First 5 elements
    if hasattr(bn, 'running_var'):
        print(f"running_var: {bn.running_var}")
test_running_stats_analysis()


TEST: Running Statistics Analysis

--- Tracking running statistics during training ---
Batch 0: running_var=0.9973, running_mean_norm=0.0166
Batch 1: running_var=0.9920, running_mean_norm=0.0209
Batch 2: running_var=0.9912, running_mean_norm=0.0399
Batch 3: running_var=0.9900, running_mean_norm=0.0747
Batch 4: running_var=0.9825, running_mean_norm=0.1196
Batch 5: running_var=0.9810, running_mean_norm=0.1699
Batch 6: running_var=0.9779, running_mean_norm=0.2260
Batch 7: running_var=0.9703, running_mean_norm=0.2896
Batch 8: running_var=0.9672, running_mean_norm=0.3587
Batch 9: running_var=0.9632, running_mean_norm=0.4264

--- Final running statistics ---
running_mean: tensor([ 0.0000,  0.4256,  0.0081, -0.0056,  0.0031])...
running_var: tensor([0.9632])


In [34]:
def test_extreme_drift():
    """Test with extreme distribution shift between train and eval."""
    print("\n" + "=" * 70)
    print("TEST: Extreme Distribution Shift")
    print("=" * 70)
    
    manifold = CustomLorentz()
    channels = 16
    bn = LorentzBatchNorm1d(manifold, channels)
    
    # Train on data near origin
    print("\n--- Training on data near origin ---")
    bn.train()
    for i in range(10):
        batch = generate_lorentz_points(manifold, 32, channels, spread=0.3)
        output = bn(batch, momentum=0.1)
    
    print(f"After training: running_var={bn.running_var.item():.4f}")
    
    # Eval on data FAR from origin
    print("\n--- Evaluating on data far from origin ---")
    bn.eval()
    
    for shift in [0, 1, 2, 5, 10]:
        test_batch = generate_lorentz_points(manifold, 32, channels, spread=0.3)
        test_batch = shift_points(manifold, test_batch, shift_amount=shift, dim=1)
        
        with torch.no_grad():
            output = bn(test_batch, momentum=0.1)
        
        output_norm = torch.norm(output[:, 1:], dim=1).mean()
        print(f"Shift={shift}: output_space_norm={output_norm:.4f}")
        
        # Check for NaN/Inf
        if torch.isnan(output).any() or torch.isinf(output).any():
            print(f"  WARNING: NaN or Inf detected!")
test_extreme_drift()


TEST: Extreme Distribution Shift

--- Training on data near origin ---
After training: running_var=0.9734

--- Evaluating on data far from origin ---
Shift=0: output_space_norm=1.1924
Shift=1: output_space_norm=1.5842
Shift=2: output_space_norm=2.4275
Shift=5: output_space_norm=5.4845
Shift=10: output_space_norm=10.9300


In [36]:
def test_validation_noise_simulation():
    """
    Simulate the validation noise you're seeing during cosine annealing.
    
    The hypothesis: validation uses eval mode (running stats) while
    training uses train mode (batch stats). During cosine annealing,
    the model changes but running stats become stale.
    """
    print("\n" + "=" * 70)
    print("TEST: Validation Noise Simulation (Cosine Annealing Scenario)")
    print("=" * 70)
    
    manifold = CustomLorentz()
    channels = 16
    bn = LorentzBatchNorm1d(manifold, channels)
    
    # Simulate: features evolve during training (as weights change)
    # but running stats lag behind
    
    print("\n--- Simulating training with evolving features ---")
    
    train_outputs = []
    eval_outputs = []
    
    for epoch in range(20):
        # Feature distribution evolves (simulating effect of changing weights)
        feature_scale = 0.3 + 0.02 * epoch  # gradually increasing
        feature_shift = 0.1 * epoch  # gradually drifting
        
        # Training step
        bn.train()
        batch = generate_lorentz_points(manifold, 32, channels, spread=feature_scale)
        batch = shift_points(manifold, batch, shift_amount=feature_shift, dim=1)
        train_out = bn(batch, momentum=0.1)
        train_outputs.append(torch.norm(train_out[:, 1:], dim=1).mean().item())
        
        # Validation step (same data distribution, but eval mode)
        bn.eval()
        val_batch = generate_lorentz_points(manifold, 32, channels, spread=feature_scale)
        val_batch = shift_points(manifold, val_batch, shift_amount=feature_shift, dim=1)
        with torch.no_grad():
            eval_out = bn(val_batch, momentum=0.1)
        eval_outputs.append(torch.norm(eval_out[:, 1:], dim=1).mean().item())
    
    print("\nEpoch | Train Output Norm | Eval Output Norm | Ratio")
    print("-" * 60)
    for i in range(20):
        ratio = eval_outputs[i] / train_outputs[i] if train_outputs[i] > 0 else float('inf')
        print(f"  {i:2d}  |      {train_outputs[i]:.4f}       |      {eval_outputs[i]:.4f}      | {ratio:.4f}")
    
    print("\n>>> Notice how the ratio diverges as training progresses!")
    print(">>> This is because running stats become increasingly stale.")
test_validation_noise_simulation()


TEST: Validation Noise Simulation (Cosine Annealing Scenario)

--- Simulating training with evolving features ---

Epoch | Train Output Norm | Eval Output Norm | Ratio
------------------------------------------------------------
   0  |      1.1883       |      1.1654      | 0.9807
   1  |      1.1844       |      1.2229      | 1.0326
   2  |      1.1851       |      1.3712      | 1.1571
   3  |      1.1827       |      1.4078      | 1.1904
   4  |      1.1872       |      1.3955      | 1.1755
   5  |      1.1850       |      1.5399      | 1.2995
   6  |      1.1889       |      1.5695      | 1.3201
   7  |      1.1848       |      1.5182      | 1.2814
   8  |      1.1874       |      1.4906      | 1.2553
   9  |      1.1818       |      1.5970      | 1.3513
  10  |      1.1868       |      1.6039      | 1.3514
  11  |      1.1827       |      1.6135      | 1.3642
  12  |      1.1852       |      1.5791      | 1.3323
  13  |      1.1819       |      1.6097      | 1.3619
  14  |      1