## 1. Setup and Import Dependencies

In [None]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import os
from pathlib import Path
from torch.utils.data import DataLoader, random_split
import datetime
from tqdm import tqdm

In [None]:
# Import components from mlops
from mlops.src.components.generator import define_G
from mlops.src.components.discriminator import define_D
from mlops.src.components.losses import GANLoss, VGGLoss
from mlops.src.components.replay_pool import ReplayPool
from mlops.src.components.functions import print_network, show_tensor
from mlops.src.models.pix2pixhd_module import Pix2PixHD, Pix2PixHDDataset

## 2. Initialize Device and Models

In [None]:
# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Generator configuration
generator = define_G(
    input_nc=3,
    output_nc=3,
    ngf=64,
    netG="global",
    norm="instance",
    n_downsample_global=3,
    n_blocks_global=9,
    n_local_enhancers=1,
    n_blocks_local=3,
    gpu_ids=[]
).to(device)

print("\nGenerator Architecture:")
print_network(generator)

In [None]:
# Discriminator configuration
discriminator = define_D(
    input_nc=6,  # 3 (input) + 3 (output)
    ndf=64,
    n_layers_D=3,
    norm="instance",
    use_sigmoid=False,
    num_D=3,  # Multi-scale discriminator
    getIntermFeat=True,
    gpu_ids=[],
    num_outputs=1
).to(device)

print("\nDiscriminator Architecture:")
print_network(discriminator)

## 3. Configure Loss Functions and Optimizers

In [None]:
# Define loss functions
criterion_gan = GANLoss(use_lsgan=True).to(device)
criterion_feat = nn.L1Loss().to(device)
criterion_vgg = VGGLoss().to(device)

print("Loss functions initialized:")
print(f"  - GAN Loss: {type(criterion_gan).__name__}")
print(f"  - Feature Loss: {type(criterion_feat).__name__}")
print(f"  - VGG Loss: {type(criterion_vgg).__name__}")

In [None]:
# Create replay buffer for discriminator training
replay_pool = ReplayPool(pool_size=50)
print(f"Replay pool initialized with size: 50")

In [None]:
# Define optimizers
learning_rate = 1e-4
g_optimizer = torch.optim.AdamW(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
d_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

print(f"Optimizers configured with learning rate: {learning_rate}")

## 4. Prepare Dataset and DataLoaders

In [None]:
# Dataset configuration
dataset_path = Path("./data")  # Adjust to your dataset location
feature_folder = "/sketches/"  # Subfolder for input images
label_folder = "/photos/"      # Subfolder for target images

print(f"Loading dataset from: {dataset_path}")
print(f"  Feature folder: {feature_folder}")
print(f"  Label folder: {label_folder}")

In [None]:
# Create dataset
full_dataset = Pix2PixHDDataset(
    images_dir=str(dataset_path),
    feature_fold=feature_folder,
    label_fold=label_folder,
    img_size=256
)

print(f"Total dataset size: {len(full_dataset)}")

In [None]:
# Create train/test split
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_ds, test_ds = random_split(full_dataset, [train_size, test_size])

print(f"\nData split:")
print(f"  Train size: {len(train_ds)}")
print(f"  Test size: {len(test_ds)}")

In [None]:
# Create DataLoaders
batch_size = 4
num_workers = 4

train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True,
    drop_last=True
)

test_loader = torch.utils.data.DataLoader(
    test_ds,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True
)

print(f"\nDataLoaders created:")
print(f"  Batch size: {batch_size}")
print(f"  Num workers: {num_workers}")

In [None]:
# Visualize sample batch
src_sample, tgt_sample = next(iter(train_loader))
print(f"Sample batch shapes:")
print(f"  Input (src): {src_sample.shape}")
print(f"  Target (tgt): {tgt_sample.shape}")

In [None]:
# Display sample images
print("Input image sample:")
show_tensor(src_sample[0])

print("\nTarget image sample:")
show_tensor(tgt_sample[0])

## 5. Initialize Pix2PixHD Model

In [None]:
# Create checkpoint directory
checkpoint_dir = "./checkpoints/pix2pixhd/"
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(os.path.join(checkpoint_dir, "images"), exist_ok=True)

print(f"Checkpoint directory: {checkpoint_dir}")

In [None]:
# Initialize Pix2PixHD model
model = Pix2PixHD(
    generator=generator,
    discriminator=discriminator,
    criterion_gan=criterion_gan,
    criterion_feat=criterion_feat,
    criterion_vgg=criterion_vgg,
    replay_pool=replay_pool,
    device=device,
    checkpoint_dir=checkpoint_dir,
    lambda_feat=10.0
)

print("Pix2PixHD model initialized successfully!")

## 6. Execute Training

In [None]:
# Training configuration
num_epochs = 15
resume_from_checkpoint = None  # Set to checkpoint path to resume training

# Load checkpoint if resuming
start_epoch = 0
if resume_from_checkpoint is not None:
    print(f"Loading checkpoint from: {resume_from_checkpoint}")
    model.load_checkpoint(resume_from_checkpoint)
    # You may need to extract epoch number from checkpoint filename
    print("Checkpoint loaded successfully!")

In [None]:
# Training loop
print("\n" + "="*80)
print("Starting Pix2PixHD Training")
print("="*80 + "\n")

for epoch in range(start_epoch, num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    print("-" * 80)
    
    try:
        model.train_epoch(
            train_loader=train_loader,
            test_loader=test_loader,
            epoch=epoch,
            g_optimizer=g_optimizer,
            d_optimizer=d_optimizer
        )
        print(f"✓ Epoch {epoch + 1} completed successfully")
        
    except Exception as e:
        print(f"✗ Error during epoch {epoch + 1}: {str(e)}")
        raise

print("\n" + "="*80)
print("Training completed successfully!")
print(f"Checkpoints saved to: {checkpoint_dir}")
print("="*80)

## 7. Inference and Visualization

In [None]:
# Generate predictions on test set
def generate_test_output(model, test_loader, num_samples=4):
    """Generate and visualize model predictions on test samples"""
    model.generator_ema.eval()
    
    with torch.no_grad():
        for i, (src, tgt) in enumerate(test_loader):
            if i >= num_samples:
                break
            
            src = src.to(device)
            tgt = tgt.to(device)
            
            # Generate predictions
            pred = model.generator_ema(src)
            
            print(f"\nSample {i+1}:")
            print("Input:")
            show_tensor(src[0])
            
            print("\nGenerated:")
            show_tensor(pred[0])
            
            print("\nTarget:")
            show_tensor(tgt[0])
    
    model.generator_ema.train()

In [None]:
# Run inference
print("Generating test outputs...")
generate_test_output(model, test_loader, num_samples=2)

## 8. Save Final Model

In [None]:
# Save the final trained model
final_checkpoint_path = os.path.join(checkpoint_dir, "final_model.pt")
torch.save({
    "G": model.generator_ema.state_dict(),
    "D": model.discriminator.state_dict()
}, final_checkpoint_path)

print(f"Final model saved to: {final_checkpoint_path}")