# Changelog

In [None]:
# import os
# working_directory = os.getcwd() + "/image_image_translation"
# print(working_directory)
# os.listdir(working_directory)

In [None]:
!pip install torch_fidelity
!pip install opencv-python
!pip install matplotlib
!pip install pandas

In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import numpy as np
import random
import cv2
import matplotlib.pyplot as plt
import itertools
import pandas as pd
import torch_fidelity
import shutil
import math

from torch.nn.utils import spectral_norm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Generator and Discriminator

In [4]:

"""
Step 1. Define Generator with Edge Awareness
"""
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )
    
    def forward(self, x):
        return x + self.block(x)
        

class SelfAttention(nn.Module):
    """Memory-efficient self attention module for Generator enhancement."""
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        # Reduce channel dimensions more aggressively to save memory
        self.channel_reduction = 16  # More aggressive reduction factor (was 8)
        
        self.query_conv = nn.Conv2d(in_dim, in_dim // self.channel_reduction, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // self.channel_reduction, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))  # learnable weight
        
        # Use scaled dot-product attention for better numerical stability
        self.scale = torch.sqrt(torch.FloatTensor([in_dim // self.channel_reduction]))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        self.scale = self.scale.to(x.device)
        
        # Project queries, keys, and values
        query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)  # B x (W*H) x C'
        key = self.key_conv(x).view(batch_size, -1, width * height)  # B x C' x (W*H)
        value = self.value_conv(x).view(batch_size, -1, width * height)  # B x C x (W*H)
        
        # Calculate attention map with scaling for stability
        attention = torch.bmm(query, key) / self.scale  # B x (W*H) x (W*H)
        attention = F.softmax(attention, dim=-1)
        
        # Apply attention to values
        out = torch.bmm(value, attention.permute(0, 2, 1))  # B x C x (W*H)
        out = out.view(batch_size, C, width, height)  # B x C x W x H
        
        # Apply learnable weight and residual connection
        return self.gamma * out + x

class Generator(nn.Module):
    def __init__(self, in_channels, edge_aware=True, num_residual_blocks=7, use_attention=True):
        super(Generator, self).__init__()
        
        # Edge awareness - if True, we'll process with internal edge detection
        self.edge_aware = edge_aware
        self.use_attention = use_attention
        
        # Initial convolution block
        model_initial = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        
        # Downsampling
        model_down = []
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model_down += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2
        
        # Add attention after downsampling if enabled
        # This is where feature maps are at 1/4 size but rich in information
        if use_attention:
            self.attention1 = SelfAttention(in_features)
        
        # Residual blocks
        model_res = []
        for _ in range(num_residual_blocks):
            model_res.append(ResidualBlock(in_features))
        
        # Add attention after residual blocks if enabled
        # This allows attention at the bottleneck where transformation is most critical
        if use_attention:
            self.attention2 = SelfAttention(in_features)
        
        # Upsampling
        model_up = []
        out_features = in_features // 2
        for _ in range(2):
            model_up += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2
        
        # Output layer
        model_output = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, in_channels, 7),
            nn.Tanh()
        ]
        
        # Store model sections separately
        self.model_initial = nn.Sequential(*model_initial)
        self.model_down = nn.Sequential(*model_down)
        self.model_res = nn.Sequential(*model_res)
        self.model_up = nn.Sequential(*model_up)
        self.model_output = nn.Sequential(*model_output)
    
    def extract_edges(self, x):
        # Your existing edge detection code (unchanged)
        x_gray = torch.mean(x, dim=1, keepdim=True)
        
        # Apply simple edge detection using Sobel filters
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).to(x.device)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).to(x.device)
        
        sobel_x = sobel_x.view(1, 1, 3, 3).repeat(1, 1, 1, 1)
        sobel_y = sobel_y.view(1, 1, 3, 3).repeat(1, 1, 1, 1)
        
        # Apply convolution for edge detection
        edge_x = nn.functional.conv2d(x_gray, sobel_x, padding=1)
        edge_y = nn.functional.conv2d(x_gray, sobel_y, padding=1)
        
        # Calculate edge magnitude with a small epsilon to prevent numerical issues
        edge_mag = torch.sqrt(edge_x ** 2 + edge_y ** 2 + 1e-6)
        
        # More moderate combination of edge detectors
        laplacian = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32).to(x.device)
        laplacian = laplacian.view(1, 1, 3, 3).repeat(1, 1, 1, 1)
        edge_laplacian = torch.abs(nn.functional.conv2d(x_gray, laplacian, padding=1))
        
        # Combine edge detectors with more balanced weighting (increased Sobel influence)
        edge_mag = 0.85 * edge_mag + 0.15 * edge_laplacian
        
        # Safe batch-wise normalization to [0, 1]
        batch_max = torch.max(edge_mag.view(edge_mag.size(0), -1), dim=1, keepdim=True)[0].view(edge_mag.size(0), 1, 1, 1)
        edge_mag = edge_mag / (batch_max + 1e-6)
        
        # More selective thresholding (higher threshold)
        edge_mag = torch.sigmoid((edge_mag - 0.2) * 10) * edge_mag
        
        return edge_mag
    
    def forward(self, x):
        # Initial processing
        x = self.model_initial(x)
        
        # Downsampling
        x = self.model_down(x)
        
        # Apply first attention after downsampling if enabled
        if self.use_attention:
            x = self.attention1(x)
        
        # Residual blocks
        x = self.model_res(x)
        
        # Apply second attention after residual blocks if enabled
        if self.use_attention:
            x = self.attention2(x)
        
        # Upsampling
        x = self.model_up(x)
        
        # Output layer
        output = self.model_output(x)
        
        # Apply edge enhancement if edge-aware
        if self.edge_aware:
            edges = self.extract_edges(x)
            edge_contribution = 0.15 * edges
            output = torch.clamp(output + edge_contribution, -1, 1)
        
        return output

"""
Step 2. Define Discriminator
"""
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        
        # Scale factor for output size calculation
        self.scale_factor = 16
        
        # Use spectral normalization for all conv layers to stabilize training
        model = [
            # Initial layer - no instance norm
            spectral_norm(nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Increasing depth, reducing spatial dimensions
            spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            spectral_norm(nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Final layer - output is not normalized
            spectral_norm(nn.Conv2d(512, 1, kernel_size=4, padding=1))
        ]
        
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)


In [5]:
"""
Step 3. Define Loss
"""
# Binary Cross Entropy loss for GAN with reduction mode to help with stability
criterion_GAN = nn.MSELoss(reduction='mean')  # Using MSE loss for stability
# L1 loss for cycle consistency with reduction mode
criterion_cycle = nn.L1Loss(reduction='mean')
# L1 loss for identity preservation with reduction mode
criterion_identity = nn.L1Loss(reduction='mean')


# Helper function to detect and handle NaN/Inf values in gradients
def clip_gradients(parameters, max_norm=1.0):
    """Clips gradients to prevent exploding gradients"""
    torch.nn.utils.clip_grad_norm_(parameters, max_norm)

# Initalisation and Data Loading

In [None]:
"""
Step 4. Initialize G and D
"""
# Initialize generators with attention mechanism
G_AB = Generator(3, edge_aware=True, num_residual_blocks=8, use_attention=True)  # Real to Comic with dual attention
G_BA = Generator(3, edge_aware=False, num_residual_blocks=8, use_attention=True)  # Comic to Real with dual attention


# Initialize discriminators
D_A = Discriminator(3)  # Discriminator for domain A (real faces)
D_B = Discriminator(3)  # Discriminator for domain B (comic faces)

## Total parameters in CycleGAN should be less than 60MB
total_params = sum(p.numel() for p in G_AB.parameters()) + \
               sum(p.numel() for p in G_BA.parameters()) + \
               sum(p.numel() for p in D_A.parameters()) + \
               sum(p.numel() for p in D_B.parameters())


"""
# modification of parameters computation is forbidden
"""
total_params_million = total_params / (1024 * 1024)
print(f'Total parameters in CycleGAN model: {total_params_million:.2f} million')

cuda = torch.cuda.is_available()
print(f'cuda: {cuda}')
if cuda:
    G_AB = G_AB.cuda()
    D_B = D_B.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()


criterion_GAN = criterion_GAN.cuda() if cuda else criterion_GAN
criterion_cycle = criterion_cycle.cuda() if cuda else criterion_cycle
criterion_identity = criterion_identity.cuda() if cuda else criterion_identity

"""
Step 5. Configure Optimizers
"""
lr = 0.0001  # Reduced learning rate for more stability (from 0.0002 to 0.0001)
# Add weight decay to improve stability and prevent overfitting
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), 
                              lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)

# Fine-tuned learning rate scheduler with early decay and cosine annealing
def create_warmup_cosine_scheduler(optimizer, n_epochs, warmup_epochs=3):
    import math
    
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            # Gradual warmup
            return epoch / warmup_epochs
        else:
            # Cosine annealing with longer initial plateau
            progress = (epoch - warmup_epochs) / (n_epochs - warmup_epochs)
            # Modified cosine that starts with less decay
            return 0.5 * (1 + math.cos(math.pi * progress ** 1.5))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# Create schedulers
scheduler_G = create_warmup_cosine_scheduler(optimizer_G, n_epochs=150)
scheduler_D_A = create_warmup_cosine_scheduler(optimizer_D_A, n_epochs=150)
scheduler_D_B = create_warmup_cosine_scheduler(optimizer_D_B, n_epochs=150)

"""
Step 6. DataLoader
"""
class ImageDataset(Dataset):
    def __init__(self, data_dir, mode='train', transforms=None):
        A_dir = os.path.join(data_dir, 'VAE_generation/train')  # modification forbidden
        B_dir = os.path.join(data_dir, 'VAE_generation_Cartoon/train')  # modification forbidden

        if mode == 'train':
            self.files_A = [os.path.join(A_dir, name) for name in sorted(os.listdir(A_dir))[:3200]]  # Adjusted to use more data
            self.files_B = [os.path.join(B_dir, name) for name in sorted(os.listdir(B_dir))[:3200]]  # Adjusted to use more data
        elif mode == 'valid':
            self.files_A = [os.path.join(A_dir, name) for name in sorted(os.listdir(A_dir))[3200:4000]]  # Adjusted for validation
            self.files_B = [os.path.join(B_dir, name) for name in sorted(os.listdir(B_dir))[3200:4000]]  # Adjusted for validation

        self.transforms = transforms

    def __len__(self):
        return len(self.files_A)

    def __getitem__(self, index):
        file_A = self.files_A[index]
        file_B = self.files_B[index % len(self.files_B)]  # Ensure we don't go out of bounds

        img_A = Image.open(file_A)
        img_B = Image.open(file_B)

        if self.transforms is not None:
            img_A = self.transforms(img_A)
            img_B = self.transforms(img_B)

        return img_A, img_B

data_dir = '/kaggle/input/image-image-translation/image_image_translation'  # TODO - Update this when submitting (Kaggle)
# data_dir = working_directory  # TODO - Update this when submitting (Jupyter)

image_size = (256, 256)
transforms_ = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 4

trainloader = DataLoader(
    ImageDataset(data_dir, mode='train', transforms=transforms_),
    batch_size=batch_size,
    shuffle=True,
    num_workers=3
)

validloader = DataLoader(
    ImageDataset(data_dir, mode='valid', transforms=transforms_),
    batch_size=batch_size,
    shuffle=False,
    num_workers=3
)

Total parameters in CycleGAN model: 25.01 million
cuda: True


In [7]:
def sample_images(real_A, real_B):
    """
    Generate and display sample translations with flexible batch size support.
    
    Args:
        real_A: Batch of images from domain A (real faces)
        real_B: Batch of images from domain B (comic style)
    """
    G_AB.eval()
    G_BA.eval()
    
    with torch.no_grad():
        fake_B = G_AB(real_A)
        fake_A = G_BA(real_B)
        recov_A = G_BA(fake_B)
        recov_B = G_AB(fake_A)
    
    # Get actual batch size from input
    batch_size = min(real_A.size(0), real_B.size(0))
    
    # Create a dynamic figure size based on batch size
    fig, ax = plt.subplots(4, batch_size, figsize=(3*batch_size, 12))
    
    # Handle the case where batch_size is 1
    if batch_size == 1:
        ax = ax.reshape(4, 1)
    
    for i in range(batch_size):
        ax[0, i].imshow(real_A[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        ax[0, i].set_title("Real A")
        ax[0, i].axis("off")
        
        ax[1, i].imshow(fake_B[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        ax[1, i].set_title("Fake B")
        ax[1, i].axis("off")
        
        ax[2, i].imshow(real_B[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        ax[2, i].set_title("Real B")
        ax[2, i].axis("off")
        
        ax[3, i].imshow(fake_A[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        ax[3, i].set_title("Fake A")
        ax[3, i].axis("off")
    
    plt.tight_layout()
    plt.show()
    
    G_AB.train()
    G_BA.train()

In [8]:
def evaluate_and_save_best_model(epoch, G_AB, G_BA, data_dir, save_path='./best_models'):
    """
    Evaluates current model performance using GMS score and saves the best model.
    Updated to work with attention-enhanced generators.
    
    Args:
        epoch (int): Current training epoch
        G_AB (nn.Module): Generator for Real to Comic transformation
        G_BA (nn.Module): Generator for Comic to Real transformation
        data_dir (str): Path to the data directory
        save_path (str): Directory to save the best model
    
    Returns:
        float: Current average GMS score
    """
    import os
    import torch
    import torch.nn as nn
    import numpy as np
    import cv2
    import torch_fidelity
    import shutil
    from PIL import Image
    import torchvision.transforms as transforms
    import gc  # For garbage collection to manage memory
    
    # Create directories if they don't exist
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    temp_save_dir_cartoon = f'./temp_generated_cartoon'
    temp_save_dir_real = f'./temp_generated_real'
    
    # Remove and recreate temporary directories
    for dir_path in [temp_save_dir_cartoon, temp_save_dir_real]:
        if os.path.exists(dir_path):
            shutil.rmtree(dir_path)
        os.makedirs(dir_path)
    
    # Setup for evaluation
    image_size = (256, 256)
    transforms_ = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    to_image = transforms.ToPILImage()
    batch_size = 4  # Reduced batch size to accommodate attention mechanism
    
    # Use CUDA if available
    cuda = torch.cuda.is_available()
    Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
    
    # Set models to evaluation mode
    G_AB.eval()
    G_BA.eval()
    
    try:
        # 1. Raw Image to Cartoon Image (A→B)
        test_dir_A = os.path.join(data_dir, 'VAE_generation/test')
        files_A = [os.path.join(test_dir_A, name) for name in os.listdir(test_dir_A)]
        
        # Process files in batches
        for i in range(0, len(files_A), batch_size):
            # Read images
            imgs = []
            for j in range(i, min(len(files_A), i+batch_size)):
                img = Image.open(files_A[j])
                img = transforms_(img)
                imgs.append(img)
            imgs = torch.stack(imgs, 0).type(Tensor)
            
            # Generate fake images
            with torch.no_grad():
                fake_imgs = G_AB(imgs).detach().cpu()
            
            # Save generated images
            for j in range(fake_imgs.size(0)):
                img = fake_imgs[j].squeeze().permute(1, 2, 0)
                img_arr = img.numpy()
                img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
                img_arr = img_arr.astype(np.uint8)
                
                img = to_image(img_arr)
                _, name = os.path.split(files_A[i+j])
                img.save(os.path.join(temp_save_dir_cartoon, name))
            
            # Clear memory
            del imgs, fake_imgs
            if cuda:
                torch.cuda.empty_cache()
        
        # Calculate metrics for A→B
        gt_dir_B = os.path.join(data_dir, 'VAE_generation_Cartoon/test')
        
        metrics_AB = torch_fidelity.calculate_metrics(
            input1=temp_save_dir_cartoon,
            input2=gt_dir_B,
            cuda=cuda,
            fid=True,
            isc=True
        )
        
        fid_score_AB = metrics_AB["frechet_inception_distance"]
        is_score_AB = metrics_AB["inception_score_mean"]
        
        if is_score_AB > 0:
            gms_AB = np.sqrt(fid_score_AB / is_score_AB)
        else:
            gms_AB = float('inf')
            
        # 2. Cartoon Image to Raw Image (B→A)
        test_dir_B = os.path.join(data_dir, 'VAE_generation_Cartoon/test')
        files_B = [os.path.join(test_dir_B, name) for name in os.listdir(test_dir_B)]
        
        # Process files in batches
        for i in range(0, len(files_B), batch_size):
            # Read images
            imgs = []
            for j in range(i, min(len(files_B), i+batch_size)):
                img = Image.open(files_B[j])
                img = transforms_(img)
                imgs.append(img)
            imgs = torch.stack(imgs, 0).type(Tensor)
            
            # Generate fake images
            with torch.no_grad():
                fake_imgs = G_BA(imgs).detach().cpu()
            
            # Save generated images
            for j in range(fake_imgs.size(0)):
                img = fake_imgs[j].squeeze().permute(1, 2, 0)
                img_arr = img.numpy()
                img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
                img_arr = img_arr.astype(np.uint8)
                
                img = to_image(img_arr)
                _, name = os.path.split(files_B[i+j])
                img.save(os.path.join(temp_save_dir_real, name))
            
            # Clear memory
            del imgs, fake_imgs
            if cuda:
                torch.cuda.empty_cache()
        
        # Calculate metrics for B→A
        gt_dir_A = os.path.join(data_dir, 'VAE_generation/test')
        
        metrics_BA = torch_fidelity.calculate_metrics(
            input1=temp_save_dir_real,
            input2=gt_dir_A,
            cuda=cuda,
            fid=True,
            isc=True
        )
        
        fid_score_BA = metrics_BA["frechet_inception_distance"]
        is_score_BA = metrics_BA["inception_score_mean"]
        
        if is_score_BA > 0:
            gms_BA = np.sqrt(fid_score_BA / is_score_BA)
        else:
            gms_BA = float('inf')
        
        # Calculate average GMS
        avg_gms = np.round((gms_AB + gms_BA) / 2, 5)
        
        # Print metrics
        print(f"\nEvaluation at Epoch {epoch}:")
        print(f"Raw to Cartoon - IS: {is_score_AB}, FID: {fid_score_AB}, GMS: {gms_AB}")
        print(f"Cartoon to Raw - IS: {is_score_BA}, FID: {fid_score_BA}, GMS: {gms_BA}")
        print(f"Average GMS: {avg_gms}")
        
        # Load best GMS so far
        best_gms_file = os.path.join(save_path, 'best_gms.txt')
        if os.path.exists(best_gms_file):
            with open(best_gms_file, 'r') as f:
                best_gms = float(f.read().strip())
        else:
            best_gms = float('inf')
        
        # If current model is better, save it
        if avg_gms < best_gms:
            print(f"New best model found! Average GMS improved from {best_gms} to {avg_gms}")
            
            # Save the model weights
            torch.save(G_AB.state_dict(), os.path.join(save_path, 'G_AB_best.pth'))
            torch.save(G_BA.state_dict(), os.path.join(save_path, 'G_BA_best.pth'))
            
            # Save the architecture information for loading
            with open(os.path.join(save_path, 'model_config.txt'), 'w') as f:
                f.write(f"G_AB_attention: {G_AB.use_attention if hasattr(G_AB, 'use_attention') else False}\n")
                f.write(f"G_BA_attention: {G_BA.use_attention if hasattr(G_BA, 'use_attention') else False}\n")
                f.write(f"G_AB_residual_blocks: {G_AB.num_residual_blocks if hasattr(G_AB, 'num_residual_blocks') else 7}\n")
                f.write(f"G_BA_residual_blocks: {G_BA.num_residual_blocks if hasattr(G_BA, 'num_residual_blocks') else 6}\n")
            
            # Save the epoch number and metrics
            with open(os.path.join(save_path, 'best_epoch.txt'), 'w') as f:
                f.write(str(epoch))
            
            with open(best_gms_file, 'w') as f:
                f.write(str(avg_gms))
                
            # Save detailed metrics
            with open(os.path.join(save_path, 'best_metrics.txt'), 'w') as f:
                f.write(f"Epoch: {epoch}\n")
                f.write(f"Raw to Cartoon - IS: {is_score_AB}, FID: {fid_score_AB}, GMS: {gms_AB}\n")
                f.write(f"Cartoon to Raw - IS: {is_score_BA}, FID: {fid_score_BA}, GMS: {gms_BA}\n")
                f.write(f"Average GMS: {avg_gms}")
        else:
            print(f"No improvement. Current GMS: {avg_gms}, Best GMS: {best_gms}")
    
    except RuntimeError as e:
        print(f"Runtime error during evaluation: {e}")
        print("This is likely due to memory issues. Try reducing the batch size further.")
        # Return a high GMS to avoid saving this model
        return float('inf')
    
    finally:
        # Reset models back to training mode
        G_AB.train()
        G_BA.train()
        
        # Clean up
        for dir_path in [temp_save_dir_cartoon, temp_save_dir_real]:
            if os.path.exists(dir_path):
                shutil.rmtree(dir_path)
        
        # Force garbage collection
        gc.collect()
        if cuda:
            torch.cuda.empty_cache()
    
    return avg_gms


def load_best_model(G_AB, G_BA, save_path='./best_models'):
    """
    Loads the best model saved during training.
    Updated to handle attention-enhanced generators.
    
    Args:
        G_AB (nn.Module): Generator for Real to Comic transformation
        G_BA (nn.Module): Generator for Comic to Real transformation
        save_path (str): Directory where the best model is saved
    
    Returns:
        tuple: The loaded models (G_AB, G_BA) and the best epoch
    """
    import os
    import torch
    
    # Check if best model exists
    g_ab_path = os.path.join(save_path, 'G_AB_best.pth')
    g_ba_path = os.path.join(save_path, 'G_BA_best.pth')
    epoch_path = os.path.join(save_path, 'best_epoch.txt')
    config_path = os.path.join(save_path, 'model_config.txt')
    
    if not (os.path.exists(g_ab_path) and os.path.exists(g_ba_path)):
        print("Best model files not found.")
        return G_AB, G_BA, None
    
    # Check if we need to reinitialize the models with different configurations
    if os.path.exists(config_path):
        config = {}
        with open(config_path, 'r') as f:
            for line in f:
                key, value = line.strip().split(': ')
                if value.lower() == 'true':
                    config[key] = True
                elif value.lower() == 'false':
                    config[key] = False
                else:
                    try:
                        config[key] = int(value)
                    except ValueError:
                        config[key] = value
        
        # Check if we need to reinitialize G_AB with different attention setting
        if hasattr(G_AB, 'use_attention') and G_AB.use_attention != config.get('G_AB_attention', False):
            print(f"Reinitializing G_AB with attention={config.get('G_AB_attention', False)}")
            # Assuming Generator class and required imports are available
            G_AB = Generator(
                3, 
                edge_aware=True, 
                num_residual_blocks=config.get('G_AB_residual_blocks', 7),
                use_attention=config.get('G_AB_attention', False)
            )
        
        # Check if we need to reinitialize G_BA with different attention setting
        if hasattr(G_BA, 'use_attention') and G_BA.use_attention != config.get('G_BA_attention', False):
            print(f"Reinitializing G_BA with attention={config.get('G_BA_attention', False)}")
            # Assuming Generator class and required imports are available
            G_BA = Generator(
                3, 
                edge_aware=False, 
                num_residual_blocks=config.get('G_BA_residual_blocks', 6),
                use_attention=config.get('G_BA_attention', False)
            )
    
    # Load best model weights
    G_AB.load_state_dict(torch.load(g_ab_path))
    G_BA.load_state_dict(torch.load(g_ba_path))
    
    # Get best epoch
    best_epoch = None
    if os.path.exists(epoch_path):
        with open(epoch_path, 'r') as f:
            best_epoch = int(f.read().strip())
    
    print(f"Loaded best model from epoch {best_epoch}")
    
    # Display best metrics if available
    metrics_path = os.path.join(save_path, 'best_metrics.txt')
    if os.path.exists(metrics_path):
        with open(metrics_path, 'r') as f:
            metrics = f.read()
        print("Best model metrics:")
        print(metrics)
    
    return G_AB, G_BA, best_epoch

# Training Loop

In [None]:
"""
Step 7. Training
"""
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

save_dir = './checkpoints'
os.makedirs(save_dir, exist_ok=True)

# Number of epochs to train for
n_epochs = 150
for epoch in range(n_epochs):
    for i, (real_A, real_B) in enumerate(trainloader):
        real_A, real_B = real_A.type(Tensor), real_B.type(Tensor)

        # Create labels dynamically based on discriminator output size
        # Instead of calculating from input size and scale factor
        # We'll create the tensors after getting actual discriminator outputs
        fake_B = G_AB(real_A)
        fake_A = G_BA(real_B)
        
        # Get discriminator outputs to determine correct shape
        real_A_out = D_A(real_A)
        real_B_out = D_B(real_B)
        fake_A_out = D_A(fake_A.detach())
        fake_B_out = D_B(fake_B.detach())
        
        # Create target tensors with correct shapes
        valid_A = torch.ones_like(real_A_out).type(Tensor)
        fake_A_target = torch.zeros_like(fake_A_out).type(Tensor)
        valid_B = torch.ones_like(real_B_out).type(Tensor)
        fake_B_target = torch.zeros_like(fake_B_out).type(Tensor)

        """Train Generators"""
        # Set to training mode 
        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()
        
        # Note: fake_A and fake_B already generated earlier to determine shapes
        
        # Identity loss - helps preserve color composition
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss - train G to make D think the generated images are real
        # Use the valid tensors with correct shapes
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid_A)
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss - ensure we can reconstruct the original image
        recov_A = G_BA(fake_B)
        recov_B = G_AB(fake_A)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss for generators with more balanced weights
        if real_A.size(0) > 0 and real_B.size(0) > 0:  # Ensure we have samples in the batch
            # For G_AB (Real→Cartoon): More balanced weights
            loss_G_AB = 2.0 * loss_id_B + 1.0 * loss_GAN_AB + 10.0 * loss_cycle_B
            
            # For G_BA (Cartoon→Real): Keep successful weights
            loss_G_BA = 2.0 * loss_id_A + 1.0 * loss_GAN_BA + 10.0 * loss_cycle_A
            
            # Combined loss
            loss_G = (loss_G_AB + loss_G_BA) / 2
        else:
            # Fallback to combined loss if batch issues
            weight1 = 2.0   # Identity weight
            weight2 = 1.0   # GAN weight
            weight3 = 10.0  # Cycle weight
            loss_G = weight1 * loss_identity + weight2 * loss_GAN + weight3 * loss_cycle


        # Check for NaN values
        if torch.isnan(loss_G):
            print("NaN detected in generator loss! Skipping backward pass.")
            # Reset the optimizer state
            optimizer_G.zero_grad()
        else:
            loss_G.backward()
            # Apply gradient clipping to prevent exploding gradients
            clip_gradients(itertools.chain(G_AB.parameters(), G_BA.parameters()), max_norm=1.0)
            optimizer_G.step()

        """Train Discriminator A"""
        optimizer_D_A.zero_grad()

        # Use the valid/fake tensors with matching shapes
        loss_real = criterion_GAN(real_A_out, valid_A)
        loss_fake = criterion_GAN(fake_A_out, fake_A_target)
        loss_D_A = (loss_real + loss_fake) / 2

        # Check for NaN values
        if torch.isnan(loss_D_A):
            print("NaN detected in discriminator A loss! Skipping backward pass.")
            # Reset the optimizer state
            optimizer_D_A.zero_grad()
        else:
            loss_D_A.backward()
            # Apply gradient clipping
            clip_gradients(D_A.parameters(), max_norm=1.0)
            optimizer_D_A.step()

        """Train Discriminator B"""
        optimizer_D_B.zero_grad()

        # Use the valid/fake tensors with matching shapes
        loss_real = criterion_GAN(real_B_out, valid_B)
        loss_fake = criterion_GAN(fake_B_out, fake_B_target)
        loss_D_B = (loss_real + loss_fake) / 2

        # Check for NaN values
        if torch.isnan(loss_D_B):
            print("NaN detected in discriminator B loss! Skipping backward pass.")
            # Reset the optimizer state
            optimizer_D_B.zero_grad()
        else:
            loss_D_B.backward()
            # Apply gradient clipping
            clip_gradients(D_B.parameters(), max_norm=1.0)
            optimizer_D_B.step()

    # Print validation progress
    loss_D = (loss_D_A + loss_D_B) / 2
    
    # Check for NaN values and stop early
    if torch.isnan(loss_G) or torch.isnan(loss_D_A) or torch.isnan(loss_D_B):
        print("NaN detected in loss values. Stopping training early.")
        break
        
    print(f'[Epoch {epoch + 1}/{n_epochs}]')
    print(f'[G loss: {loss_G.item():.4f} | identity: {loss_identity.item():.4f} GAN: {loss_GAN.item():.4f} cycle: {loss_cycle.item():.4f}]')
    print(f'[D loss: {loss_D.item():.4f} | D_A: {loss_D_A.item():.4f} D_B: {loss_D_B.item():.4f}]')

    # Generate validation samples every 5 epochs
    if (epoch + 1) % 2 == 0:
        real_A, real_B = next(iter(validloader))
        real_A, real_B = real_A.type(Tensor), real_B.type(Tensor)
        sample_images(real_A, real_B)

        avg_gms = evaluate_and_save_best_model(
            epoch + 1, 
            G_AB, 
            G_BA, 
            data_dir, 
            save_path='./best_models'
        )

        print(f'Average GMS for {epoch+1}: {avg_gms}')

    # Step the schedulers at the end of each epoch
    scheduler_G.step()
    scheduler_D_A.step()
    scheduler_D_B.step()
        

G_AB, G_BA, best_epoch = load_best_model(G_AB, G_BA, save_path='./best_models')
print(f"Using best model from epoch {best_epoch} for final evaluation")

# Test Function (Do not Modify)

In [None]:
"""
Step 8. Generate Images
"""
## Translation 1: Raw Image --> Cartoon Image
test_dir = os.path.join(data_dir, 'VAE_generation/test') # modification forbidden

files = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]
len(files)

save_dir = './Cartoon_images'
if os.path.exists(save_dir):
    shutil.rmtree(save_dir)  # Deletes the folder and its contents
    
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

to_image = transforms.ToPILImage()

G_AB.eval()
for i in range(0, len(files), batch_size):
    # read images
    imgs = []
    for j in range(i, min(len(files), i+batch_size)):
        img = Image.open(files[j])
        img = transforms_(img)
        imgs.append(img)
    imgs = torch.stack(imgs, 0).type(Tensor)

    # generate
    fake_imgs = G_AB(imgs).detach().cpu()

    # save
    for j in range(fake_imgs.size(0)):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img_arr = img_arr.astype(np.uint8)

        img = to_image(img_arr)
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))

gt_dir = os.path.join(data_dir, 'VAE_generation_Cartoon/test')

metrics = torch_fidelity.calculate_metrics(
    input1=save_dir,
    input2=gt_dir,
    cuda=True,
    fid=True,
    isc=True
)

fid_score = metrics["frechet_inception_distance"]
is_score = metrics["inception_score_mean"]

if is_score > 0:
    s_value_1 = np.sqrt(fid_score / is_score)
    print("Geometric Mean Score:", s_value_1)
else:
    print("IS is 0, GMS cannot be computed!")


## Translation 2: Cartoon Image --> Raw Image
test_dir = os.path.join(data_dir, 'VAE_generation_Cartoon/test') # modification forbidden

files = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]
len(files)

save_dir = './Raw_images'
if os.path.exists(save_dir):
    shutil.rmtree(save_dir)  # Deletes the folder and its contents
    
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

G_BA.eval()
for i in range(0, len(files), batch_size):
    # read images
    imgs = []
    for j in range(i, min(len(files), i+batch_size)):
        img = Image.open(files[j])
        img = transforms_(img)
        imgs.append(img)
    imgs = torch.stack(imgs, 0).type(Tensor)

    # generate
    fake_imgs = G_BA(imgs).detach().cpu()

    # save
    for j in range(fake_imgs.size(0)):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img_arr = img_arr.astype(np.uint8)

        img = to_image(img_arr)
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))

gt_dir = os.path.join(data_dir, 'VAE_generation/test')

metrics = torch_fidelity.calculate_metrics(
    input1 = save_dir,
    input2 = gt_dir,
    cuda=True,
    fid=True,
    isc=True
)

fid_score = metrics["frechet_inception_distance"]
is_score = metrics["inception_score_mean"]

if is_score > 0:
    s_value_2 = np.sqrt(fid_score / is_score)
    print("Geometric Mean Score:", s_value_2)
else:
    print("IS is 0, GMS cannot be computed!")


s_value = np.round((s_value_1+s_value_2)/2, 5)
df = pd.DataFrame({'id': [1], 'label': [s_value]})

print("Average GMS:", s_value)

csv_path = "aaron.kwah.2021.csv"
df.to_csv(csv_path, index=False)

print(f"CSV saved to {csv_path}")