In [1]:
import torch
from torchinfo import summary
import wandb

import argparse
import os
import numpy as np
import math
import itertools
import datetime
import time

import torchvision.transforms as transforms
from torchvision.utils import save_image 
from torchvision.utils import make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable


from model_general_cyclegan_vae import *

from datasets import *
from utils import *

import torch.nn as nn
import torch.nn.functional as F
import torch


from torchvision.utils import save_image
import matplotlib.pyplot as plt
from PIL import Image
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image import StructuralSimilarityIndexMeasure

In [2]:
import torch
import random
import numpy as np
import os

def set_seed(seed):
    """Sets the seed for reproducibility across different components."""
    random.seed(seed)  # Python's built-in random module
    np.random.seed(seed)  # NumPy
    torch.manual_seed(seed)  # PyTorch CPU
    # torch.cuda.manual_seed(seed)  # PyTorch CUDA (single GPU)
    torch.cuda.manual_seed_all(seed)  # PyTorch CUDA (multi-GPU)
    # os.environ['PYTHONHASHSEED'] = str(seed)  # For hash-based operations

    # # Enforce deterministic behavior for CuDNN (can impact performance)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    # Enforce deterministic algorithms for all PyTorch operations (can raise errors if not available)
    # torch.use_deterministic_algorithms(True) # Uncomment if strict determinism is required and all ops support it

# Example usage:
seed_value = 42
set_seed(seed_value)

In [3]:
# Set the model configuration
# This configuration is used to define the parameters for training the CycleGAN with VAE.
# The values can be adjusted based on the specific requirements of the training process.

model_config = {
    'epoch': 0,
    'n_epochs': 601 ,# max channels
    'dataset_name': 'maps', # dataset
    'batch_size': 16, #0,
    'lr': 0.0002,
    'b1': 0.5,   #'leaky_relu',  # 'relu', 'leaky_relu', 'sin'
    'b2': 0.999,  ##4, 3
    'decay_epoch': 2,
    'n_cpu': 8, # 1.0, #0.0001, # 1e-6, #1e-6,
    'img_height': 64,
    'img_width': 64,
    'channels': 3,
    'sample_interval': 500,  #100
    'checkpoint_interval': 1,
    'n_residual_blocks': 1, #1
    'lambda_cyc': 10.0,
    'lambda_id': 5.0, # 0.0, # 1.0, #0.0001, # 1e-6, #1e-6,
    'latent_dim': 256, # Squashed latent dimension used for VAE
    'n_layers': 2, #2 # Number of layers in the generator
    'lambda_kl': 1e-05, # KL divergence loss weight for VAE
    'run_name': 'VAE_e601', #pixel_shuffle_sm_ldim128_l4_r2_pos_enc

    'saved_models_file': "saved_models_VAE",
    'images_file' : "images_VAE",
    'images_cycle_file': "images_VAE_cycle",
    'test_results_file': "test_results_VAE",

}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:
import wandb

wandb.init(project="CycleGAN", name=model_config['run_name'], config=model_config)
# wandb.watch(model, log="all", log_freq=model_config['log_interval'])

# model.train_loop(train_dataloader, val_dataloader, epochs=model_config['epochs'], device=device)


[34m[1mwandb[0m: Currently logged in as: [33msautoor[0m ([33msautoor-university-of-west-florida[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
import torch
from torchmetrics import MeanSquaredError
from torchmetrics.image import PeakSignalNoiseRatio
from torchvision.utils import save_image


# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create necessary directories
os.makedirs(f"{model_config['images_file']}/epochs_{model_config['n_epochs']}", exist_ok=True)
os.makedirs(f"{model_config['images_cycle_file']}/epochs_{model_config['n_epochs']}", exist_ok=True)
os.makedirs(f"{model_config['saved_models_file']}/epochs_{model_config['n_epochs']}", exist_ok=True)
os.makedirs(f"{model_config['test_results_file']}", exist_ok=True)

# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

cuda = torch.cuda.is_available()

input_shape = (model_config['channels'], model_config['img_height'], model_config['img_width'])

# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, model_config['n_residual_blocks'], model_config['n_layers'], model_config['latent_dim'], model_config['lambda_kl'])
G_BA = GeneratorResNet(input_shape, model_config['n_residual_blocks'], model_config['n_layers'], model_config['latent_dim'], model_config['lambda_kl'])
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

if model_config['epoch'] != 0:
    # Load pretrained models
    G_AB.load_state_dict(torch.load(f"{model_config['saved_models_file']}/epochs_%s/G_AB_%d.pth" % (model_config['n_epochs'], model_config['epoch'])))
    G_BA.load_state_dict(torch.load(f"{model_config['saved_models_file']}/epochs_%s/G_BA_%d.pth" % (model_config['n_epochs'], model_config['epoch'])))
    D_A.load_state_dict(torch.load( f"{model_config['saved_models_file']}/epochs_%s/D_A_%d.pth" % (model_config['n_epochs'], model_config['epoch'])))
    D_B.load_state_dict(torch.load (f"{model_config['saved_models_file']}/epochs_%s/D_B_%d.pth" % (model_config['n_epochs'], model_config['epoch'])))


else:
    # Initialize weights
    # G_AB.apply(weights_init_normal)
    # G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=model_config['lr'], betas=(model_config['b1'], model_config['b2'])
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=model_config['lr'], betas=(model_config['b1'], model_config['b2']))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=model_config['lr'], betas=(model_config['b1'], model_config['b2']))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(model_config['n_epochs'], model_config['epoch'], model_config['decay_epoch']).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(model_config['n_epochs'], model_config['epoch'], model_config['decay_epoch']).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(model_config['n_epochs'], model_config['epoch'], model_config['decay_epoch']).step
)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Buffers of previously generated samples
# fake_A_buffer = ReplayBuffer()
# fake_B_buffer = ReplayBuffer()

# Image transformations
transforms_ = [
    transforms.Resize(int(model_config['img_height'] * 1.12), Image.BICUBIC),
    transforms.RandomCrop((model_config['img_height'], model_config['img_width'])),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataset_path = "../../data/%s" % model_config['dataset_name']
print(f"Dataset path: {dataset_path}")

# Create the dataset instance
dataset = ImageDataset(dataset_path, transforms_=transforms_, unaligned=True)

# Print dataset length and some samples info
print(f"Number of samples in dataset: {len(dataset)}")
if len(dataset) > 0:
    print(f"First sample shape: {dataset[0]['A'].shape if 'A' in dataset[0] else 'No A found'}")
    print(f"First sample shape: {dataset[0]['B'].shape if 'B' in dataset[0] else 'No B found'}")


# Training data loader
dataloader = DataLoader(
    ImageDataset("../../data/%s" % model_config['dataset_name'], transforms_=transforms_, unaligned=True),
    batch_size=model_config['batch_size'],
    shuffle=True,
    num_workers=model_config['n_cpu'],
    pin_memory=True,  # Faster data transfer to GPU
    prefetch_factor=2,  # Preload next batches
)
# Test data loader
val_dataloader = DataLoader(
    ImageDataset("../../data/%s" % model_config['dataset_name'], transforms_=transforms_, unaligned=False, mode="test"),
    batch_size=model_config['batch_size'], #5,
    shuffle=False,
    num_workers=1,
)

# Add these near your other initialization code
train_history = {
    'D_loss': [],
    'G_loss': [],
    'G_GAN': [],
    'G_cycle': [],
    'G_identity': [],
    'batches': []
}


def calculate_cr(original_size, latent_size):
    """Calculate compression ratio between original and latent representation"""
    # Assuming original_size is (channels, height, width)
    original_elements = original_size[1] * original_size[2]
   
    # Calculate compression ratio
    cr = original_elements / latent_size
    return cr

# You can call this with your image dimensions and latent dimension
latent_shape =input_shape[1]//(2**(model_config['n_layers']))
latent_shape = latent_shape**2  # Adjusted to match the latent dimension used in the model

cr = calculate_cr(input_shape, latent_shape)
print(f"Compression Ratio: {cr:.2f}:1")

def plot_metric_curves(metrics, save_path="metric_curves.png"):
    plt.figure(figsize=(12, 10))
    
    # Plot MSE
    plt.subplot(2, 2, 1)
    plt.plot(metrics['train']['mse'], label='Train MSE')
    plt.plot(metrics['val']['mse'], label='Val MSE')
    plt.title('MSE')
    plt.xlabel('Epoch')
    plt.ylabel('MSE')
    plt.legend()
    plt.grid(True)
    
    # Plot PSNR
    plt.subplot(2, 2, 2)
    plt.plot(metrics['train']['psnr'], label='Train PSNR')
    plt.plot(metrics['val']['psnr'], label='Val PSNR')
    plt.title('PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.legend()
    plt.grid(True)
    
    # Plot SSIM
    plt.subplot(2, 2, 3)
    plt.plot(metrics['train']['ssim'], label='Train SSIM')
    plt.plot(metrics['val']['ssim'], label='Val SSIM')
    plt.title('SSIM')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')
    plt.legend()
    plt.grid(True)
    
    # Plot FID
    plt.subplot(2, 2, 4)
    plt.plot(metrics['train']['fid'], label='Train FID')
    plt.plot(metrics['val']['fid'], label='Val FID')
    plt.title('FID')
    plt.xlabel('Epoch')
    plt.ylabel('FID')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


# Function to plot learning curves
def plot_learning_curves(history, save_path="learning_curves.png"):
    plt.figure(figsize=(12, 8))
    
    # Plot Generator losses
    plt.subplot(2, 1, 1)
    plt.plot(history['batches'], history['G_loss'], label='Total G Loss')
    plt.plot(history['batches'], history['G_GAN'], label='GAN Loss')
    plt.plot(history['batches'], history['G_cycle'], label='Cycle Loss')
    plt.plot(history['batches'], history['G_identity'], label='Identity Loss')
    plt.title('Generator Learning Curves')
    plt.xlabel('Batch Number')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot Discriminator loss
    plt.subplot(2, 1, 2)
    plt.plot(history['batches'], history['D_loss'], label='D Loss', color='red')
    plt.title('Discriminator Learning Curve')
    plt.xlabel('Batch Number')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


# Function to save colored images 
@torch.no_grad()  
def sample_images(batches_done):
    """Saves generated samples in a 4x4 grid with clear labels"""
    G_AB.eval()
    G_BA.eval()

    try:
        # Get a batch of test images
        batch = next(iter(val_dataloader))

        real_A = batch["A"].cuda()      # Real A (Domain A)
        real_B = batch["B"].cuda()     # Real B (Domain B)

        
        with torch.no_grad():
            fake_B = G_AB(real_A)           # Fake B (A → B)
            fake_A = G_BA(real_B)           # Fake A (B → A)
            
            cycle_A = G_BA(fake_B)           # Fake B → Cycle-Consistent A          
            cycle_B= G_AB(cycle_A)           # cycle_A -> Cycle-Consistent B

        # Denormalize images
        def denorm(tensor):
            return torch.clamp(tensor * 0.5 + 0.5, 0, 1)

        # Denormalize images for visualization
        real_A = denorm(real_A[:4])
        fake_B = denorm(fake_B[:4])
        real_B = denorm(real_B[:4])

        fake_A = denorm(fake_A[:4])        
        cycle_A = denorm(cycle_A[:4])
        cycle_B = denorm(cycle_B[:4])
        

        grid_rows1 = [
            real_A,     # Row 1: Real A
            fake_B,     # Row 2: Fake B (generated from real A)
            real_B,     # Row 3: Real B
            fake_A      # Row 4: Fake A (generated from real B)
        ]
        grid1 = torch.cat(grid_rows1, dim=0)
        
        # Save the grid

        save_image(
            grid1,
            f"{model_config['images_file']}/epochs_{model_config['n_epochs']}/{batches_done}.png",
            nrow=4,       # 4 columns (4 images per row)
            normalize=False,
            padding=2,
            pad_value=1.0  # White padding between images
        )
      
        grid_rows2 = [
            real_A,      # Row 1: Real A
            fake_B,      # Row 2: Fake B (A → B)
            cycle_A,     # Row 3: Cycle-A (B → A cycle)
            cycle_B      # Row 4: cycle-B (cycle A → cycle B)
        ]
        grid2= torch.cat(grid_rows2, dim=0)

        #Save the grid
        save_image(
            grid2,
            f"{model_config['images_cycle_file']}/epochs_{model_config['n_epochs']}/{batches_done}.png",
            nrow=4,       # 4 columns (4 images per row)
            normalize=False,
            padding=2,
            pad_value=1.0  # White padding between images
        )

        
    except Exception as e:
        print(f"Error saving sample images: {str(e)}")


##Function to save colored images 

@torch.no_grad()  

def save_generated_images(epoch_num):
    """Saves generated samples in a 4x4 grid with clear labels"""
    G_AB.eval()
    G_BA.eval()

    try:
        # Get a batch of test images
        batch = next(iter(val_dataloader))

        real_A = batch["A"].cuda()      # Real A (Domain A)
        real_B = batch["B"].cuda()     # Real B (Domain B)

        
        with torch.no_grad():
            fake_B = G_AB(real_A)           # Fake B (A → B)
            fake_A = G_BA(real_B)           # Fake A (B → A)
            
            cycle_A = G_BA(fake_B)           # Fake B → Cycle-Consistent A          
            cycle_B= G_AB(cycle_A)           # cycle_A -> Cycle-Consistent B

        # Denormalize images
        def denorm(tensor):
            return torch.clamp(tensor * 0.5 + 0.5, 0, 1)

        # Denormalize images for visualization
        _real_A = denorm(real_A[:4])
        _fake_B = denorm(fake_B[:4])
        _real_B = denorm(real_B[:4])

        _fake_A = denorm(fake_A[:4])        
        _cycle_A = denorm(cycle_A[:4])
        _cycle_B = denorm(cycle_B[:4])
        

        grid_rows1 = [
            _real_A,     # Row 1: Real A
            _fake_B,     # Row 2: Fake B (generated from real A)
            _real_B,     # Row 3: Real B
            _fake_A      # Row 4: Fake A (generated from real B)
        ]
        grid1 = torch.cat(grid_rows1, dim=0)
        
        # Save the grid

        save_image(
            grid1,
            f"{model_config['images_file']}/epochs_{model_config['n_epochs']}/{epoch_num}.png",
            nrow=4,       # 4 columns (4 images per row)
            normalize=False,
            padding=2,
            pad_value=1.0  # White padding between images
        )
      

        # Arrange in 4x4 grid for cycle-consistency images:
        # Row 1: Real A (4 images)
        # Row 2: Fake B (A→B translations)
        # Row 3: Cycle-Consistent A (fake B → A)
        # Row 4: Fake A (B→A translations)
        grid_rows2 = [
            _real_A,      # Row 1: Real A
            _fake_B,      # Row 2: Fake B (A → B)
            _cycle_A,     # Row 3: Cycle-A (B → A cycle)
            _cycle_B      # Row 4: cycle-B (cycle A → cycle B)
        ]
        grid2= torch.cat(grid_rows2, dim=0)

        #Save the grid
        save_image(
            grid2,
            f"{model_config['images_cycle_file']}/epochs_{model_config['n_epochs']}/{epoch_num}.png",
            nrow=4,       # 4 columns (4 images per row)
            normalize=False,
            padding=2,
            pad_value=1.0  # White padding between images
        )

    except Exception as e:
        print(f"Error saving sample images: {str(e)}")

    return real_A, real_B, fake_A, fake_B, cycle_A, cycle_B


# ----------
#  Training
# ----------

prev_time = time.time()

# At start of training
print(f"GPU Memory Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")

# dictionary to track epoch metrics
epoch_metrics = {
    'train': {'mse': [], 'psnr': [], 'ssim': [], 'fid': []},
    'val': {'mse': [], 'psnr': [], 'ssim': [], 'fid': []}
}

for epoch in range(model_config['epoch'], model_config['n_epochs']):

    # Set models to train mode
    G_AB.train()
    G_BA.train()
    D_A.train()
    D_B.train()

    for i, batch in enumerate(dataloader):

        # Set model input
        real_A =batch["A"].to(device)  
        real_B = batch["B"].to(device) 

        # Adversarial ground truths
        valid = torch.ones(real_A.size(0), *D_A.output_shape, 
                  dtype=torch.float32, device=real_A.device)

        fake = torch.tensor(np.zeros((real_A.size(0), *D_A.output_shape)), dtype=torch.float32, device='cuda', requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        # G_AB.train()
        # G_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

     
        # Total loss
        loss_G = loss_GAN + model_config['lambda_cyc'] * loss_cycle  + model_config['lambda_id'] * loss_identity

        # Log losses to wandb        
        loss_dict = {
            'loss_GAN_AB': loss_GAN_AB.item(),
            'loss_GAN_BA': loss_GAN_BA.item(),
            'loss_GAN': loss_GAN.item(),
            'loss_identity_A': loss_id_A.item(),
            'loss_identity_B': loss_id_B.item(),
            'loss_identity': loss_identity.item(),
            'loss_cycle_A': loss_cycle_A.item(),
            'loss_cycle_B': loss_cycle_B.item(),
            'loss_cycle': loss_cycle.item(),
            'total_loss_G': loss_G.item(),
            'epoch': epoch
        }
        wandb.log(loss_dict)

        # Backward pass
        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_A(real_A), valid)     

        # Fake loss 
        fake_A_ = G_BA(real_B) 
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        # backward pass
        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = criterion_GAN(D_B(real_B), valid)

        # Fake loss 
        # fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        fake_B_ = G_AB(real_A) # fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
       
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = model_config['n_epochs'] * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

    
        # Store losses
        train_history['D_loss'].append(loss_D.item())
        train_history['G_loss'].append(loss_G.item())
        train_history['G_GAN'].append(loss_GAN.item())
        train_history['G_cycle'].append(loss_cycle.item())
        train_history['G_identity'].append(loss_identity.item())
        train_history['batches'].append(batches_done)


        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
            % (
                epoch,
                model_config['n_epochs'],
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_GAN.item(),
                loss_cycle.item(),
                loss_identity.item(),
                time_left,
            )
        )
        
        # =============================================
        
   
    if ((epoch + 1) % 10 == 0 or (epoch + 1) == model_config['epoch'] or (epoch + 1) == model_config['n_epochs']  or (epoch + 1) >(model_config['n_epochs'] -5)):
        val_real_A, val_real_B, val_fake_A, val_fake_B, val_cycle_A, val_cycle_B = save_generated_images(epoch)  # Save sample images for this epoch
   
    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()


    # if model_config['checkpoint_interval'] != -1 and epoch % model_config['checkpoint_interval'] == 0:
    if model_config['checkpoint_interval']!= -1 and (((epoch +1) %100 ==0) or (epoch +1) == model_config['n_epochs']):
        # Save model checkpoints
        torch.save(G_AB.state_dict(), f"{model_config['saved_models_file']}/epochs_%s/G_AB_%d.pth" % (model_config['n_epochs'], epoch))
        torch.save(G_BA.state_dict(), f"{model_config['saved_models_file']}/epochs_%s/G_BA_%d.pth" % (model_config['n_epochs'], epoch))
        torch.save(D_A.state_dict(), f"{model_config['saved_models_file']}/epochs_%s/D_A_%d.pth" % (model_config['n_epochs'], epoch))
        torch.save(D_B.state_dict(), f"{model_config['saved_models_file']}/epochs_%s/D_B_%d.pth" % (model_config['n_epochs'], epoch))


    # Clip gradients to prevent exploding gradients
    torch.nn.utils.clip_grad_norm_(G_AB.parameters(), 1.0)
    torch.nn.utils.clip_grad_norm_(G_BA.parameters(), 1.0)




####################################################################


# Initialize metrics
fid = FrechetInceptionDistance(feature=2048).to(device)
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)


# Add these near your other metric initializations
mse = MeanSquaredError().to(device)
psnr = PeakSignalNoiseRatio().to(device)


# Save metrics to file
metrics_path = f"{model_config['test_results_file']}/epochs_{model_config['n_epochs']}_metrics.txt"

# Function to calculate metrics between real and fake images
# Modify your calculate_metrics function:
@torch.no_grad()
def calculate_metrics(real, fake):
    """Calculate various metrics between real and fake images"""
    metrics = {}
    

    """Handle batch size mismatches by trimming to smaller size"""
    min_batch = min(real.shape[0], fake.shape[0])
    real = real[:min_batch]
    fake = fake[:min_batch]

    # Denormalize images (assuming they're in [-1, 1] range)
    real_denorm = (real + 1) / 2  # Scale to [0, 1]
    fake_denorm = (fake + 1) / 2
    
    # Reset metrics
    mse.reset()
    psnr.reset()
    ssim.reset()
    fid.reset()
    
    # Calculate metrics
    metrics['mse'] = mse(real_denorm, fake_denorm).item()
    metrics['psnr'] = psnr(real_denorm, fake_denorm).item()
    metrics['ssim'] = ssim(real_denorm, fake_denorm).item()
    
    # For FID, we need uint8 images in 0-255 range
    real_uint8 = (real_denorm * 255).byte()
    fake_uint8 = (fake_denorm * 255).byte()
    
    # Update FID
    fid.update(real_uint8, real=True)
    fid.update(fake_uint8, real=False)
    metrics['fid'] = fid.compute().item()

    # Convert all metrics to Python floats
    # return {k: float(v) if hasattr(v, 'item') else float(v) for k, v in metrics.items()}
    
    return metrics

##############################
# dictionary to track epoch metrics
epoch_metrics = {
    'train': {'mse': [], 'psnr': [], 'ssim': [], 'fid': []},
    'val': {'mse': [], 'psnr': [], 'ssim': [], 'fid': []}
}


    # Initialize epoch metrics
train_metrics_A = {'mse': 0, 'psnr': 0, 'ssim': 0, 'fid':0, 'count': 0}
val_metrics_A =  {'mse': 0, 'psnr': 0, 'ssim': 0, 'fid': 0, 'count': 0}  

train_metrics_B = {'mse': 0, 'psnr': 0, 'ssim': 0, 'fid':0, 'count': 0}
val_metrics_B =  {'mse': 0, 'psnr': 0, 'ssim': 0, 'fid': 0, 'count': 0} 

train_metrics_B = {'mse': 0, 'psnr': 0, 'ssim': 0, 'fid':0, 'count': 0}
val_metrics_B =  {'mse': 0, 'psnr': 0, 'ssim': 0, 'fid': 0, 'count': 0} 

train_metrics_A = calculate_metrics(real_A, fake_A)
val_metrics_A = calculate_metrics(val_real_A, val_fake_A)

train_metrics_B = calculate_metrics(real_B, fake_B)
val_metrics_B = calculate_metrics(val_real_B, val_fake_B)

cycle_metrics_A = calculate_metrics(real_A, val_cycle_A)
cycle_metrics_B = calculate_metrics(real_B, val_cycle_B)


print(train_metrics_A)
print(val_metrics_A)

print(train_metrics_B)
print(val_metrics_B)

print(cycle_metrics_A)
print(cycle_metrics_B)

# Save metrics to file
metrics_path = f"{model_config['test_results_file']}/epochs_{model_config['n_epochs']}_metrics.txt"

with open(metrics_path, "w") as f:
    # f.write("Epoch\tTrain MSE\tTrain PSNR\tTrain SSIM\tTrain FID\tVal MSE\tVal PSNR\tVal SSIM\tVal FID\n")
    f.write(f"Compression Ratio: {cr:.2f}:1\n\n")

    ################ TRAIN A and VAL A

    f.write("Train_A and Val_A\n")
    f.write("Epoch\tTrain MSE\tTrain PSNR\tTrain SSIM\tTrain FID\tVal MSE \tVal PSNR\tVal SSIM\tVal FID\n")

    f.write(f"{epoch}\t")
    f.write(f"{train_metrics_A['mse']:.4f}\t\t")
    f.write(f"{train_metrics_A['psnr']:.2f}\t\t")
    f.write(f"{train_metrics_A['ssim']:.4f}\t\t")
    f.write(f"{train_metrics_A['fid']:.2f}\t\t")
    

    f.write(f"{val_metrics_A['mse']:.4f}\t\t")
    f.write(f"{val_metrics_A['psnr']:.2f}\t\t")
    f.write(f"{val_metrics_A['ssim']:.4f}\t\t")
    f.write(f"{val_metrics_A['fid']:.2f}\n\n")
   
    ####### TRAIN B and VAL B 
    f.write("Train_B and Val_B\n")
    f.write("Epoch\tTrain MSE\tTrain PSNR\tTrain SSIM\tTrain FID\tVal MSE \tVal PSNR\tVal SSIM\tVal FID\n")
    f.write(f"{epoch}\t")
    f.write(f"{train_metrics_B['mse']:.4f}\t\t")
    f.write(f"{train_metrics_B['psnr']:.2f}\t\t")
    f.write(f"{train_metrics_B['ssim']:.4f}\t\t")
    f.write(f"{train_metrics_B['fid']:.2f}\t\t")
    

    f.write(f"{val_metrics_B['mse']:.4f}\t\t")
    f.write(f"{val_metrics_B['psnr']:.2f}\t\t")
    f.write(f"{val_metrics_B['ssim']:.4f}\t\t")
    f.write(f"{val_metrics_B['fid']:.2f}\n\n")

    ####### TRAIN A and CYCLE A
    f.write("Train_A and Cycle_A\n")
    f.write("Epoch\tTrain MSE\tTrain PSNR\tTrain SSIM\tTrain FID\tCyc MSE \tCyc PSNR\tCyc SSIM\tCyc FID\n")

    f.write(f"{epoch}\t")
    f.write(f"{train_metrics_A['mse']:.4f}\t\t")
    f.write(f"{train_metrics_A['psnr']:.2f}\t\t")
    f.write(f"{train_metrics_A['ssim']:.4f}\t\t")
    f.write(f"{train_metrics_A['fid']:.2f}\t\t")
    

    f.write(f"{cycle_metrics_A['mse']:.4f}\t\t")
    f.write(f"{cycle_metrics_A['psnr']:.2f}\t\t")
    f.write(f"{cycle_metrics_A['ssim']:.4f}\t\t")
    f.write(f"{cycle_metrics_A['fid']:.2f}\n\n")

    ############ TRAIN B and CYCLE B

    f.write("Train_B and Cycle_B\n")
    f.write("Epoch\tTrain MSE\tTrain PSNR\tTrain SSIM\tTrain FID\tCyc MSE \tCyc PSNR\tCyc SSIM\tCyc FID\n")
    f.write(f"{epoch}\t")
    f.write(f"{train_metrics_B['mse']:.4f}\t\t")
    f.write(f"{train_metrics_B['psnr']:.2f}\t\t")
    f.write(f"{train_metrics_B['ssim']:.4f}\t\t")
    f.write(f"{train_metrics_B['fid']:.2f}\t\t")
    

    f.write(f"{cycle_metrics_B['mse']:.4f}\t\t")
    f.write(f"{cycle_metrics_B['psnr']:.2f}\t\t")
    f.write(f"{cycle_metrics_B['ssim']:.4f}\t\t")
    f.write(f"{cycle_metrics_B['fid']:.2f}\n\n")

                
 
print("Testing complete! Results saved to test_results/")
wandb.finish()


Using device: cuda
Dataset path: ../../data/maps
Number of samples in dataset: 1096
First sample shape: torch.Size([3, 64, 64])
First sample shape: torch.Size([3, 64, 64])
Compression Ratio: 16.00:1
GPU Memory Allocated: 0.06 GB
[Epoch 600/601] [Batch 68/69] [D loss: 0.030714] [G loss: 1.843898, adv: 0.750222, cycle: 0.075210, identity: 0.068315] ETA: 0:00:00.084649{'mse': 0.03864389285445213, 'psnr': 13.663229942321777, 'ssim': 0.20813539624214172, 'fid': 364.8481140136719}
{'mse': 0.03551651909947395, 'psnr': 13.803985595703125, 'ssim': 0.04967421293258667, 'fid': 269.29742431640625}
{'mse': 0.0128769027069211, 'psnr': 18.893125534057617, 'ssim': 0.5429456233978271, 'fid': 324.21075439453125}
{'mse': 0.0037293164059519768, 'psnr': 24.25999641418457, 'ssim': 0.46013012528419495, 'fid': 241.2958526611328}
{'mse': 0.0337069109082222, 'psnr': 12.233099937438965, 'ssim': 0.10632374882698059, 'fid': 396.5418701171875}
{'mse': 0.017027586698532104, 'psnr': 17.66900062561035, 'ssim': 0.47559

0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
kl_loss,▅▂▁▁▁▁▄▁▁▁▁▂▁▁▁▁▅▁▁▁▁▅▁▂▁▁▁▁▅▁▅▂▁▁█▁▁▁▁▁
loss_GAN,▁▃▇▄▆▄▄▆▅▆▅▇▆▆▇█▅▅▇▅▅▆▆▇▆▅▆▆▅▆▆▆▆▆▆▆▆▆▆▆
loss_GAN_AB,▃▄▅▁▂▂█▅▄▆▆▅▅▇▅▇▅▅▅▄▅▅▅▆▅▅▅▆▅▅▅▅▆▅▅▅▅▅▅▄
loss_GAN_BA,▃▂▁▂▂▅▄▃▃▃▄▄▃▄▄▃▅▅▆▅▄█▄▆▅▆▆▆▅▅▄▄▆▅▅▅▆▄▅▄
loss_cycle,▆█▆▅▅▅▆▄▄▄▃▃▄▃▄▃▂▄▃▂▄▄▂▃▃▃▃▃▃▂▂▂▂▁▃▃▂▃▂▂
loss_cycle_A,█▇▇▅▇▄▅▆▄▆▃▆█▄▄▅▃▃▃▄▅▃▄▄▃▁▄▃▂▂▆▅▄▃▄▂▃▄▅▄
loss_cycle_B,█▆▄▄▃▃▃▃▄▂▂▂▂▂▂▂▂▁▂▁▁▂▁▂▁▂▁▂▁▂▂▂▁▁▂▂▁▁▁▁
loss_identity,█▅▅▃▄▃▂▄▃▃▂▃▂▂▂▃▂▂▂▃▂▂▂▃▂▂▂▁▂▂▁▂▂▁▂▂▁▁▂▁
loss_identity_A,█▇▆█▆▅▅▇▅▆▅▃▃▄▄▃▃▄▂▂▃▄▃▄▄▄▂▂▃▂▂▃▃▅▂▃▂▁▁▂

0,1
epoch,600.0
kl_loss,2820.98657
loss_GAN,0.75022
loss_GAN_AB,0.79831
loss_GAN_BA,0.70213
loss_cycle,0.07521
loss_cycle_A,0.11018
loss_cycle_B,0.04024
loss_identity,0.06832
loss_identity_A,0.10279
