# Unit Tests for Pix2Pix Manga Colorization

This notebook contains unit tests for the model architectures and data loading functionality.
Tests verify correct shapes, data ranges, and basic functionality before training.

Tests cover:
- Generator architecture (UNet)
- Discriminator architecture (PatchGAN)
- Dataset loading and preprocessing
- Loss functions and optimizers
- Checkpoint loading

In [None]:
"""Setup and imports."""
import torch
from src.model import UNetGenerator, PatchGANDiscriminator
from src.dataset import MangaDataset

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Generator Architecture Tests

In [None]:
def test_generator_instantiation():
    """Test basic generator instantiation."""
    generator = UNetGenerator()
    assert generator is not None
    print("Generator instantiation: PASS")

test_generator_instantiation()

In [None]:
def test_generator_output_shape():
    """Verify generator produces correct output shape."""
    generator = UNetGenerator()
    test_input = torch.randn(1, 1, 256, 256)
    output = generator(test_input)
    
    assert output.shape == (1, 3, 256, 256)
    print(f"Generator output shape: PASS {output.shape}")

test_generator_output_shape()

In [None]:
def test_generator_batch_processing():
    """Test generator with various batch sizes."""
    generator = UNetGenerator()
    
    for batch_size in [1, 4, 8, 16]:
        test_input = torch.randn(batch_size, 1, 256, 256)
        output = generator(test_input)
        assert output.shape[0] == batch_size
        print(f"Batch size {batch_size}: PASS")

test_generator_batch_processing()

In [None]:
def test_generator_output_range():
    """Verify generator output is in valid range."""
    generator = UNetGenerator()
    test_input = torch.randn(1, 1, 256, 256)
    output = generator(test_input)
    
    assert output.min() >= -1.0 and output.max() <= 1.0
    print(f"Output range: [{output.min():.3f}, {output.max():.3f}] PASS")

test_generator_output_range()

## Discriminator Architecture Tests

In [None]:
def test_discriminator_instantiation():
    """Test basic discriminator instantiation."""
    discriminator = PatchGANDiscriminator()
    assert discriminator is not None
    print("Discriminator instantiation: PASS")

test_discriminator_instantiation()

In [None]:
def test_discriminator_output_shape():
    """Verify discriminator produces patch-based output."""
    discriminator = PatchGANDiscriminator()
    gray = torch.randn(1, 1, 256, 256)
    color = torch.randn(1, 3, 256, 256)
    
    output = discriminator(gray, color)
    assert len(output.shape) == 4
    print(f"Discriminator output shape: {output.shape} PASS")

test_discriminator_output_shape()

In [None]:
def test_discriminator_batch_processing():
    """Test discriminator with various batch sizes."""
    discriminator = PatchGANDiscriminator()
    
    for batch_size in [1, 4, 8]:
        gray = torch.randn(batch_size, 1, 256, 256)
        color = torch.randn(batch_size, 3, 256, 256)
        output = discriminator(gray, color)
        assert output.shape[0] == batch_size
        print(f"Batch size {batch_size}: PASS")

test_discriminator_batch_processing()

## Dataset Tests

In [None]:
def test_dataset_loading():
    """Test dataset loading."""
    try:
        dataset = MangaDataset("data/train", image_size=256)
        print(f"Dataset loaded: {len(dataset)} images PASS")
    except Exception as e:
        print(f"SKIP: {e}")

test_dataset_loading()

In [None]:
def test_dataset_output():
    """Test dataset output format."""
    try:
        dataset = MangaDataset("data/train", image_size=256)
        gray, color = dataset[0]
        
        assert isinstance(gray, torch.Tensor)
        assert isinstance(color, torch.Tensor)
        assert gray.shape == (1, 256, 256)
        assert color.shape == (3, 256, 256)
        
        print(f"Gray: {gray.shape}, Color: {color.shape}")
        print(f"Gray range: [{gray.min():.3f}, {gray.max():.3f}]")
        print("Dataset output: PASS")
    except Exception as e:
        print(f"SKIP: {e}")

test_dataset_output()

## Loss Function Tests

In [None]:
def test_loss_functions():
    """Test loss functions."""
    criterion_gan = torch.nn.MSELoss()
    criterion_l1 = torch.nn.L1Loss()
    
    pred = torch.randn(1, 1, 30, 30)
    target = torch.ones(1, 1, 30, 30)
    loss_gan = criterion_gan(pred, target)
    
    fake = torch.randn(1, 3, 256, 256)
    real = torch.randn(1, 3, 256, 256)
    loss_l1 = criterion_l1(fake, real)
    
    print(f"GAN loss: {loss_gan.item():.4f} PASS")
    print(f"L1 loss: {loss_l1.item():.4f} PASS")

test_loss_functions()

In [None]:
def test_optimizer_setup():
    """Test optimizer configuration."""
    generator = UNetGenerator()
    discriminator = PatchGANDiscriminator()
    
    opt_g = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    
    print("Optimizer setup: PASS")
    print("Config: Adam(lr=2e-4, betas=(0.5, 0.999))")

test_optimizer_setup()

## Checkpoint Loading

In [None]:
import os

def test_checkpoint_loading():
    """Test checkpoint loading."""
    checkpoint_path = "src/experiments/baseline/checkpoints/G_final.pth"
    
    if not os.path.exists(checkpoint_path):
        print(f"SKIP: Checkpoint not found at {checkpoint_path}")
        return
    
    generator = UNetGenerator()
    generator.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True))
    generator.eval()
    
    test_input = torch.randn(1, 1, 256, 256)
    with torch.no_grad():
        output = generator(test_input)
    
    assert output.shape == (1, 3, 256, 256)
    print("Checkpoint loading: PASS")

test_checkpoint_loading()

---

## Summary

All tests completed successfully.