# BJLKeng-Style Real NVP for MNIST

Based on the working implementation from BJLKeng's blog post and GitHub repo. This version follows their exact approach that achieved **1.92 bits/dim** on MNIST.

## Key Principles from BJLKeng:
1. **Simple preprocessing**: Just scale MNIST to [0,1] - no complex transforms
2. **Batch norm in coupling layers**: With proper loss accounting
3. **Small learning rate**: 0.0005 ("slow learners")
4. **Zero initialization**: Start networks at identity
5. **Long training**: Many epochs needed

**Target**: ~1.9 bits/dim (BJLKeng's result)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from keras.datasets.mnist import load_data
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.pyplot as plt

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## BJLKeng's Simple Data Preprocessing

**Quote**: *"Eschewed the pixel transform for MNIST because it's not really a natural image. Scaling pixel values to [0, 1] seemed to work better."*

In [None]:
# Load MNIST - BJLKeng's simple approach
(trainX, trainY), (testX, testY) = load_data()

# Simple preprocessing like BJLKeng
def bjlkeng_preprocess(data, add_noise=True):
    """BJLKeng's preprocessing: simple scaling to [0,1]"""
    data = data.astype(np.float32)
    
    if add_noise:
        # Add uniform noise for dequantization
        data = data + np.random.uniform(0, 1, data.shape)
    
    # Scale to [0, 1] - BJLKeng's approach
    data = data / 256.0
    
    return data

trainX_processed = bjlkeng_preprocess(trainX, add_noise=True)
testX_processed = bjlkeng_preprocess(testX, add_noise=False)

# Convert to tensors
trainX = torch.tensor(trainX_processed, dtype=torch.float32).unsqueeze(1)
testX = torch.tensor(testX_processed, dtype=torch.float32).unsqueeze(1)

print(f"Training data shape: {trainX.shape}")
print(f"Data range: [{trainX.min():.3f}, {trainX.max():.3f}]")
print(f"Data mean: {trainX.mean():.3f}")
print(f"Data std: {trainX.std():.3f}")

## Standard Normal Base Distribution

Following BJLKeng's approach with standard Gaussian prior.

In [None]:
class StandardGaussian:
    """Standard Gaussian base distribution"""
    
    def __init__(self, shape=(1, 28, 28), device='cpu'):
        self.shape = shape
        self.device = device
        self.dim = np.prod(shape)
    
    def log_prob(self, z):
        """Log probability of standard Gaussian"""
        z_flat = z.view(z.size(0), -1)
        # Standard Gaussian: log p(z) = -0.5 * (z^2 + log(2π))
        log_prob = -0.5 * (z_flat**2 + np.log(2 * np.pi))
        return log_prob.sum(dim=1)
    
    def sample(self, n_samples=1):
        """Sample from standard Gaussian"""
        return torch.randn(n_samples, *self.shape, device=self.device)

base_dist = StandardGaussian(device=device)

# Test base distribution
test_z = base_dist.sample(5)
test_logp = base_dist.log_prob(test_z)
print(f"Base sample shape: {test_z.shape}")
print(f"Base log prob: {test_logp.mean():.1f} (expected: ~{-0.5 * 784 * (1 + np.log(2*np.pi)):.1f})")

## BJLKeng's Coupling Network with Batch Norm

Key features:
- **Batch normalization** in coupling networks
- **Zero initialization** for identity start
- **Running average batch norm** with modified momentum

In [None]:
def create_checkerboard_mask(h, w, reverse=False):
    """Checkerboard mask"""
    mask = torch.zeros(h, w)
    mask[::2, ::2] = 1
    mask[1::2, 1::2] = 1
    if reverse:
        mask = 1 - mask
    return mask


class BJLKengCouplingNet(nn.Module):
    """BJLKeng's coupling network with batch norm"""
    
    def __init__(self, in_channels=1, hidden_channels=64, num_layers=3):
        super().__init__()
        
        layers = []
        
        # Input layer
        layers.append(nn.Conv2d(in_channels, hidden_channels, 3, padding=1))
        layers.append(nn.BatchNorm2d(hidden_channels, momentum=0.05))  # BJLKeng uses small momentum
        layers.append(nn.ReLU())
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1))
            layers.append(nn.BatchNorm2d(hidden_channels, momentum=0.05))
            layers.append(nn.ReLU())
        
        # Output layer - outputs both s and t
        layers.append(nn.Conv2d(hidden_channels, 2 * in_channels, 3, padding=1))
        
        self.network = nn.Sequential(*layers)
        
        # BJLKeng: Initialize output layer to zero for identity start
        nn.init.zeros_(self.network[-1].weight)
        nn.init.zeros_(self.network[-1].bias)
        
    def forward(self, x):
        out = self.network(x)
        s, t = torch.chunk(out, 2, dim=1)
        
        # BJLKeng: tanh scaling for s
        s = torch.tanh(s)  # [-1, 1] range
        
        return s, t


class BJLKengCouplingLayer(nn.Module):
    """BJLKeng's coupling layer implementation"""
    
    def __init__(self, mask, coupling_net):
        super().__init__()
        self.register_buffer('mask', mask.unsqueeze(0).unsqueeze(0))  # [1, 1, H, W]
        self.coupling_net = coupling_net
        
    def forward(self, x, compute_log_det=True):
        """Forward pass with optional log determinant"""
        # Split according to mask
        x_frozen = x * (1 - self.mask)
        x_active = x * self.mask
        
        # Get s and t from frozen part
        s, t = self.coupling_net(x_frozen)
        
        # Apply transformation only to active part
        s_active = s * self.mask
        t_active = t * self.mask
        
        # Affine transformation
        z_active = x_active * torch.exp(s_active) + t_active
        z = x_frozen + z_active
        
        if compute_log_det:
            # Log determinant
            log_det = s_active.sum(dim=[1, 2, 3])
            return z, log_det
        else:
            return z
    
    def inverse(self, z):
        """Inverse transformation"""
        z_frozen = z * (1 - self.mask)
        z_active = z * self.mask
        
        s, t = self.coupling_net(z_frozen)
        s_active = s * self.mask
        t_active = t * self.mask
        
        # Inverse affine transformation
        x_active = (z_active - t_active) * torch.exp(-s_active)
        x = z_frozen + x_active
        
        return x

# Test coupling layer
test_mask = create_checkerboard_mask(28, 28)
test_coupling_net = BJLKengCouplingNet()
test_layer = BJLKengCouplingLayer(test_mask, test_coupling_net)

test_x = torch.randn(2, 1, 28, 28) * 0.1
test_z, test_logdet = test_layer(test_x)
test_x_recon = test_layer.inverse(test_z)

print(f"Coupling layer test:")
print(f"  Input shape: {test_x.shape}")
print(f"  Output shape: {test_z.shape}")
print(f"  Reconstruction error: {torch.max(torch.abs(test_x - test_x_recon)):.8f}")
print(f"  Log determinant: {test_logdet}")

## BJLKeng's Real NVP Model

In [None]:
class BJLKengRealNVP(nn.Module):
    """BJLKeng's Real NVP implementation"""
    
    def __init__(self, num_coupling_layers=8, hidden_channels=64):
        super().__init__()
        
        self.num_coupling_layers = num_coupling_layers
        self.coupling_layers = nn.ModuleList()
        
        # Create coupling layers with alternating masks
        for i in range(num_coupling_layers):
            mask = create_checkerboard_mask(28, 28, reverse=(i % 2 == 1))
            coupling_net = BJLKengCouplingNet(hidden_channels=hidden_channels)
            coupling_layer = BJLKengCouplingLayer(mask, coupling_net)
            self.coupling_layers.append(coupling_layer)
    
    def forward(self, x):
        """Forward pass through all coupling layers"""
        z = x
        total_log_det = 0
        
        for layer in self.coupling_layers:
            z, log_det = layer(z)
            total_log_det += log_det
        
        return z, total_log_det
    
    def inverse(self, z):
        """Inverse pass (generation)"""
        x = z
        for layer in reversed(self.coupling_layers):
            x = layer.inverse(x)
        return x
    
    def sample(self, base_dist, n_samples=64):
        """Generate samples"""
        self.eval()
        with torch.no_grad():
            z = base_dist.sample(n_samples)
            if z.device != next(self.parameters()).device:
                z = z.to(next(self.parameters()).device)
            x = self.inverse(z)
        self.train()
        return x

# Create model
model = BJLKengRealNVP(num_coupling_layers=8, hidden_channels=64).to(device)

# Test model
test_batch = trainX[:4].to(device)
test_z, test_logdet = model(test_batch)
test_recon = model.inverse(test_z)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Forward test: {test_batch.shape} -> {test_z.shape}")
print(f"Reconstruction error: {torch.max(torch.abs(test_batch - test_recon)):.8f}")
print(f"Log determinant range: [{test_logdet.min():.2f}, {test_logdet.max():.2f}]")

## BJLKeng's Loss Function with Batch Norm Accounting

Key insight: BJLKeng accounts for batch normalization in the loss function since it's also a transformation.

In [None]:
def bjlkeng_loss_and_bpd(model, base_dist, batch, l2_reg=1e-5):
    """BJLKeng's loss computation with batch norm accounting"""
    
    # Forward pass
    z, coupling_log_det = model(batch)
    
    # Base distribution log probability
    base_log_prob = base_dist.log_prob(z)
    
    # Account for batch norm scaling (BJLKeng insight)
    # Batch norm also transforms the density
    bn_log_det = 0
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            # Batch norm log determinant: log|det(scale/std)|
            bn_log_det += torch.sum(torch.log(torch.abs(module.weight / (module.running_var.sqrt() + module.eps))))
    
    # Account for dequantization scaling: [0,255] -> [0,1]
    # Each pixel is scaled by 1/256, so density is scaled by 256
    num_pixels = np.prod(batch.shape[1:])
    dequant_log_det = num_pixels * np.log(256.0)
    
    # Total log likelihood
    total_log_det = coupling_log_det + bn_log_det + dequant_log_det
    log_likelihood = base_log_prob + total_log_det
    
    # L2 regularization on scale parameters (BJLKeng approach)
    l2_loss = 0
    for name, param in model.named_parameters():
        if 'weight' in name and 'coupling' in name:
            l2_loss += torch.sum(param**2)
    
    # Negative log likelihood (our loss)
    nll = -log_likelihood.mean() + l2_reg * l2_loss
    
    # Bits per dimension
    bpd = nll / (np.log(2) * num_pixels)
    
    return nll, bpd, base_log_prob.mean(), coupling_log_det.mean(), bn_log_det, dequant_log_det

# Test loss computation
test_batch = trainX[:8].to(device)
test_nll, test_bpd, test_base_lp, test_coupling_ld, test_bn_ld, test_dequant_ld = bjlkeng_loss_and_bpd(model, base_dist, test_batch)

print("=== BJLKeng Loss Computation Test ===")
print(f"NLL: {test_nll.item():.3f}")
print(f"BPD: {test_bpd.item():.3f} (should be positive!)")
print(f"Base log prob: {test_base_lp.item():.1f}")
print(f"Coupling log det: {test_coupling_ld.item():.1f}")
print(f"Batch norm log det: {test_bn_ld:.1f}")
print(f"Dequantization log det: {test_dequant_ld:.1f}")

if test_bpd.item() > 0:
    print("\n🎉 SUCCESS! BPD is positive with BJLKeng's approach!")
else:
    print("\n😞 Still negative - need more debugging")

## BJLKeng's Training Approach

Key training details:
- **Small learning rate**: 0.0005 ("slow learners")
- **Long training**: Many epochs
- **L2 regularization**: Small weight decay

In [None]:
def train_bjlkeng_realnvp(model, base_dist, dataloader, epochs=100, lr=0.0005):
    """BJLKeng's training approach"""
    
    # BJLKeng uses small learning rate
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.8)
    
    losses = []
    bpds = []
    
    model.train()
    
    for epoch in tqdm(range(epochs), desc="Training BJLKeng Real NVP"):
        epoch_loss = 0
        epoch_bpd = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(dataloader):
            if isinstance(batch, (list, tuple)):
                batch = batch[0]
            batch = batch.to(device)
            
            # BJLKeng's loss computation
            nll, bpd, base_lp, coupling_ld, bn_ld, dequant_ld = bjlkeng_loss_and_bpd(model, base_dist, batch)
            
            # Backward pass
            optimizer.zero_grad()
            nll.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            epoch_loss += nll.item()
            epoch_bpd += bpd.item()
            num_batches += 1
            
            # Print progress
            if batch_idx % 300 == 0:
                print(f"Epoch {epoch+1:3d}, Batch {batch_idx:3d}: "
                      f"NLL={nll.item():6.3f}, "
                      f"BPD={bpd.item():5.3f}, "
                      f"Base_LP={base_lp.item():7.1f}, "
                      f"Coupling_LD={coupling_ld.item():6.1f}")
        
        # Store epoch metrics
        avg_loss = epoch_loss / num_batches
        avg_bpd = epoch_bpd / num_batches
        losses.append(avg_loss)
        bpds.append(avg_bpd)
        
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Progress update
        if (epoch + 1) % 10 == 0:
            print(f"\nEpoch {epoch+1}: BPD = {avg_bpd:.3f}, LR = {current_lr:.6f}")
            
            # Generate samples
            samples = model.sample(base_dist, n_samples=16)
            samples = torch.clamp(samples, 0, 1)
            
            fig, axes = plt.subplots(4, 4, figsize=(8, 8))
            for i in range(4):
                for j in range(4):
                    idx = i * 4 + j
                    axes[i, j].imshow(samples[idx, 0].cpu(), cmap='gray')
                    axes[i, j].set_xticks([])
                    axes[i, j].set_yticks([])
            
            plt.suptitle(f'BJLKeng Real NVP Samples - Epoch {epoch+1} (BPD: {avg_bpd:.3f})')
            plt.tight_layout()
            plt.show()
    
    return losses, bpds

# Setup training
dataset = TensorDataset(trainX)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)

print(f"Ready to train with BJLKeng's approach")
print(f"Target: ~1.9 bits/dim (BJLKeng's result)")
print(f"Paper result: 1.06 bits/dim")

## Start BJLKeng Training

In [None]:
if test_bpd.item() > 0:
    print("🚀 Starting BJLKeng-style training...")
    
    bjlkeng_losses, bjlkeng_bpds = train_bjlkeng_realnvp(
        model=model,
        base_dist=base_dist,
        dataloader=dataloader,
        epochs=50,  # Start with fewer epochs for testing
        lr=0.0005   # BJLKeng's learning rate
    )
    
    # Plot results
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(bjlkeng_losses)
    plt.title('Training Loss (NLL)')
    plt.xlabel('Epoch')
    plt.ylabel('Negative Log Likelihood')
    plt.grid(True)
    
    plt.subplot(1, 3, 2)
    plt.plot(bjlkeng_bpds)
    plt.title('Bits Per Dimension')
    plt.xlabel('Epoch')
    plt.ylabel('BPD')
    plt.axhline(y=1.92, color='g', linestyle='--', label='BJLKeng Result')
    plt.axhline(y=1.06, color='r', linestyle='--', label='Paper Result')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 3, 3)
    # Show improvement over time
    improvement = bjlkeng_bpds[0] - np.array(bjlkeng_bpds)
    plt.plot(improvement)
    plt.title('BPD Improvement')
    plt.xlabel('Epoch')
    plt.ylabel('BPD Reduction from Start')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n🎯 BJLKeng-Style Results:")
    print(f"Initial BPD: {bjlkeng_bpds[0]:.3f}")
    print(f"Final BPD: {bjlkeng_bpds[-1]:.3f}")
    print(f"BJLKeng's result: 1.92")
    print(f"Paper result: 1.06")
    print(f"Total improvement: {bjlkeng_bpds[0] - bjlkeng_bpds[-1]:.3f} bits")
    
    if bjlkeng_bpds[-1] < 3.0:
        print("\n🎉 Success! Achieving reasonable BPD values!")
    
else:
    print("❌ BPD test failed - check loss computation")

## Final Sample Generation and Evaluation

In [None]:
# Generate final samples
if 'bjlkeng_bpds' in locals():
    model.eval()
    with torch.no_grad():
        # Generate larger batch of samples
        final_samples = model.sample(base_dist, n_samples=100)
        final_samples = torch.clamp(final_samples, 0, 1)
    
    # Show comparison with real data
    fig, axes = plt.subplots(2, 10, figsize=(15, 4))
    
    # Real MNIST digits (top row)
    for i in range(10):
        axes[0, i].imshow(testX[i, 0], cmap='gray')
        axes[0, i].set_title('Real' if i == 0 else '')
        axes[0, i].set_xticks([])
        axes[0, i].set_yticks([])
    
    # Generated samples (bottom row)
    for i in range(10):
        axes[1, i].imshow(final_samples[i, 0].cpu(), cmap='gray')
        axes[1, i].set_title('Generated' if i == 0 else '')
        axes[1, i].set_xticks([])
        axes[1, i].set_yticks([])
    
    plt.suptitle(f'Real vs Generated MNIST (BJLKeng Real NVP - Final BPD: {bjlkeng_bpds[-1]:.3f})', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    # Large grid of generated samples
    fig, axes = plt.subplots(10, 10, figsize=(12, 12))
    for i in range(10):
        for j in range(10):
            idx = i * 10 + j
            axes[i, j].imshow(final_samples[idx, 0].cpu(), cmap='gray')
            axes[i, j].set_xticks([])
            axes[i, j].set_yticks([])
    
    plt.suptitle(f'100 Generated MNIST Samples - BJLKeng Real NVP (BPD: {bjlkeng_bpds[-1]:.3f})', fontsize=16)
    plt.tight_layout()
    plt.show()

## Summary: BJLKeng's Working Approach

### ✅ **Key Insights from BJLKeng's Implementation:**

1. **Simple Data Preprocessing**
   - Just scale MNIST to [0,1] - no complex transforms
   - Add dequantization noise, but keep it simple

2. **Batch Normalization in Loss**
   - Account for batch norm scaling in log-likelihood
   - Use small momentum (0.05) for batch norm

3. **Training Details**
   - **Small learning rate**: 0.0005 ("slow learners")
   - **L2 regularization**: Small weight decay on scale parameters
   - **Long training**: Many epochs needed

4. **Architecture**
   - **Zero initialization**: Networks start at identity
   - **Tanh scaling**: s = tanh(raw_s) for bounded scale
   - **64 hidden channels**: Reasonable model size

### 📊 **Expected Results:**
- **BJLKeng achieved**: 1.92 bits/dim
- **Original paper**: 1.06 bits/dim
- **Good result**: < 3.0 bits/dim

### 🎓 **Lessons Learned:**
1. **Keep preprocessing simple** for MNIST
2. **Account for ALL transformations** in loss (including batch norm)
3. **Use small learning rates** - these models are slow learners
4. **Zero initialization** is crucial for stable training

This implementation follows BJLKeng's exact approach that achieved working results on MNIST! 🎉