# Wasserstein GAN (WGAN) Implementation

This notebook implements a comprehensive Wasserstein GAN with proper critic module and Wasserstein distance optimization. The implementation includes:

- **Critic Module**: Replaces the discriminator with a critic that estimates Wasserstein distance
- **Wasserstein Loss**: Uses Wasserstein distance instead of binary cross-entropy
- **Weight Clipping**: Ensures Lipschitz constraint for the critic
- **WGAN-GP**: Optional gradient penalty implementation
- **Multi-Dataset Support**: Works with both MNIST and CIFAR10
- **Professional Training Pipeline**: Complete training loop with logging and visualization

## Key Differences from Standard GAN:
1. **Critic vs Discriminator**: The critic outputs a real-valued score rather than a probability
2. **Wasserstein Distance**: Measures the cost of transforming one distribution to another
3. **No Sigmoid**: Critic output is unbounded
4. **Weight Clipping**: Maintains Lipschitz constraint (or gradient penalty in WGAN-GP)
5. **Different Training Ratio**: Typically train critic more frequently than generator


In [None]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from IPython import display
import time
import logging
from collections import defaultdict

# Set up logging configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),  # Console output
        logging.FileHandler('wgan_training.log')  # File output
    ]
)

# Create logger for WGAN
logger = logging.getLogger('WGAN')

# Set device globally
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

def get_cuda_device_info():
    """Get detailed information about CUDA device"""
    if torch.cuda.is_available():
        logger.info("=" * 50)
        logger.info("CUDA DEVICE INFORMATION")
        logger.info("=" * 50)
        
        # Basic device info
        device_count = torch.cuda.device_count()
        current_device = torch.cuda.current_device()
        device_name = torch.cuda.get_device_name(current_device)
        
        logger.info(f"Number of CUDA devices: {device_count}")
        logger.info(f"Current device index: {current_device}")
        logger.info(f"Device name: {device_name}")
        
        # Memory information
        memory_allocated = torch.cuda.memory_allocated(current_device) / 1024**3  # GB
        memory_reserved = torch.cuda.memory_reserved(current_device) / 1024**3    # GB
        memory_total = torch.cuda.get_device_properties(current_device).total_memory / 1024**3  # GB
        
        logger.info("MEMORY INFORMATION:")
        logger.info(f"Total GPU memory: {memory_total:.2f} GB")
        logger.info(f"Allocated memory: {memory_allocated:.2f} GB")
        logger.info(f"Reserved memory: {memory_reserved:.2f} GB")
        logger.info(f"Free memory: {memory_total - memory_allocated:.2f} GB")
        
        logger.info("=" * 50)
    else:
        logger.warning("CUDA is not available")

# Call the function to display device info
get_cuda_device_info()


In [None]:
class DataSetLoader:
    def __init__(self, data_dir="./data", batch_size=64):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.logger = logging.getLogger(f'{__name__}.DataSetLoader')
        self.logger.info(f"DataSetLoader initialized with batch_size={batch_size}, data_dir={data_dir}")
        
    def load_mnist(self):
        self.logger.info("Loading MNIST dataset...")
        
        # MNIST transforms - normalized to [-1, 1] range for WGAN
        compose = transforms.Compose([
            transforms.Resize(64),  # Resize to 64x64 for better quality
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # Convert to [-1, 1] range
        ])
        
        # Download and load dataset
        out_dir = '{}/dataset'.format(self.data_dir)
        self.logger.debug(f"Loading MNIST from directory: {out_dir}")
        data = datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)
        
        # Create data loader
        data_loader = DataLoader(data, batch_size=self.batch_size, shuffle=True, drop_last=True)
        num_batches = len(data_loader)
        
        # MNIST specifications
        img_size = [64, 64]
        channels = 1
        
        self.logger.info(f"MNIST dataset loaded: {num_batches} batches, image size: {img_size}, channels: {channels}")
        return data_loader, num_batches, img_size, channels
    
    def load_cifar10(self):
        self.logger.info("Loading CIFAR10 dataset...")
        
        # CIFAR10 transforms - normalized to [-1, 1] range for WGAN
        compose = transforms.Compose([
            transforms.Resize(64),  # Resize to 64x64 for consistency
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Convert to [-1, 1] range
        ])
        
        # Download and load dataset
        out_dir = '{}/dataset'.format(self.data_dir)
        self.logger.debug(f"Loading CIFAR10 from directory: {out_dir}")
        data = datasets.CIFAR10(root=out_dir, train=True, transform=compose, download=True)
        
        # Create data loader
        data_loader = DataLoader(data, batch_size=self.batch_size, shuffle=True, drop_last=True)
        num_batches = len(data_loader)
        
        # CIFAR10 specifications
        img_size = [64, 64]
        channels = 3
        
        self.logger.info(f"CIFAR10 dataset loaded: {num_batches} batches, image size: {img_size}, channels: {channels}")
        return data_loader, num_batches, img_size, channels

logger.info("DataSetLoader class created successfully!")


In [None]:
class Critic(nn.Module):
    """
    WGAN Critic (replaces Discriminator)
    
    Key differences from standard discriminator:
    1. No sigmoid activation at the end
    2. Output is unbounded real value (not probability)
    3. Uses Wasserstein distance instead of binary cross-entropy
    4. Requires Lipschitz constraint (weight clipping or gradient penalty)
    """
    def __init__(self, in_channels=1, hidden_dim=64, img_size=[64, 64]):
        super(Critic, self).__init__()
        
        self.logger = logging.getLogger(f'{__name__}.Critic')
        self.img_size = img_size
        self.in_channels = in_channels
        
        # Calculate final spatial dimensions after convolutions
        # Each conv layer with stride=2 reduces size by half
        self.final_height = img_size[0] // (2**4)  # 4 conv layers
        self.final_width = img_size[1] // (2**4)
        
        # Convolutional layers
        self.conv_layers = nn.Sequential(
            # First layer: in_channels -> hidden_dim
            nn.Conv2d(in_channels, hidden_dim, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Second layer: hidden_dim -> hidden_dim*2
            nn.Conv2d(hidden_dim, hidden_dim * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Third layer: hidden_dim*2 -> hidden_dim*4
            nn.Conv2d(hidden_dim * 2, hidden_dim * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Fourth layer: hidden_dim*4 -> hidden_dim*8
            nn.Conv2d(hidden_dim * 4, hidden_dim * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim * 8),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Final layer: flatten and output single value (no sigmoid!)
        self.final_layer = nn.Sequential(
            nn.Conv2d(hidden_dim * 8, 1, kernel_size=self.final_height, stride=1, padding=0, bias=False),
            nn.Flatten()
        )
        
        self.logger.info("Critic initialized:")
        self.logger.info(f"  Input channels: {in_channels}")
        self.logger.info(f"  Hidden dim: {hidden_dim}")
        self.logger.info(f"  Image size: {img_size}")
        self.logger.info(f"  Final spatial size: {self.final_height}x{self.final_width}")

    def forward(self, x):
        """
        Forward pass through critic
        
        Args:
            x: Input images [batch_size, channels, height, width]
            
        Returns:
            Real-valued scores [batch_size] (NOT probabilities)
        """
        x = self.conv_layers(x)
        x = self.final_layer(x)
        return x.squeeze()  # Remove extra dimensions, return [batch_size]
    
    def clip_weights(self, clip_value=0.01):
        """
        Clip weights to satisfy Lipschitz constraint
        This is the original WGAN approach (WGAN-GP uses gradient penalty instead)
        """
        self.logger.debug(f"Clipping weights with clip_value={clip_value}")
        for param in self.parameters():
            param.data.clamp_(-clip_value, clip_value)

# Test the Critic
logger.info("Testing Critic:")
critic_mnist = Critic(in_channels=1, hidden_dim=64, img_size=[64, 64])
test_input = torch.randn(4, 1, 64, 64)
output = critic_mnist(test_input)
logger.info(f"Input shape: {test_input.shape}")
logger.info(f"Output shape: {output.shape}")
logger.debug(f"Output values: {output}")
logger.info("Note: Output values are unbounded (not probabilities)!")


In [None]:
class Generator(nn.Module):
    """
    WGAN Generator
    
    Generates images from random noise using transposed convolutions.
    Output range is [-1, 1] to match data normalization.
    """
    def __init__(self, latent_dim=100, hidden_dim=64, out_channels=1, img_size=[64, 64]):
        super(Generator, self).__init__()
        
        self.logger = logging.getLogger(f'{__name__}.Generator')
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.out_channels = out_channels
        self.img_size = img_size
        
        # Calculate starting size (reverse of critic's final size)
        self.start_height = img_size[0] // (2**4)  # 4 upsampling layers
        self.start_width = img_size[1] // (2**4)
        
        # Initial projection from latent space to feature maps
        self.initial_layer = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, hidden_dim * 8, 
                             kernel_size=self.start_height, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(hidden_dim * 8),
            nn.ReLU(inplace=True)
        )
        
        # Upsampling layers
        self.conv_layers = nn.Sequential(
            # First upsampling: hidden_dim*8 -> hidden_dim*4
            nn.ConvTranspose2d(hidden_dim * 8, hidden_dim * 4, 
                             kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim * 4),
            nn.ReLU(inplace=True),
            
            # Second upsampling: hidden_dim*4 -> hidden_dim*2
            nn.ConvTranspose2d(hidden_dim * 4, hidden_dim * 2, 
                             kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim * 2),
            nn.ReLU(inplace=True),
            
            # Third upsampling: hidden_dim*2 -> hidden_dim
            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, 
                             kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            
            # Final layer: hidden_dim -> out_channels
            nn.ConvTranspose2d(hidden_dim, out_channels, 
                             kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()  # Output in [-1, 1] range
        )
        
        self.logger.info("Generator initialized:")
        self.logger.info(f"  Latent dim: {latent_dim}")
        self.logger.info(f"  Hidden dim: {hidden_dim}")
        self.logger.info(f"  Output channels: {out_channels}")
        self.logger.info(f"  Output size: {img_size}")
        self.logger.info(f"  Starting size: {self.start_height}x{self.start_width}")

    def forward(self, z):
        """
        Forward pass through generator
        
        Args:
            z: Random noise [batch_size, latent_dim, 1, 1]
            
        Returns:
            Generated images [batch_size, out_channels, height, width]
        """
        # If z is 2D, reshape to 4D for conv operations
        if len(z.shape) == 2:
            z = z.view(z.size(0), z.size(1), 1, 1)
        
        x = self.initial_layer(z)
        x = self.conv_layers(x)
        return x

def weights_init(m):
    """
    Initialize weights according to DCGAN paper
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Test the Generator
logger.info("Testing Generator:")
generator_mnist = Generator(latent_dim=100, hidden_dim=64, out_channels=1, img_size=[64, 64])
generator_mnist.apply(weights_init)

# Test with 2D noise (common input format)
test_noise_2d = torch.randn(4, 100)
output_2d = generator_mnist(test_noise_2d)
logger.info(f"2D Noise shape: {test_noise_2d.shape}")
logger.info(f"Generated image shape: {output_2d.shape}")
logger.info(f"Output range: [{output_2d.min():.3f}, {output_2d.max():.3f}]")

# Test with 4D noise
test_noise_4d = torch.randn(4, 100, 1, 1)
output_4d = generator_mnist(test_noise_4d)
logger.info(f"4D Noise shape: {test_noise_4d.shape}")
logger.info(f"Generated image shape: {output_4d.shape}")
logger.info(f"Output range: [{output_4d.min():.3f}, {output_4d.max():.3f}]")


In [None]:
class WGAN:
    """
    Wasserstein GAN Implementation
    
    Key features:
    1. Uses Wasserstein distance instead of binary cross-entropy
    2. Critic outputs real-valued scores (not probabilities)
    3. Weight clipping to enforce Lipschitz constraint
    4. Train critic more frequently than generator (typically 5:1 ratio)
    """
    def __init__(self, latent_dim=100, img_size=[64, 64], channels=1, 
                 hidden_dim=64, batch_size=64, lr=5e-5, device=None,
                 clip_value=0.01, critic_iterations=5):
        
        self.logger = logging.getLogger(f'{__name__}.WGAN')
        
        # Set device
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.logger.info(f"WGAN initialized on device: {self.device}")
        
        # Model parameters
        self.latent_dim = latent_dim
        self.img_size = img_size
        self.channels = channels
        self.batch_size = batch_size
        self.lr = lr
        self.clip_value = clip_value
        self.critic_iterations = critic_iterations
        
        # Initialize networks
        self.generator = Generator(
            latent_dim=latent_dim,
            hidden_dim=hidden_dim,
            out_channels=channels,
            img_size=img_size
        ).to(self.device)
        
        self.critic = Critic(
            in_channels=channels,
            hidden_dim=hidden_dim,
            img_size=img_size
        ).to(self.device)
        
        # Initialize weights
        self.generator.apply(weights_init)
        self.critic.apply(weights_init)
        self.logger.debug("Applied weight initialization to Generator and Critic")
        
        # Optimizers (RMSprop is recommended for WGAN)
        self.g_optimizer = optim.RMSprop(self.generator.parameters(), lr=lr)
        self.c_optimizer = optim.RMSprop(self.critic.parameters(), lr=lr)
        
        # For visualization
        self.fixed_noise = torch.randn(16, latent_dim, device=self.device)
        
        self.logger.info("WGAN Architecture:")
        self.logger.info(f"  Generator parameters: {sum(p.numel() for p in self.generator.parameters()):,}")
        self.logger.info(f"  Critic parameters: {sum(p.numel() for p in self.critic.parameters()):,}")
        self.logger.info(f"  Critic iterations per generator step: {critic_iterations}")
        self.logger.info(f"  Weight clip value: {clip_value}")
    
    def generate_noise(self, batch_size=None):
        """Generate random noise for generator"""
        if batch_size is None:
            batch_size = self.batch_size
        return torch.randn(batch_size, self.latent_dim, device=self.device)
    
    def train_critic(self, real_images):
        """
        Train critic for one step
        
        WGAN Critic Loss: E[C(real)] - E[C(fake)]
        We want to maximize this (maximize distance between real and fake)
        So we minimize: -E[C(real)] + E[C(fake)]
        """
        batch_size = real_images.size(0)
        real_images = real_images.to(self.device)
        
        # Reset gradients
        self.c_optimizer.zero_grad()
        
        # Train on real images
        real_scores = self.critic(real_images)
        real_loss = -torch.mean(real_scores)  # Negative because we want to maximize
        
        # Train on fake images
        noise = self.generate_noise(batch_size)
        fake_images = self.generator(noise).detach()  # Detach to avoid training generator
        fake_scores = self.critic(fake_images)
        fake_loss = torch.mean(fake_scores)
        
        # Total critic loss
        critic_loss = real_loss + fake_loss
        critic_loss.backward()
        
        # Update critic
        self.c_optimizer.step()
        
        # Clip weights to enforce Lipschitz constraint
        self.critic.clip_weights(self.clip_value)
        
        return {
            'critic_loss': critic_loss.item(),
            'real_loss': real_loss.item(),
            'fake_loss': fake_loss.item(),
            'real_score': real_scores.mean().item(),
            'fake_score': fake_scores.mean().item(),
            'wasserstein_distance': -(real_scores.mean() - fake_scores.mean()).item()
        }
    
    def train_generator(self):
        """
        Train generator for one step
        
        WGAN Generator Loss: -E[C(G(z))]
        We want to maximize E[C(G(z))] (make fake images get high scores)
        So we minimize: -E[C(G(z))]
        """
        # Reset gradients
        self.g_optimizer.zero_grad()
        
        # Generate fake images
        noise = self.generate_noise()
        fake_images = self.generator(noise)
        fake_scores = self.critic(fake_images)
        
        # Generator loss: we want to maximize fake scores
        generator_loss = -torch.mean(fake_scores)
        generator_loss.backward()
        
        # Update generator
        self.g_optimizer.step()
        
        return {
            'generator_loss': generator_loss.item(),
            'fake_score': fake_scores.mean().item()
        }
    
    def train_step(self, real_images):
        """
        Complete training step: train critic multiple times, then generator once
        """
        metrics = {
            'critic_losses': [],
            'wasserstein_distances': [],
            'real_scores': [],
            'fake_scores': []
        }
        
        # Train critic multiple times
        for _ in range(self.critic_iterations):
            critic_metrics = self.train_critic(real_images)
            metrics['critic_losses'].append(critic_metrics['critic_loss'])
            metrics['wasserstein_distances'].append(critic_metrics['wasserstein_distance'])
            metrics['real_scores'].append(critic_metrics['real_score'])
            metrics['fake_scores'].append(critic_metrics['fake_score'])
        
        # Train generator once
        generator_metrics = self.train_generator()
        
        # Return averaged metrics
        return {
            'critic_loss': np.mean(metrics['critic_losses']),
            'generator_loss': generator_metrics['generator_loss'],
            'wasserstein_distance': np.mean(metrics['wasserstein_distances']),
            'real_score': np.mean(metrics['real_scores']),
            'fake_score': np.mean(metrics['fake_scores'])
        }
    
    def generate_samples(self, num_samples=16, noise=None):
        """Generate sample images"""
        with torch.no_grad():
            if noise is None:
                noise = self.generate_noise(num_samples)
            else:
                noise = noise.to(self.device)
            
            self.generator.eval()
            samples = self.generator(noise)
            self.generator.train()
            
            return samples.cpu()
    
    def save_models(self, path_prefix):
        """Save generator and critic models"""
        torch.save({
            'generator_state_dict': self.generator.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'g_optimizer_state_dict': self.g_optimizer.state_dict(),
            'c_optimizer_state_dict': self.c_optimizer.state_dict(),
            'config': {
                'latent_dim': self.latent_dim,
                'img_size': self.img_size,
                'channels': self.channels,
                'batch_size': self.batch_size,
                'lr': self.lr,
                'clip_value': self.clip_value,
                'critic_iterations': self.critic_iterations
            }
        }, f"{path_prefix}_wgan.pth")
        self.logger.info(f"Models saved to {path_prefix}_wgan.pth")
    
    def load_models(self, path):
        """Load generator and critic models"""
        self.logger.info(f"Loading models from {path}")
        checkpoint = torch.load(path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.critic.load_state_dict(checkpoint['critic_state_dict'])
        self.g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        self.c_optimizer.load_state_dict(checkpoint['c_optimizer_state_dict'])
        self.logger.info(f"Models loaded successfully from {path}")

logger.info("WGAN class created successfully!")


In [None]:
class WGANLogger:
    """
    Logger for WGAN training (similar to gan_utils.py Logger)
    
    Provides logging, visualization, and model saving functionality
    """
    def __init__(self, model_name, data_name, save_dir="./data"):
        self.model_name = model_name
        self.data_name = data_name
        self.save_dir = save_dir
        self.logger = logging.getLogger(f'{__name__}.WGANLogger')
        
        self.comment = '{}_{}'.format(model_name, data_name)
        self.data_subdir = '{}/{}'.format(model_name, data_name)
        
        # Create directories
        self.images_dir = f'{save_dir}/images/{self.data_subdir}'
        self.models_dir = f'{save_dir}/models/{self.data_subdir}'
        self._make_dir(self.images_dir)
        self._make_dir(self.models_dir)
        
        # Training metrics storage
        self.metrics_history = defaultdict(list)
        
        self.logger.info("WGANLogger initialized:")
        self.logger.info(f"  Model: {model_name}")
        self.logger.info(f"  Dataset: {data_name}")
        self.logger.info(f"  Save directory: {save_dir}")
    
    def log_metrics(self, metrics, epoch, batch_idx):
        """Log training metrics"""
        # Store metrics
        for key, value in metrics.items():
            self.metrics_history[key].append(value)
        
        # Log progress
        if batch_idx % 50 == 0:
            self.logger.info(f"Epoch [{epoch}] Batch [{batch_idx}] "
                           f"C_loss: {metrics.get('critic_loss', 0):.4f} "
                           f"G_loss: {metrics.get('generator_loss', 0):.4f} "
                           f"W_dist: {metrics.get('wasserstein_distance', 0):.4f}")
    
    def log_images(self, wgan, epoch, batch_idx, num_images=16):
        """
        Generate and save sample images
        
        Args:
            wgan: WGAN model instance
            epoch: Current epoch
            batch_idx: Current batch index
            num_images: Number of images to generate
        """
        # Generate samples
        samples = wgan.generate_samples(num_images, wgan.fixed_noise[:num_images])
        
        # Create grid
        nrows = int(np.sqrt(num_images))
        fig, axes = plt.subplots(nrows, nrows, figsize=(8, 8))
        
        for i in range(num_images):
            row, col = i // nrows, i % nrows
            
            # Get sample and convert to numpy
            sample = samples[i].numpy()
            
            # Handle different channel configurations
            if sample.shape[0] == 1:  # Grayscale
                img = sample.squeeze(0)  # Remove channel dimension
                axes[row, col].imshow(img, cmap='gray', vmin=-1, vmax=1)
            elif sample.shape[0] == 3:  # RGB
                img = sample.transpose(1, 2, 0)  # Convert to (H, W, C)
                # Denormalize from [-1, 1] to [0, 1]
                img = (img + 1) / 2
                axes[row, col].imshow(np.clip(img, 0, 1))
            
            axes[row, col].axis('off')
        
        plt.tight_layout()
        
        # Save image
        filename = f'{self.images_dir}/epoch_{epoch}_batch_{batch_idx}.png'
        plt.savefig(filename, bbox_inches='tight', dpi=150)
        
        # Display if in notebook
        if batch_idx % 200 == 0:  # Show every 200 batches
            plt.show()
        else:
            plt.close()
    
    def display_status(self, epoch, num_epochs, batch_idx, num_batches, metrics):
        """Display training status"""
        progress = (batch_idx / num_batches) * 100
        
        self.logger.info(f'Epoch: [{epoch}/{num_epochs}] '
                        f'Batch: [{batch_idx}/{num_batches}] '
                        f'({progress:.1f}%)')
        self.logger.info(f'Critic Loss: {metrics.get("critic_loss", 0):.4f}')
        self.logger.info(f'Generator Loss: {metrics.get("generator_loss", 0):.4f}')
        self.logger.info(f'Wasserstein Distance: {metrics.get("wasserstein_distance", 0):.4f}')
        self.logger.info(f'Real Score: {metrics.get("real_score", 0):.4f}')
        self.logger.info(f'Fake Score: {metrics.get("fake_score", 0):.4f}')
    
    def save_models(self, wgan, epoch):
        """Save WGAN models"""
        filename = f'{self.models_dir}/wgan_epoch_{epoch}.pth'
        wgan.save_models(filename.replace('.pth', ''))
    
    def plot_training_progress(self):
        """Plot training metrics over time"""
        if not self.metrics_history:
            self.logger.warning("No metrics to plot yet")
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # Plot critic loss
        if 'critic_loss' in self.metrics_history:
            axes[0, 0].plot(self.metrics_history['critic_loss'])
            axes[0, 0].set_title('Critic Loss')
            axes[0, 0].set_xlabel('Iteration')
            axes[0, 0].set_ylabel('Loss')
        
        # Plot generator loss
        if 'generator_loss' in self.metrics_history:
            axes[0, 1].plot(self.metrics_history['generator_loss'])
            axes[0, 1].set_title('Generator Loss')
            axes[0, 1].set_xlabel('Iteration')
            axes[0, 1].set_ylabel('Loss')
        
        # Plot Wasserstein distance
        if 'wasserstein_distance' in self.metrics_history:
            axes[1, 0].plot(self.metrics_history['wasserstein_distance'])
            axes[1, 0].set_title('Wasserstein Distance')
            axes[1, 0].set_xlabel('Iteration')
            axes[1, 0].set_ylabel('Distance')
        
        # Plot real vs fake scores
        if 'real_score' in self.metrics_history and 'fake_score' in self.metrics_history:
            axes[1, 1].plot(self.metrics_history['real_score'], label='Real Score')
            axes[1, 1].plot(self.metrics_history['fake_score'], label='Fake Score')
            axes[1, 1].set_title('Critic Scores')
            axes[1, 1].set_xlabel('Iteration')
            axes[1, 1].set_ylabel('Score')
            axes[1, 1].legend()
        
        plt.tight_layout()
        plt.show()
    
    def generate_final_samples(self, wgan, num_samples=64):
        """Generate a large grid of final samples"""
        samples = wgan.generate_samples(num_samples)
        
        nrows = int(np.sqrt(num_samples))
        fig, axes = plt.subplots(nrows, nrows, figsize=(12, 12))
        
        for i in range(num_samples):
            row, col = i // nrows, i % nrows
            
            sample = samples[i].numpy()
            
            if sample.shape[0] == 1:  # Grayscale
                img = sample.squeeze(0)
                axes[row, col].imshow(img, cmap='gray', vmin=-1, vmax=1)
            elif sample.shape[0] == 3:  # RGB
                img = sample.transpose(1, 2, 0)
                img = (img + 1) / 2
                axes[row, col].imshow(np.clip(img, 0, 1))
            
            axes[row, col].axis('off')
        
        plt.suptitle(f'Final Generated Samples - {self.comment}', fontsize=16)
        plt.tight_layout()
        
        # Save final samples
        filename = f'{self.images_dir}/final_samples.png'
        plt.savefig(filename, bbox_inches='tight', dpi=200)
        plt.show()
    
    @staticmethod
    def _make_dir(directory):
        """Create directory if it doesn't exist"""
        os.makedirs(directory, exist_ok=True)

logger.info("WGANLogger class created successfully!")


In [None]:
# Training Function
def train_wgan(wgan, dataloader, wgan_logger, num_epochs=25, save_every=5):
    """
    Train WGAN with proper logging and visualization
    
    Args:
        wgan: WGAN model instance
        dataloader: Training data loader
        wgan_logger: WGANLogger instance
        num_epochs: Number of training epochs
        save_every: Save models every N epochs
    """
    train_logger = logging.getLogger('WGAN.Training')
    train_logger.info(f"Starting WGAN training for {num_epochs} epochs...")
    train_logger.info("="*80)
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        
        for batch_idx, (real_images, _) in enumerate(dataloader):
            # Train WGAN
            metrics = wgan.train_step(real_images)
            
            # Log metrics
            wgan_logger.log_metrics(metrics, epoch, batch_idx)
            
            # Generate and save images periodically
            if batch_idx % 100 == 0:
                wgan_logger.log_images(wgan, epoch, batch_idx)
            
            # Display detailed status periodically
            if batch_idx % 200 == 0:
                wgan_logger.display_status(epoch, num_epochs, batch_idx, len(dataloader), metrics)
        
        # End of epoch summary
        epoch_time = time.time() - epoch_start
        train_logger.info(f"Epoch [{epoch}/{num_epochs}] completed in {epoch_time:.2f}s")
        
        # Save models periodically
        if (epoch + 1) % save_every == 0:
            wgan_logger.save_models(wgan, epoch)
        
        # Plot progress periodically
        if (epoch + 1) % 5 == 0:
            wgan_logger.plot_training_progress()
    
    total_time = time.time() - start_time
    train_logger.info(f"Training completed in {total_time/60:.2f} minutes!")
    
    # Generate final samples
    wgan_logger.generate_final_samples(wgan, num_samples=64)
    
    return wgan_logger

logger.info("Training function created successfully!")


In [None]:
# MNIST Training Example
mnist_logger = logging.getLogger('WGAN.MNIST_Example')
mnist_logger.info("="*80)
mnist_logger.info("WGAN TRAINING ON MNIST DATASET")
mnist_logger.info("="*80)

# Initialize dataset loader
dataset_loader = DataSetLoader(batch_size=64, data_dir='./data')

# Load MNIST data
mnist_dataloader, num_batches, img_size, channels = dataset_loader.load_mnist()
mnist_logger.info(f"MNIST - Batches: {num_batches}, Image size: {img_size}, Channels: {channels}")

# Create WGAN for MNIST
wgan_mnist = WGAN(
    latent_dim=100,
    img_size=img_size,
    channels=channels,
    hidden_dim=64,
    batch_size=64,
    lr=5e-5,  # Lower learning rate for WGAN
    device=device,
    clip_value=0.01,
    critic_iterations=5
)

mnist_logger.info("WGAN MNIST Configuration:")
mnist_logger.info(f"  Latent dimension: {wgan_mnist.latent_dim}")
mnist_logger.info(f"  Image size: {wgan_mnist.img_size}")
mnist_logger.info(f"  Channels: {wgan_mnist.channels}")
mnist_logger.info(f"  Learning rate: {wgan_mnist.lr}")
mnist_logger.info(f"  Clip value: {wgan_mnist.clip_value}")
mnist_logger.info(f"  Critic iterations: {wgan_mnist.critic_iterations}")

# Create logger
logger_mnist = WGANLogger("WGAN", "MNIST", save_dir="./data")

# Generate initial samples (before training)
mnist_logger.info("Generating initial samples (before training)...")
initial_samples = wgan_mnist.generate_samples(16)
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i in range(16):
    row, col = i // 4, i % 4
    img = initial_samples[i, 0].numpy()
    axes[row, col].imshow(img, cmap='gray', vmin=-1, vmax=1)
    axes[row, col].axis('off')
plt.suptitle('Initial Generated Samples (Before Training)', fontsize=16)
plt.tight_layout()
plt.show()

# Train WGAN on MNIST (start with fewer epochs for demonstration)
mnist_logger.info("Starting MNIST training...")
logger_mnist = train_wgan(
    wgan=wgan_mnist,
    dataloader=mnist_dataloader,
    wgan_logger=logger_mnist,
    num_epochs=10,  # Start with 10 epochs for quick demo
    save_every=5
)


In [None]:
# CIFAR10 Training Example
cifar_logger = logging.getLogger('WGAN.CIFAR10_Example')
cifar_logger.info("="*80)
cifar_logger.info("WGAN TRAINING ON CIFAR10 DATASET")
cifar_logger.info("="*80)

# Load CIFAR10 data
cifar_dataloader, num_batches, img_size, channels = dataset_loader.load_cifar10()
cifar_logger.info(f"CIFAR10 - Batches: {num_batches}, Image size: {img_size}, Channels: {channels}")

# Create WGAN for CIFAR10
wgan_cifar = WGAN(
    latent_dim=100,
    img_size=img_size,
    channels=channels,
    hidden_dim=64,
    batch_size=64,
    lr=5e-5,  # Lower learning rate for WGAN
    device=device,
    clip_value=0.01,
    critic_iterations=5
)

cifar_logger.info("WGAN CIFAR10 Configuration:")
cifar_logger.info(f"  Latent dimension: {wgan_cifar.latent_dim}")
cifar_logger.info(f"  Image size: {wgan_cifar.img_size}")
cifar_logger.info(f"  Channels: {wgan_cifar.channels}")
cifar_logger.info(f"  Learning rate: {wgan_cifar.lr}")
cifar_logger.info(f"  Clip value: {wgan_cifar.clip_value}")
cifar_logger.info(f"  Critic iterations: {wgan_cifar.critic_iterations}")

# Create logger
logger_cifar = WGANLogger("WGAN", "CIFAR10", save_dir="./data")

# Generate initial samples (before training)
cifar_logger.info("Generating initial samples (before training)...")
initial_samples_cifar = wgan_cifar.generate_samples(16)
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i in range(16):
    row, col = i // 4, i % 4
    img = initial_samples_cifar[i].numpy().transpose(1, 2, 0)
    img = (img + 1) / 2  # Denormalize to [0, 1]
    axes[row, col].imshow(np.clip(img, 0, 1))
    axes[row, col].axis('off')
plt.suptitle('Initial Generated Samples (Before Training) - CIFAR10', fontsize=16)
plt.tight_layout()
plt.show()

# Train WGAN on CIFAR10 (start with fewer epochs for demonstration)
cifar_logger.info("Starting CIFAR10 training...")
logger_cifar = train_wgan(
    wgan=wgan_cifar,
    dataloader=cifar_dataloader,
    wgan_logger=logger_cifar,
    num_epochs=10,  # Start with 10 epochs for quick demo
    save_every=5
)


In [None]:
# WGAN-GP (Gradient Penalty) Implementation
class WGAN_GP(WGAN):
    """
    WGAN with Gradient Penalty
    
    Replaces weight clipping with gradient penalty for better stability
    and improved training dynamics.
    """
    def __init__(self, latent_dim=100, img_size=[64, 64], channels=1, 
                 hidden_dim=64, batch_size=64, lr=1e-4, device=None,
                 lambda_gp=10, critic_iterations=5):
        
        # Initialize parent class (but don't use clip_value)
        super().__init__(
            latent_dim=latent_dim,
            img_size=img_size,
            channels=channels,
            hidden_dim=hidden_dim,
            batch_size=batch_size,
            lr=lr,
            device=device,
            clip_value=None,  # Not used in WGAN-GP
            critic_iterations=critic_iterations
        )
        
        self.lambda_gp = lambda_gp
        
        # Use Adam optimizer for WGAN-GP (works better than RMSprop)
        self.g_optimizer = optim.Adam(self.generator.parameters(), lr=lr, betas=(0.0, 0.9))
        self.c_optimizer = optim.Adam(self.critic.parameters(), lr=lr, betas=(0.0, 0.9))
        
        self.logger.info(f"WGAN-GP initialized with gradient penalty lambda: {lambda_gp}")
    
    def gradient_penalty(self, real_images, fake_images):
        """
        Calculate gradient penalty for WGAN-GP
        
        Args:
            real_images: Real images from dataset
            fake_images: Generated fake images
            
        Returns:
            Gradient penalty term
        """
        batch_size = real_images.size(0)
        
        # Random weight term for interpolation
        alpha = torch.rand(batch_size, 1, 1, 1, device=self.device)
        
        # Get interpolated images
        interpolated = alpha * real_images + (1 - alpha) * fake_images
        interpolated.requires_grad_(True)
        
        # Get critic scores for interpolated images
        interpolated_scores = self.critic(interpolated)
        
        # Calculate gradients
        gradients = torch.autograd.grad(
            outputs=interpolated_scores,
            inputs=interpolated,
            grad_outputs=torch.ones_like(interpolated_scores),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        
        # Calculate gradient penalty
        gradients = gradients.view(batch_size, -1)
        gradient_norm = gradients.norm(2, dim=1)
        gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
        
        return gradient_penalty
    
    def train_critic(self, real_images):
        """
        Train critic with gradient penalty instead of weight clipping
        """
        batch_size = real_images.size(0)
        real_images = real_images.to(self.device)
        
        # Reset gradients
        self.c_optimizer.zero_grad()
        
        # Train on real images
        real_scores = self.critic(real_images)
        real_loss = -torch.mean(real_scores)
        
        # Train on fake images
        noise = self.generate_noise(batch_size)
        fake_images = self.generator(noise).detach()
        fake_scores = self.critic(fake_images)
        fake_loss = torch.mean(fake_scores)
        
        # Calculate gradient penalty
        gp = self.gradient_penalty(real_images, fake_images)
        
        # Total critic loss
        critic_loss = real_loss + fake_loss + self.lambda_gp * gp
        critic_loss.backward()
        
        # Update critic (no weight clipping needed!)
        self.c_optimizer.step()
        
        return {
            'critic_loss': critic_loss.item(),
            'real_loss': real_loss.item(),
            'fake_loss': fake_loss.item(),
            'gradient_penalty': gp.item(),
            'real_score': real_scores.mean().item(),
            'fake_score': fake_scores.mean().item(),
            'wasserstein_distance': -(real_scores.mean() - fake_scores.mean()).item()
        }

# Extended Training Examples and Utilities
utils_logger = logging.getLogger('WGAN.Utilities')
utils_logger.info("="*80)
utils_logger.info("EXTENDED TRAINING EXAMPLES AND UTILITIES")
utils_logger.info("="*80)

def compare_wgan_vs_wgan_gp(dataset_name="mnist", num_epochs=15):
    """
    Compare WGAN vs WGAN-GP on the same dataset
    """
    compare_logger = logging.getLogger('WGAN.Comparison')
    compare_logger.info(f"Comparing WGAN vs WGAN-GP on {dataset_name.upper()}...")
    
    # Load dataset
    if dataset_name.lower() == "mnist":
        dataloader, _, img_size, channels = dataset_loader.load_mnist()
    else:
        dataloader, _, img_size, channels = dataset_loader.load_cifar10()
    
    # Create both models
    wgan_original = WGAN(
        latent_dim=100, img_size=img_size, channels=channels,
        hidden_dim=64, batch_size=64, lr=5e-5, device=device,
        clip_value=0.01, critic_iterations=5
    )
    
    wgan_gp = WGAN_GP(
        latent_dim=100, img_size=img_size, channels=channels,
        hidden_dim=64, batch_size=64, lr=1e-4, device=device,
        lambda_gp=10, critic_iterations=5
    )
    
    # Create loggers
    logger_original = WGANLogger("WGAN_Original", dataset_name.upper())
    logger_gp = WGANLogger("WGAN_GP", dataset_name.upper())
    
    compare_logger.info("Training original WGAN...")
    train_wgan(wgan_original, dataloader, logger_original, num_epochs=num_epochs//2, save_every=5)
    
    compare_logger.info("Training WGAN-GP...")
    train_wgan(wgan_gp, dataloader, logger_gp, num_epochs=num_epochs//2, save_every=5)
    
    # Compare results
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Compare losses
    axes[0, 0].plot(logger_original.metrics_history['critic_loss'], label='WGAN', alpha=0.7)
    axes[0, 0].plot(logger_gp.metrics_history['critic_loss'], label='WGAN-GP', alpha=0.7)
    axes[0, 0].set_title('Critic Loss Comparison')
    axes[0, 0].legend()
    
    axes[0, 1].plot(logger_original.metrics_history['generator_loss'], label='WGAN', alpha=0.7)
    axes[0, 1].plot(logger_gp.metrics_history['generator_loss'], label='WGAN-GP', alpha=0.7)
    axes[0, 1].set_title('Generator Loss Comparison')
    axes[0, 1].legend()
    
    # Compare Wasserstein distances
    axes[1, 0].plot(logger_original.metrics_history['wasserstein_distance'], label='WGAN', alpha=0.7)
    axes[1, 0].plot(logger_gp.metrics_history['wasserstein_distance'], label='WGAN-GP', alpha=0.7)
    axes[1, 0].set_title('Wasserstein Distance Comparison')
    axes[1, 0].legend()
    
    # Compare sample quality (show final samples)
    axes[1, 1].axis('off')
    axes[1, 1].text(0.5, 0.5, f'Check generated samples\\nfrom both models above', 
                   ha='center', va='center', fontsize=12)
    
    plt.suptitle(f'WGAN vs WGAN-GP Comparison on {dataset_name.upper()}', fontsize=16)
    plt.tight_layout()
    plt.show()

# Usage instructions
usage_info = """
USAGE EXAMPLES:

1. Quick WGAN training (already done above):
   - MNIST and CIFAR10 models trained for 10 epochs each

2. Extended training:
   train_wgan(wgan_mnist, mnist_dataloader, logger_mnist, num_epochs=50, save_every=10)

3. WGAN-GP training:
   wgan_gp = WGAN_GP(latent_dim=100, img_size=[64,64], channels=1, lambda_gp=10)
   logger_gp = WGANLogger("WGAN_GP", "MNIST")
   train_wgan(wgan_gp, mnist_dataloader, logger_gp, num_epochs=25)

4. Compare WGAN vs WGAN-GP:
   compare_wgan_vs_wgan_gp("mnist", num_epochs=20)
   compare_wgan_vs_wgan_gp("cifar10", num_epochs=20)

5. Load saved models:
   wgan_mnist.load_models("./data/models/WGAN/MNIST/wgan_epoch_9_wgan.pth")

6. Generate new samples:
   samples = wgan_mnist.generate_samples(64)
   # Display samples...

KEY FEATURES IMPLEMENTED:
✓ Proper Wasserstein distance calculation
✓ Critic instead of discriminator (no sigmoid)
✓ Weight clipping for Lipschitz constraint
✓ WGAN-GP with gradient penalty
✓ Professional training pipeline with structured logging
✓ Comprehensive logging and visualization
✓ Support for both MNIST and CIFAR10
✓ Model saving and loading
✓ Training progress visualization
✓ File and console logging with timestamps
"""

utils_logger.info(usage_info)
