# BJLKeng Real NVP - GUARANTEED TO TRAIN

This version will **definitely train** regardless of initial BPD values. No more conditional training!

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from keras.datasets.mnist import load_data
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.pyplot as plt

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

In [None]:
# Simple data preprocessing
(trainX, trainY), (testX, testY) = load_data()

def simple_preprocess(data, add_noise=True):
    data = data.astype(np.float32)
    if add_noise:
        data = data + np.random.uniform(0, 1, data.shape)
    data = data / 256.0
    return data

trainX = torch.tensor(simple_preprocess(trainX), dtype=torch.float32).unsqueeze(1)
testX = torch.tensor(simple_preprocess(testX, False), dtype=torch.float32).unsqueeze(1)

print(f"Data shape: {trainX.shape}")
print(f"Data range: [{trainX.min():.3f}, {trainX.max():.3f}]")

In [None]:
# Standard Gaussian base distribution
class StandardGaussian:
    def __init__(self, shape=(1, 28, 28), device='cpu'):
        self.shape = shape
        self.device = device
    
    def log_prob(self, z):
        z_flat = z.view(z.size(0), -1)
        log_prob = -0.5 * (z_flat**2 + np.log(2 * np.pi))
        return log_prob.sum(dim=1)
    
    def sample(self, n_samples=1):
        return torch.randn(n_samples, *self.shape, device=self.device)

base_dist = StandardGaussian(device=device)

In [None]:
# Coupling network and layers
def create_checkerboard_mask(h, w, reverse=False):
    mask = torch.zeros(h, w)
    mask[::2, ::2] = 1
    mask[1::2, 1::2] = 1
    if reverse:
        mask = 1 - mask
    return mask

class SimpleCouplingNet(nn.Module):
    def __init__(self, hidden=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, hidden, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden, hidden, 1),
            nn.ReLU(), 
            nn.Conv2d(hidden, 2, 3, padding=1)  # Output s and t
        )
        # Zero init
        nn.init.zeros_(self.net[-1].weight)
        nn.init.zeros_(self.net[-1].bias)
        
    def forward(self, x):
        out = self.net(x)
        s, t = torch.chunk(out, 2, dim=1)
        s = torch.tanh(s) * 0.5  # Small scale bounds
        return s, t

class SimpleCouplingLayer(nn.Module):
    def __init__(self, mask, coupling_net):
        super().__init__()
        self.register_buffer('mask', mask.unsqueeze(0).unsqueeze(0))
        self.coupling_net = coupling_net
        
    def forward(self, x):
        x_frozen = x * (1 - self.mask)
        x_active = x * self.mask
        
        s, t = self.coupling_net(x_frozen)
        s_active = s * self.mask
        t_active = t * self.mask
        
        z_active = x_active * torch.exp(s_active) + t_active
        z = x_frozen + z_active
        
        log_det = s_active.sum(dim=[1, 2, 3])
        return z, log_det
    
    def inverse(self, z):
        z_frozen = z * (1 - self.mask)
        z_active = z * self.mask
        
        s, t = self.coupling_net(z_frozen)
        s_active = s * self.mask
        t_active = t * self.mask
        
        x_active = (z_active - t_active) * torch.exp(-s_active)
        x = z_frozen + x_active
        return x

In [None]:
# Simple Real NVP model
class SimpleRealNVP(nn.Module):
    def __init__(self, num_layers=6):
        super().__init__()
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            mask = create_checkerboard_mask(28, 28, reverse=(i % 2 == 1))
            coupling_net = SimpleCouplingNet()
            layer = SimpleCouplingLayer(mask, coupling_net)
            self.layers.append(layer)
    
    def forward(self, x):
        z = x
        total_log_det = 0
        
        for layer in self.layers:
            z, log_det = layer(z)
            total_log_det += log_det
        
        return z, total_log_det
    
    def inverse(self, z):
        x = z
        for layer in reversed(self.layers):
            x = layer.inverse(x)
        return x

model = SimpleRealNVP().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# SIMPLE loss function - no complex accounting
def simple_loss_and_bpd(model, base_dist, batch):
    z, log_det = model(batch)
    
    # Base log probability
    base_log_prob = base_dist.log_prob(z)
    
    # Add scaling correction for [0,255] -> [0,1] preprocessing
    num_pixels = np.prod(batch.shape[1:])
    scaling_correction = num_pixels * np.log(256.0)
    
    # Total log likelihood
    log_likelihood = base_log_prob + log_det + scaling_correction
    
    # Negative log likelihood (loss)
    nll = -log_likelihood.mean()
    
    # Bits per dimension
    bpd = nll / (np.log(2) * num_pixels)
    
    return nll, bpd, base_log_prob.mean(), log_det.mean()

# Test initial loss
test_batch = trainX[:8].to(device)
test_nll, test_bpd, test_base_lp, test_log_det = simple_loss_and_bpd(model, base_dist, test_batch)

print("=== Initial Loss Test ===")
print(f"NLL: {test_nll.item():.3f}")
print(f"BPD: {test_bpd.item():.3f}")
print(f"Base log prob: {test_base_lp.item():.1f}")
print(f"Log determinant: {test_log_det.item():.1f}")

## GUARANTEED TRAINING - NO CONDITIONS!

In [None]:
def train_no_conditions(model, base_dist, dataloader, epochs=30, lr=1e-4):
    """Training with NO CONDITIONS - will definitely run!"""
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    losses = []
    bpds = []
    
    print("🚀 STARTING TRAINING - NO CONDITIONS TO STOP US!")
    
    for epoch in tqdm(range(epochs), desc="GUARANTEED Training"):
        epoch_loss = 0
        epoch_bpd = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(dataloader):
            if isinstance(batch, (list, tuple)):
                batch = batch[0]
            batch = batch.to(device)
            
            # Compute loss
            nll, bpd, base_lp, log_det = simple_loss_and_bpd(model, base_dist, batch)
            
            # Backward pass
            optimizer.zero_grad()
            nll.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += nll.item()
            epoch_bpd += bpd.item()
            num_batches += 1
            
            # Print progress frequently
            if batch_idx % 200 == 0:
                print(f"Epoch {epoch+1:2d}, Batch {batch_idx:3d}: "
                      f"NLL={nll.item():6.3f}, "
                      f"BPD={bpd.item():5.3f}, "
                      f"Base_LP={base_lp.item():7.1f}, "
                      f"LogDet={log_det.item():6.1f}")
        
        # Store metrics
        avg_loss = epoch_loss / num_batches
        avg_bpd = epoch_bpd / num_batches
        losses.append(avg_loss)
        bpds.append(avg_bpd)
        
        # Generate samples every 5 epochs
        if (epoch + 1) % 5 == 0:
            print(f"\n📊 Epoch {epoch+1}: BPD = {avg_bpd:.3f}")
            
            # Generate samples
            model.eval()
            with torch.no_grad():
                z_samples = base_dist.sample(16)
                if z_samples.device != next(model.parameters()).device:
                    z_samples = z_samples.to(next(model.parameters()).device)
                generated = model.inverse(z_samples)
                generated = torch.clamp(generated, 0, 1)
                
                fig, axes = plt.subplots(4, 4, figsize=(8, 8))
                for i in range(4):
                    for j in range(4):
                        idx = i * 4 + j
                        axes[i, j].imshow(generated[idx, 0].cpu(), cmap='gray')
                        axes[i, j].set_xticks([])
                        axes[i, j].set_yticks([])
                plt.suptitle(f'Generated Samples - Epoch {epoch+1} (BPD: {avg_bpd:.3f})')
                plt.tight_layout()
                plt.show()
            model.train()
    
    return losses, bpds

# Setup data
dataset = TensorDataset(trainX)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

print("\n" + "="*50)
print("READY TO TRAIN - NO BULLSHIT CONDITIONS!")
print("This WILL train regardless of initial BPD values!")
print("="*50)

In [None]:
# START TRAINING - NO CONDITIONS!
print("🔥 STARTING UNCONDITIONAL TRAINING!")

training_losses, training_bpds = train_no_conditions(
    model=model,
    base_dist=base_dist, 
    dataloader=dataloader,
    epochs=25,
    lr=1e-4
)

print("\n🎉 TRAINING COMPLETED!")

In [None]:
# Plot results
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(training_losses)
plt.title('Training Loss (NLL)')
plt.xlabel('Epoch')
plt.ylabel('Negative Log Likelihood')
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(training_bpds)
plt.title('Bits Per Dimension')
plt.xlabel('Epoch')
plt.ylabel('BPD')
plt.axhline(y=1.92, color='g', linestyle='--', label='BJLKeng Target')
plt.axhline(y=1.06, color='r', linestyle='--', label='Paper Result')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
improvement = training_bpds[0] - np.array(training_bpds)
plt.plot(improvement)
plt.title('BPD Improvement')
plt.xlabel('Epoch')
plt.ylabel('BPD Reduction')
plt.grid(True)

plt.tight_layout()
plt.show()

print(f"\n📊 FINAL RESULTS:")
print(f"Initial BPD: {training_bpds[0]:.3f}")
print(f"Final BPD: {training_bpds[-1]:.3f}")
print(f"Improvement: {training_bpds[0] - training_bpds[-1]:.3f} bits")
print(f"BJLKeng target: 1.92")
print(f"Paper target: 1.06")

if training_bpds[-1] > 0:
    print("\n✅ BPD is positive - model is working!")
else:
    print("\n❌ BPD still negative - need more fixes")

## Final Sample Generation

In [None]:
# Generate final samples
model.eval()
with torch.no_grad():
    final_samples = base_dist.sample(64)
    if final_samples.device != next(model.parameters()).device:
        final_samples = final_samples.to(next(model.parameters()).device)
    generated_images = model.inverse(final_samples)
    generated_images = torch.clamp(generated_images, 0, 1)

# Show comparison
fig, axes = plt.subplots(2, 10, figsize=(15, 4))

# Real MNIST (top row)
for i in range(10):
    axes[0, i].imshow(testX[i, 0], cmap='gray')
    axes[0, i].set_title('Real' if i == 0 else '')
    axes[0, i].set_xticks([])
    axes[0, i].set_yticks([])

# Generated (bottom row)
for i in range(10):
    axes[1, i].imshow(generated_images[i, 0].cpu(), cmap='gray')
    axes[1, i].set_title('Generated' if i == 0 else '')
    axes[1, i].set_xticks([])
    axes[1, i].set_yticks([])

plt.suptitle(f'Real vs Generated MNIST (Final BPD: {training_bpds[-1]:.3f})', fontsize=14)
plt.tight_layout()
plt.show()

print("\n🎯 TRAINING COMPLETED SUCCESSFULLY!")
print("No conditional bullshit - this model definitely trained!")