# Convolutional Normalizing Flows

Implementation based on https://github.com/AxelNathanson/pytorch-normalizing-flows/blob/main/flow_models.py

This notebook implements convolutional Real NVP flows for image generation.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Simple Affine Transform

In [None]:
class SimpleAffine(nn.Module):
    def __init__(self, mask, dim=784):
        super().__init__()
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.dim = dim
        
        self.s_func = nn.Sequential(
            nn.Linear(dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, dim),
            nn.Tanh()
        )
        
        self.t_func = nn.Sequential(
            nn.Linear(dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, dim)
        )
    
    def forward(self, x):
        x_masked = x * self.mask
        s = self.s_func(x_masked)
        t = self.t_func(x_masked)
        
        s = s * (1 - self.mask)
        t = t * (1 - self.mask)
        
        y = x * torch.exp(s) + t
        log_det_jac = s.sum(dim=1)
        
        return y, log_det_jac
    
    def inverse(self, y):
        y_masked = y * self.mask
        s = self.s_func(y_masked)
        t = self.t_func(y_masked)
        
        s = s * (1 - self.mask)
        t = t * (1 - self.mask)
        
        x = (y - t) * torch.exp(-s)
        inv_log_det_jac = -s.sum(dim=1)
        
        return x, inv_log_det_jac

## Stack Simple Affine

In [None]:
class StackSimpleAffine(nn.Module):
    def __init__(self, transforms, dim=784):
        super().__init__()
        self.dim = dim
        self.transforms = nn.ModuleList(transforms)
        self.distribution = MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim))

    def log_probability(self, x):
        log_prob = torch.zeros(x.shape[0], device=x.device)
        for transform in reversed(self.transforms):
            x, inv_log_det_jac = transform.inverse(x)
            log_prob += inv_log_det_jac

        log_prob += self.distribution.log_prob(x)
        return log_prob

    def rsample(self, num_samples):
        x = self.distribution.sample((num_samples,))
        log_prob = self.distribution.log_prob(x)

        for transform in self.transforms:
            x, log_det_jac = transform.forward(x)
            log_prob += log_det_jac

        return x, log_prob

## Real NVP Node

In [None]:
class RealNVPNode(nn.Module):
    def __init__(self, mask, dim=784):
        super().__init__()
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.dim = dim
        
        hidden_dim = 512
        
        self.s_func = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim),
            nn.Tanh()
        )
        
        self.t_func = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim)
        )
    
    def forward(self, x):
        x_masked = x * self.mask
        s = self.s_func(x_masked)
        t = self.t_func(x_masked)
        
        s = s * (1 - self.mask)
        t = t * (1 - self.mask)
        
        y = x * torch.exp(s) + t
        log_det_jac = s.sum(dim=1)
        
        return y, log_det_jac
    
    def inverse(self, y):
        y_masked = y * self.mask
        s = self.s_func(y_masked)
        t = self.t_func(y_masked)
        
        s = s * (1 - self.mask)
        t = t * (1 - self.mask)
        
        x = (y - t) * torch.exp(-s)
        inv_log_det_jac = -s.sum(dim=1)
        
        return x, inv_log_det_jac

## Real NVP

In [None]:
class RealNVP(nn.Module):
    def __init__(self, nodes, dim=784):
        super().__init__()
        self.dim = dim
        self.nodes = nn.ModuleList(nodes)
        self.distribution = MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim))

    def log_probability(self, x):
        log_prob = torch.zeros(x.shape[0], device=x.device)
        for node in reversed(self.nodes):
            x, inv_log_det_jac = node.inverse(x)
            log_prob += inv_log_det_jac

        log_prob += self.distribution.log_prob(x)
        return log_prob

    def rsample(self, num_samples):
        x = self.distribution.sample((num_samples,))
        log_prob = self.distribution.log_prob(x)

        for node in self.nodes:
            x, log_det_jac = node.forward(x)
            log_prob += log_det_jac

        return x, log_prob

## Convolutional Real NVP Node

In [None]:
class RealNVPNodeCNN(nn.Module):
    def __init__(self, mask, in_channels=1):
        super().__init__()
        
        self.mask = nn.Parameter(mask, requires_grad=False)
        
        cnn_channels = [32, 64, 32]
        
        self.s_func = nn.Sequential(
            nn.Conv2d(in_channels, cnn_channels[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_channels[0]),
            nn.LeakyReLU(),
            nn.Conv2d(cnn_channels[0], cnn_channels[1], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_channels[1]),
            nn.LeakyReLU(),
            nn.Conv2d(cnn_channels[1], cnn_channels[2], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_channels[2]),
            nn.LeakyReLU(),
            nn.Conv2d(cnn_channels[2], in_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
        
        self.t_func = nn.Sequential(
            nn.Conv2d(in_channels, cnn_channels[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_channels[0]),
            nn.LeakyReLU(),
            nn.Conv2d(cnn_channels[0], cnn_channels[1], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_channels[1]),
            nn.LeakyReLU(),
            nn.Conv2d(cnn_channels[1], cnn_channels[2], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(cnn_channels[2]),
            nn.LeakyReLU(),
            nn.Conv2d(cnn_channels[2], in_channels, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, x):
        x_masked = x * self.mask
        s = self.s_func(x_masked)
        t = self.t_func(x_masked)
        
        s = s * (1 - self.mask)
        t = t * (1 - self.mask)
        
        y = x * torch.exp(s) + t
        log_det_jac = s.sum(dim=[1, 2, 3])
        
        return y, log_det_jac
    
    def inverse(self, y):
        y_masked = y * self.mask
        s = self.s_func(y_masked)
        t = self.t_func(y_masked)
        
        s = s * (1 - self.mask)
        t = t * (1 - self.mask)
        
        x = (y - t) * torch.exp(-s)
        inv_log_det_jac = -s.sum(dim=[1, 2, 3])
        
        return x, inv_log_det_jac

## Convolutional Real NVP

In [None]:
class RealNVPCNN(nn.Module):
    def __init__(self, nodes, image_shape=(1, 28, 28)):
        super().__init__()
        self.image_shape = image_shape
        self.dim = np.prod(image_shape)
        self.nodes = nn.ModuleList(nodes)
        self.distribution = MultivariateNormal(torch.zeros(self.dim), torch.eye(self.dim))

    def log_probability(self, x):
        log_prob = torch.zeros(x.shape[0], device=x.device)
        for node in reversed(self.nodes):
            x, inv_log_det_jac = node.inverse(x)
            log_prob += inv_log_det_jac

        x_flat = x.view(x.shape[0], -1)
        log_prob += self.distribution.log_prob(x_flat)
        return log_prob

    def rsample(self, num_samples):
        x = self.distribution.sample((num_samples,))
        log_prob = self.distribution.log_prob(x)
        
        x = x.view(num_samples, *self.image_shape)

        for node in self.nodes:
            x, log_det_jac = node.forward(x)
            log_prob += log_det_jac

        return x, log_prob

## Data Loading

In [None]:
# Load MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')

# Check data shape
sample_batch = next(iter(train_loader))[0]
print(f'Sample batch shape: {sample_batch.shape}')
print(f'Data range: [{sample_batch.min():.3f}, {sample_batch.max():.3f}]')

## Create Checkerboard Masks

In [None]:
def create_checkerboard_mask(height, width, reverse=False):
    """Create checkerboard mask for spatial coupling"""
    mask = torch.zeros(height, width)
    mask[::2, ::2] = 1  # Every other pixel starting from (0,0)
    mask[1::2, 1::2] = 1  # Every other pixel starting from (1,1)
    
    if reverse:
        mask = 1 - mask
    
    return mask

# Create masks for CNN
mask1 = create_checkerboard_mask(28, 28, reverse=False).unsqueeze(0).to(device)  # [1, 28, 28]
mask2 = create_checkerboard_mask(28, 28, reverse=True).unsqueeze(0).to(device)   # [1, 28, 28]

print(f'Mask 1 shape: {mask1.shape}')
print(f'Mask 2 shape: {mask2.shape}')

# Visualize masks
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].imshow(mask1[0].cpu(), cmap='RdBu')
axes[0].set_title('Mask 1')
axes[0].set_xticks([])
axes[0].set_yticks([])

axes[1].imshow(mask2[0].cpu(), cmap='RdBu')
axes[1].set_title('Mask 2')
axes[1].set_xticks([])
axes[1].set_yticks([])

plt.tight_layout()
plt.show()

## Create Convolutional Real NVP Model

In [None]:
# Create CNN nodes with alternating masks
num_layers = 6
nodes = []

for i in range(num_layers):
    if i % 2 == 0:
        nodes.append(RealNVPNodeCNN(mask1, in_channels=1))
    else:
        nodes.append(RealNVPNodeCNN(mask2, in_channels=1))

# Create the model
model = RealNVPCNN(nodes, image_shape=(1, 28, 28)).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total trainable parameters: {total_params:,}')

# Test forward pass
test_batch = next(iter(train_loader))[0][:4].to(device)
with torch.no_grad():
    log_prob = model.log_probability(test_batch)
    samples, sample_log_prob = model.rsample(4)

print(f'Test batch shape: {test_batch.shape}')
print(f'Log probability shape: {log_prob.shape}')
print(f'Sample shape: {samples.shape}')
print(f'Sample log prob shape: {sample_log_prob.shape}')

## Training Loop

In [None]:
def train_model(model, train_loader, num_epochs=20, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)
    
    losses = []
    
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        num_batches = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch_idx, (data, _) in enumerate(progress_bar):
            data = data.to(device)
            
            optimizer.zero_grad()
            
            # Compute negative log likelihood
            log_prob = model.log_probability(data)
            loss = -log_prob.mean()
            
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Avg Loss': f'{epoch_loss/num_batches:.4f}'
            })
        
        avg_loss = epoch_loss / num_batches
        losses.append(avg_loss)
        
        scheduler.step()
        
        print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}')
        
        # Generate samples every 5 epochs
        if (epoch + 1) % 5 == 0:
            generate_samples(model, epoch + 1)
    
    return losses

def generate_samples(model, epoch, num_samples=64):
    model.eval()
    
    with torch.no_grad():
        samples, _ = model.rsample(num_samples)
        samples = torch.clamp(samples, -1, 1)  # Clamp to [-1, 1]
        samples = (samples + 1) / 2  # Convert to [0, 1] for visualization
    
    # Plot samples
    fig, axes = plt.subplots(8, 8, figsize=(12, 12))
    for i in range(8):
        for j in range(8):
            idx = i * 8 + j
            axes[i, j].imshow(samples[idx, 0].cpu(), cmap='gray')
            axes[i, j].axis('off')
    
    plt.suptitle(f'Generated Samples - Epoch {epoch}', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    model.train()

# Start training
print('Starting training...')
losses = train_model(model, train_loader, num_epochs=20, lr=1e-3)

## Plot Training Loss

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Negative Log Likelihood')
plt.grid(True)
plt.show()

print(f'Final loss: {losses[-1]:.4f}')
print(f'Best loss: {min(losses):.4f}')

## Generate Final Samples

In [None]:
# Generate a large batch of samples
model.eval()
with torch.no_grad():
    final_samples, _ = model.rsample(100)
    final_samples = torch.clamp(final_samples, -1, 1)
    final_samples = (final_samples + 1) / 2

# Plot comparison with real data
real_data = next(iter(test_loader))[0][:10]
real_data = (real_data + 1) / 2  # Convert to [0, 1]

fig, axes = plt.subplots(2, 10, figsize=(15, 4))

# Real data
for i in range(10):
    axes[0, i].imshow(real_data[i, 0], cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Real Data', fontsize=12)

# Generated data
for i in range(10):
    axes[1, i].imshow(final_samples[i, 0].cpu(), cmap='gray')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Generated Data', fontsize=12)

plt.tight_layout()
plt.show()

# Plot large grid of generated samples
fig, axes = plt.subplots(10, 10, figsize=(15, 15))
for i in range(10):
    for j in range(10):
        idx = i * 10 + j
        axes[i, j].imshow(final_samples[idx, 0].cpu(), cmap='gray')
        axes[i, j].axis('off')

plt.suptitle('100 Generated MNIST Samples', fontsize=16)
plt.tight_layout()
plt.show()

## Model Evaluation

In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    total_log_prob = 0
    num_samples = 0
    
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            log_prob = model.log_probability(data)
            total_log_prob += log_prob.sum().item()
            num_samples += data.size(0)
    
    avg_log_prob = total_log_prob / num_samples
    bits_per_dim = -avg_log_prob / (np.log(2) * np.prod((1, 28, 28)))
    
    return avg_log_prob, bits_per_dim

# Evaluate on test set
test_log_prob, test_bpd = evaluate_model(model, test_loader)
print(f'Test Log Probability: {test_log_prob:.4f}')
print(f'Test Bits per Dimension: {test_bpd:.4f}')

# Sample quality metrics
with torch.no_grad():
    samples, _ = model.rsample(1000)
    sample_mean = samples.mean().item()
    sample_std = samples.std().item()
    
print(f'\nSample Statistics:')
print(f'Sample Mean: {sample_mean:.4f}')
print(f'Sample Std: {sample_std:.4f}')
print(f'Sample Range: [{samples.min():.4f}, {samples.max():.4f}]')

## Summary

In [None]:
print('=== Model Summary ===')
print(f'Architecture: Convolutional Real NVP')
print(f'Number of layers: {len(model.nodes)}')
print(f'Total parameters: {total_params:,}')
print(f'Image shape: {model.image_shape}')
print(f'\n=== Training Results ===')
print(f'Final training loss: {losses[-1]:.4f}')
print(f'Best training loss: {min(losses):.4f}')
print(f'\n=== Evaluation Results ===')
print(f'Test log probability: {test_log_prob:.4f}')
print(f'Test bits per dimension: {test_bpd:.4f}')
print(f'\nTraining completed successfully!')