In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
from timeit import default_timer

torch.manual_seed(0)
np.random.seed(0)

# ============================================================================
# VQ-VAE Components
# ============================================================================

class VectorQuantizer(nn.Module):
    """
    Vector Quantization layer with exponential moving average updates.
    """
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25, 
                 decay=0.99, epsilon=1e-5):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon
        
        # Codebook with EMA
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.normal_()  # Better initialization
        
        # EMA tracking
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self.register_buffer('_ema_w', self.embedding.weight.data.clone())
        
    def forward(self, z):
        # z: [B, C, H, W]
        # Reshape to [B*H*W, C]
        z_flattened = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z_flattened.view(-1, self.embedding_dim)
        
        # Calculate distances to codebook vectors
        # |z - e|^2 = |z|^2 + |e|^2 - 2*z*e
        distances = (torch.sum(z_flattened**2, dim=1, keepdim=True) 
                    + torch.sum(self.embedding.weight**2, dim=1)
                    - 2 * torch.matmul(z_flattened, self.embedding.weight.t()))
        
        # Find nearest codebook entry
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=z.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize
        quantized = torch.matmul(encodings, self.embedding.weight)
        quantized = quantized.view(z.shape[0], z.shape[2], z.shape[3], self.embedding_dim)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()
        
        # EMA codebook update (only during training)
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self.decay + \
                                     (1 - self.decay) * torch.sum(encodings, dim=0)
            
            n = torch.sum(self._ema_cluster_size)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self.epsilon) /
                (n + self.num_embeddings * self.epsilon) * n
            )
            
            dw = torch.matmul(encodings.t(), z_flattened)
            self._ema_w = self._ema_w * self.decay + (1 - self.decay) * dw
            
            self.embedding.weight.data = self._ema_w / self._ema_cluster_size.unsqueeze(1)
        
        # Loss - only commitment loss with EMA
        e_latent_loss = F.mse_loss(quantized.detach(), z)
        loss = self.commitment_cost * e_latent_loss
        
        # Straight-through estimator
        quantized = z + (quantized - z).detach()
        
        # Perplexity (measure of codebook usage)
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        return quantized, loss, perplexity, encoding_indices


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 1)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)


class Encoder(nn.Module):
    """
    Encoder: 85x85 -> 21x21 (or 22x22 depending on architecture)
    Uses strided convolutions for downsampling
    """
    def __init__(self, in_channels=1, hidden_dims=[32, 64, 128], latent_dim=64):
        super().__init__()
        
        layers = []
        prev_dim = in_channels
        
        # Downsampling layers
        # 85 -> 43 -> 22 -> 11 or similar
        for h_dim in hidden_dims:
            layers.append(nn.Conv2d(prev_dim, h_dim, 4, stride=2, padding=1))
            layers.append(nn.ReLU())
            layers.append(ResidualBlock(h_dim))
            prev_dim = h_dim
        
        # Final projection to latent space
        layers.append(nn.Conv2d(prev_dim, latent_dim, 3, padding=1))
        
        self.encoder = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.encoder(x)


class Decoder(nn.Module):
    """
    Decoder: upsamples back to 85x85
    """
    def __init__(self, latent_dim=64, hidden_dims=[128, 64, 32], out_channels=1):
        super().__init__()
        
        layers = []
        prev_dim = latent_dim
        
        # Upsampling layers
        for h_dim in hidden_dims:
            layers.append(nn.ConvTranspose2d(prev_dim, h_dim, 4, stride=2, padding=1))
            layers.append(nn.ReLU())
            layers.append(ResidualBlock(h_dim))
            prev_dim = h_dim
        
        # Final layer - may need adjustment to get exact 85x85
        layers.append(nn.ConvTranspose2d(prev_dim, out_channels, 4, stride=2, padding=1))
        
        self.decoder = nn.Sequential(*layers)
        
    def forward(self, x):
        out = self.decoder(x)
        # Handle size mismatch if needed
        return out


class VQVAE(nn.Module):
    def __init__(self, in_channels=1, hidden_dims=[32, 64, 128], 
                 latent_dim=64, num_embeddings=512, commitment_cost=0.25):
        super().__init__()
        
        self.encoder = Encoder(in_channels, hidden_dims, latent_dim)
        self.vq = VectorQuantizer(num_embeddings, latent_dim, commitment_cost, 
                                  decay=0.99, epsilon=1e-5)
        self.decoder = Decoder(latent_dim, hidden_dims[::-1], in_channels)
        
    def forward(self, x):
        z = self.encoder(x)
        quantized, vq_loss, perplexity, indices = self.vq(z)
        recon = self.decoder(quantized)
        
        # Crop or pad to match input size
        if recon.shape[-2:] != x.shape[-2:]:
            recon = F.interpolate(recon, size=x.shape[-2:], mode='bilinear', align_corners=False)
        
        return recon, vq_loss, perplexity
    
    def encode(self, x):
        z = self.encoder(x)
        _, _, _, indices = self.vq(z)
        return indices
    
    def decode_indices(self, indices, spatial_shape):
        # Reconstruct from indices
        quantized = self.vq.embedding(indices)
        quantized = quantized.view(indices.shape[0], spatial_shape[0], spatial_shape[1], -1)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()
        recon = self.decoder(quantized)
        if recon.shape[-2:] != (85, 85):
            recon = F.interpolate(recon, size=(85, 85), mode='bilinear', align_corners=False)
        return recon


# ============================================================================
# Training Function
# ============================================================================

def train_vqvae(model, train_loader, test_loader, epochs=100, lr=1e-3, device='cuda'):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                           factor=0.5, patience=10)
    
    train_losses = []
    test_losses = []
    perplexities = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_vq_loss = 0
        train_perplexity = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            
            optimizer.zero_grad()
            recon, vq_loss, perplexity = model(data)
            
            # Reconstruction loss
            recon_loss = F.mse_loss(recon, data)
            
            # Total loss
            loss = recon_loss + vq_loss
            
            loss.backward()
            optimizer.step()
            
            train_loss += recon_loss.item()
            train_vq_loss += vq_loss.item()
            train_perplexity += perplexity.item()
        
        train_loss /= len(train_loader)
        train_vq_loss /= len(train_loader)
        train_perplexity /= len(train_loader)
        train_losses.append(train_loss)
        perplexities.append(train_perplexity)
        
        # Testing
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for data, _ in test_loader:
                data = data.to(device)
                recon, vq_loss, _ = model(data)
                recon_loss = F.mse_loss(recon, data)
                test_loss += (recon_loss + vq_loss).item()
        
        test_loss /= len(test_loader)
        test_losses.append(test_loss)
        
        scheduler.step(test_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{epochs}:')
            print(f'  Train Loss: {train_loss:.6f}, VQ Loss: {train_vq_loss:.6f}')
            print(f'  Test Loss: {test_loss:.6f}')
            print(f'  Perplexity: {train_perplexity:.2f} (codebook usage)')
            print(f'  LR: {optimizer.param_groups[0]["lr"]:.6f}')
    
    return train_losses, test_losses, perplexities


# ============================================================================
# Visualization
# ============================================================================

def visualize_reconstruction(model, test_data, num_samples=5, device='cuda'):
    model.eval()
    fig, axes = plt.subplots(2, num_samples, figsize=(15, 6))
    
    with torch.no_grad():
        for i in range(num_samples):
            data = test_data[i:i+1].to(device)
            recon, _, _ = model(data)
            
            # Original
            axes[0, i].imshow(data[0, 0].cpu().numpy(), cmap='viridis')
            axes[0, i].set_title(f'Original {i+1}')
            axes[0, i].axis('off')
            
            # Reconstruction
            axes[1, i].imshow(recon[0, 0].cpu().numpy(), cmap='viridis')
            axes[1, i].set_title(f'Reconstructed {i+1}')
            axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.savefig('vqvae_reconstruction.png', dpi=150, bbox_inches='tight')
    plt.close()


def plot_training_curves(train_losses, test_losses, perplexities):
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].plot(train_losses, label='Train')
    axes[0].plot(test_losses, label='Test')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Reconstruction Loss')
    axes[0].set_title('Training Progress')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].semilogy(train_losses, label='Train')
    axes[1].semilogy(test_losses, label='Test')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Log Loss')
    axes[1].set_title('Training Progress (Log Scale)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    axes[2].plot(perplexities)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Perplexity')
    axes[2].set_title('Codebook Usage')
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('vqvae_training.png', dpi=150, bbox_inches='tight')
    plt.close()


# ============================================================================
# Main Training Script
# ============================================================================

if __name__ == '__main__':
    # Hyperparameters
    BATCH_SIZE = 16
    EPOCHS = 200  # More epochs with EMA
    LR = 3e-4  # Lower learning rate
    LATENT_DIM = 32  # Smaller latent space to force compression
    NUM_EMBEDDINGS = 256  # Smaller codebook - easier to use
    COMMITMENT_COST = 1.0  # Higher commitment - prevent encoder from drifting
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"Using device: {DEVICE}")
    print(f"Training VQ-VAE on Darcy Flow Dataset")
    print(f"Codebook size: {NUM_EMBEDDINGS}, Latent dim: {LATENT_DIM}")
    
    # NOTE: You need to load your data here
    # This assumes x_train, x_test are already loaded and normalized
    # Shape should be [N, 85, 85, 1] or [N, 1, 85, 85]
    
    # Example data loading (replace with your actual data):
    # x_train = torch.from_numpy(x_train).float()  # [1000, 85, 85, 1]
    # x_test = torch.from_numpy(x_test).float()
    
    # Ensure channel-first format [N, C, H, W]
    # if x_train.shape[-1] == 1:
    #     x_train = x_train.permute(0, 3, 1, 2)
    #     x_test = x_test.permute(0, 3, 1, 2)
    
    # Create dummy data for demonstration
    x_train = torch.randn(1000, 1, 85, 85)
    x_test = torch.randn(200, 1, 85, 85)
    
    # Create data loaders
    train_dataset = TensorDataset(x_train, torch.zeros(len(x_train)))
    test_dataset = TensorDataset(x_test, torch.zeros(len(x_test)))
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Initialize model
    model = VQVAE(
        in_channels=1,
        hidden_dims=[32, 64, 128],
        latent_dim=LATENT_DIM,
        num_embeddings=NUM_EMBEDDINGS,
        commitment_cost=COMMITMENT_COST
    ).to(DEVICE)
    
    print(f"\nModel Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Train
    print("\nStarting training...")
    t0 = default_timer()
    train_losses, test_losses, perplexities = train_vqvae(
        model, train_loader, test_loader, 
        epochs=EPOCHS, lr=LR, device=DEVICE
    )
    t1 = default_timer()
    print(f"\nTraining completed in {t1-t0:.2f} seconds")
    
    # Visualize
    print("\nGenerating visualizations...")
    visualize_reconstruction(model, x_test, num_samples=5, device=DEVICE)
    plot_training_curves(train_losses, test_losses, perplexities)
    
    # Save model
    torch.save(model.state_dict(), 'vqvae_darcy.pth')
    print("\nModel saved to 'vqvae_darcy.pth'")
    print("Visualizations saved to 'vqvae_reconstruction.png' and 'vqvae_training.png'")
    
    # Final metrics
    print(f"\nFinal Test Loss: {test_losses[-1]:.6f}")
    print(f"Final Perplexity: {perplexities[-1]:.2f} / {NUM_EMBEDDINGS}")
    print(f"Codebook usage: {perplexities[-1]/NUM_EMBEDDINGS*100:.1f}%")



Using device: cuda
Training VQ-VAE on Darcy Flow Dataset
Codebook size: 256, Latent dim: 32

Model Parameters: 870,753

Starting training...
Epoch 10/200:
  Train Loss: 0.999541, VQ Loss: 0.000001
  Test Loss: 0.997808
  Perplexity: 1.00 (codebook usage)
  LR: 0.000300
Epoch 20/200:
  Train Loss: 0.999508, VQ Loss: 0.000000
  Test Loss: 0.997841
  Perplexity: 1.00 (codebook usage)
  LR: 0.000150
Epoch 30/200:
  Train Loss: 0.999405, VQ Loss: 0.000000
  Test Loss: 0.997869
  Perplexity: 1.00 (codebook usage)
  LR: 0.000075
Epoch 40/200:
  Train Loss: 0.999409, VQ Loss: 0.000000
  Test Loss: 0.997879
  Perplexity: 1.00 (codebook usage)
  LR: 0.000037
Epoch 50/200:
  Train Loss: 0.999419, VQ Loss: 0.000000
  Test Loss: 0.997885
  Perplexity: 1.00 (codebook usage)
  LR: 0.000019
Epoch 60/200:
  Train Loss: 0.999406, VQ Loss: 0.000000
  Test Loss: 0.997887
  Perplexity: 1.00 (codebook usage)
  LR: 0.000009
Epoch 70/200:
  Train Loss: 0.999448, VQ Loss: 0.000000
  Test Loss: 0.997888
  Perpl

RuntimeError: NVML_SUCCESS == DriverAPI::get()->nvmlInit_v2_() INTERNAL ASSERT FAILED at "/pytorch/c10/cuda/CUDACachingAllocator.cpp":1098, please report a bug to PyTorch. 