In [2]:
import sys


import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from scipy import fftpack
from scipy.interpolate import UnivariateSpline
from torch.utils.data import DataLoader, Dataset

print("Python version")
print(sys.version)
print("PyTorch version:", torch.__version__)

# Define the generator
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9):
        super(Generator, self).__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        ]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(ngf * mult * 2),
                nn.ReLU(True)
            ]

        # Resnet blocks
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [
                nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(int(ngf * mult / 2)),
                nn.ReLU(True)
            ]

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64):
        super(Discriminator, self).__init__()
        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True)
        ]

        model += [nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Define the dataset
class CycleGANDataset(Dataset):
    def __init__(self, npz_file, transform=None, fraction=0.03):
        data = np.load(npz_file)
        total_samples = len(data['arr_0'])
        num_samples = int(total_samples * fraction)
        
        self.A = data['arr_0'][:num_samples]
        self.B = data['arr_1'][:num_samples]
        self.transform = transform

    def __len__(self):
        return len(self.A)

    def __getitem__(self, idx):
        A = self.A[idx].astype(np.float32) / 255.0  # Normalize to [0, 1]
        B = self.B[idx].astype(np.float32) / 255.0  # Normalize to [0, 1]
        
        A = torch.from_numpy(A).permute(2, 0, 1)  # Change from (H, W, C) to (C, H, W)
        B = torch.from_numpy(B).permute(2, 0, 1)  # Change from (H, W, C) to (C, H, W)
        
        if self.transform:
            A = self.transform(A)
            B = self.transform(B)
        
        return A, B
# Define the loss functions
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)
    
    # Metrics functions
def generate_fake_samples(g_model, samples, device):
    g_model = g_model.to(device)
    samples = samples.to(device)
    
    with torch.no_grad():
        X = g_model(samples)
    return X


def generate_real_samples(dataset, n_samples, device):
    indices = torch.randint(0, len(dataset), (n_samples,))
    samples = [dataset[i] for i in indices]
    X1 = torch.stack([s[0] for s in samples])
    X2 = torch.stack([s[1] for s in samples])
    return X1.to(device), X2.to(device)

def psnr(g_model_AtoB, g_model_BtoA, dataset, n_samples=15, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    dataloader = DataLoader(dataset, batch_size=n_samples, shuffle=True)
    X_realA, X_realB = next(iter(dataloader))
    X_realA, X_realB = X_realA.to(device), X_realB.to(device)

    # Ensure models are on the correct device
    g_model_AtoB = g_model_AtoB.to(device)
    g_model_BtoA = g_model_BtoA.to(device)
    
    X_fakeB = generate_fake_samples(g_model_AtoB, X_realA, device)
    X_fakeA = generate_fake_samples(g_model_BtoA, X_realB, device)

    # Proceed with the PSNR calculation
    mse_AtoB = nn.functional.mse_loss(X_fakeB, X_realB, reduction='none').mean(dim=[1,2,3])
    psnr_AtoB = 20 * torch.log10(2.0 / torch.sqrt(mse_AtoB))

    mse_BtoA = nn.functional.mse_loss(X_fakeA, X_realA, reduction='none').mean(dim=[1,2,3])
    psnr_BtoA = 20 * torch.log10(2.0 / torch.sqrt(mse_BtoA))

    return (psnr_AtoB.mean().item(), psnr_AtoB.std().item(), mse_AtoB.mean().item(), mse_AtoB.std().item(),
            psnr_BtoA.mean().item(), psnr_BtoA.std().item(), mse_BtoA.mean().item(), mse_BtoA.std().item())



def structural_similarity(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    # Las imágenes ya están en formato (batch, channels, height, width)
    # y tienen 3 canales (RGB), entonces ajustamos el kernel a 3 canales.
    
    # Cambiar el kernel para que sea aplicado a cada canal por separado
    kernel = torch.ones(3, 1, 11, 11) / 121  # Kernel de 3 canales
    kernel = kernel.to(img1.device)

    # Aplicamos convoluciones canal por canal con groups=3
    mu1 = nn.functional.conv2d(img1, kernel, padding=5, groups=3)  # groups=3 para cada canal (RGB)
    mu2 = nn.functional.conv2d(img2, kernel, padding=5, groups=3)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = nn.functional.conv2d(img1 * img1, kernel, padding=5, groups=3) - mu1_sq
    sigma2_sq = nn.functional.conv2d(img2 * img2, kernel, padding=5, groups=3) - mu2_sq
    sigma12 = nn.functional.conv2d(img1 * img2, kernel, padding=5, groups=3) - mu1_mu2

    # Calcular el mapa SSIM
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    # Retornar el valor promedio del SSIM por batch
    return ssim_map.mean([1, 2, 3])





def avg_SSIM(g_model_AtoB, g_model_BtoA, dataset, n_samples=15, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    dataloader = DataLoader(dataset, batch_size=n_samples, shuffle=True)
    X_realA, X_realB = next(iter(dataloader))
    X_realA, X_realB = X_realA.to(device), X_realB.to(device)

    g_model_AtoB = g_model_AtoB.to(device)
    g_model_BtoA = g_model_BtoA.to(device)
    
    X_fakeB = generate_fake_samples(g_model_AtoB, X_realA, device)
    X_fakeA = generate_fake_samples(g_model_BtoA, X_realB, device)

    # Convert from [-1, 1] to [0, 1] for SSIM calculation
    X_realA = (X_realA + 1) / 2
    X_realB = (X_realB + 1) / 2
    X_fakeA = (X_fakeA + 1) / 2
    X_fakeB = (X_fakeB + 1) / 2

    ssim_AtoB = structural_similarity(X_fakeB, X_realB)
    ssim_BtoA = structural_similarity(X_fakeA, X_realA)

    return (ssim_AtoB.mean().item(), ssim_AtoB.std().item(),
            ssim_BtoA.mean().item(), ssim_BtoA.std().item())
def cutoff_metric(image):
    # Convert PyTorch tensor to numpy array
    image_np = image.squeeze().cpu().numpy()

    # If the image has more than two dimensions, select only the spatial dimensions
    if len(image_np.shape) == 3:  # (channels, height, width)
        # Select only one channel (e.g., channel 0)
        image_np = image_np[0]

    # Ensure the image is in [0, 255] range for FFT calculation
    image_np = (image_np * 127.5 + 127.5).astype(np.uint8)

    # Compute the 2-D FFT and shift the zero-frequency component to the center of the array
    fft = fftpack.fftshift(fftpack.fft2(image_np))

    # Compute the magnitude spectrum
    magnitude_spectrum = np.abs(fft)

    # Determine the size of the magnitude spectrum (height, width)
    Nx, Ny = magnitude_spectrum.shape
    x0, y0 = Nx // 2, Ny // 2

    # Define a set of radii to crop the magnitude spectrum
    radius = np.arange(1, Nx, 5)

    # Compute distances from center point to all pixels
    y, x = np.ogrid[-y0:Ny-y0, -x0:Nx-x0]
    distances = np.sqrt(x*x + y*y)

    # Create mask and compute profile
    ft_profile = np.array([np.sum(magnitude_spectrum[(distances >= r) & (distances < r+5)]) for r in radius[:-1]])

    # Compute cutoff energy and cumulative energy
    cutoff_energy = 0.95 * np.sum(ft_profile)
    energy = np.cumsum(ft_profile)

    # Fit smoothing spline to the cumulative energy data
    xData, yData = np.arange(len(energy)), energy
    fitresult = UnivariateSpline(xData, yData, s=0.01, k=3)

    # Evaluate the fitted curve and find cutoff frequency
    vect_frequency = np.arange(1, 300, 0.1)
    a = fitresult(vect_frequency)
    n = np.argwhere(a > cutoff_energy)[0]

    # Estimate the cutoff frequency using linear interpolation
    y1, y0, y = a[n], a[n-1], cutoff_energy
    x0, x1 = vect_frequency[n-1], vect_frequency[n]
    cutoff_freq = x0 + (y-y0)/(y1-y0)*(x1-x0)

    return cutoff_freq

def cutoff_batch(g_model_AtoB, g_model_BtoA, dataset, n_samples=15, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    dataloader = DataLoader(dataset, batch_size=n_samples, shuffle=True)
    X_realA, X_realB = next(iter(dataloader))
    X_realA, X_realB = X_realA.to(device), X_realB.to(device)

    g_model_AtoB = g_model_AtoB.to(device)
    g_model_BtoA = g_model_BtoA.to(device)
    
    X_fakeB = generate_fake_samples(g_model_AtoB, X_realA, device)
    X_fakeA = generate_fake_samples(g_model_BtoA, X_realB, device)

    # Convert from [-1, 1] to [0, 1] for cutoff metric calculation
    X_fakeA = (X_fakeA + 1) / 2
    X_fakeB = (X_fakeB + 1) / 2
    
    cutoff_AtoB = torch.tensor([cutoff_metric(img) for img in X_fakeB]).to(device)
    cutoff_BtoA = torch.tensor([cutoff_metric(img) for img in X_fakeA]).to(device)
    
    return (cutoff_AtoB.mean().item(), cutoff_AtoB.std().item(),
            cutoff_BtoA.mean().item(), cutoff_BtoA.std().item())



# Define the training function
def train(netG_A2B, netG_B2A, netD_A, netD_B, train_loader, valid_loader, valid_paired_loader, num_epochs, device):
    criterionGAN = GANLoss().to(device)
    criterionCycle = nn.L1Loss()
    criterionIdt = nn.L1Loss()

    optimizer_G = optim.Adam(list(netG_A2B.parameters()) + list(netG_B2A.parameters()), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D_A = optim.Adam(netD_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D_B = optim.Adam(netD_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # Learning rate scheduler
    scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=30, gamma=0.1)
    scheduler_D_A = optim.lr_scheduler.StepLR(optimizer_D_A, step_size=30, gamma=0.1)
    scheduler_D_B = optim.lr_scheduler.StepLR(optimizer_D_B, step_size=30, gamma=0.1)

    train_metrics = []
    valid_metrics = []
    losses = {'G': [], 'D_A': [], 'D_B': []}

    for epoch in range(num_epochs):
        epoch_losses = {'G': [], 'D_A': [], 'D_B': []}
        for i, (real_A, real_B) in enumerate(train_loader):
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Train Generators
            optimizer_G.zero_grad()

            fake_B = netG_A2B(real_A)
            fake_A = netG_B2A(real_B)

            loss_idt_A = criterionIdt(netG_A2B(real_B), real_B) * 5.0
            loss_idt_B = criterionIdt(netG_B2A(real_A), real_A) * 5.0

            loss_GAN_A2B = criterionGAN(netD_B(fake_B), True)
            loss_GAN_B2A = criterionGAN(netD_A(fake_A), True)

            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterionCycle(recovered_A, real_A) * 10.0

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterionCycle(recovered_B, real_B) * 10.0

            loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB + loss_idt_A + loss_idt_B
            loss_G.backward()
            optimizer_G.step()

            # Train Discriminator A
            optimizer_D_A.zero_grad()
            loss_D_A = criterionGAN(netD_A(real_A), True) + criterionGAN(netD_A(fake_A.detach()), False)
            loss_D_A.backward()
            optimizer_D_A.step()

            # Train Discriminator B
            optimizer_D_B.zero_grad()
            loss_D_B = criterionGAN(netD_B(real_B), True) + criterionGAN(netD_B(fake_B.detach()), False)
            loss_D_B.backward()
            optimizer_D_B.step()

            epoch_losses['G'].append(loss_G.item())
            epoch_losses['D_A'].append(loss_D_A.item())
            epoch_losses['D_B'].append(loss_D_B.item())

            if i % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], '
                      f'D_A_loss: {loss_D_A.item():.4f}, D_B_loss: {loss_D_B.item():.4f}, '
                      f'G_loss: {loss_G.item():.4f}')

        # Step the schedulers
        scheduler_G.step()
        scheduler_D_A.step()
        scheduler_D_B.step()

        for key in losses.keys():
            losses[key].append(np.mean(epoch_losses[key]))

        # Compute metrics
        try:
            train_metrics.append(compute_metrics(netG_A2B, netG_B2A, train_loader.dataset, device, n_samples=100))
            valid_metrics.append(compute_metrics(netG_A2B, netG_B2A, valid_paired_loader.dataset, device, n_samples=100))
        except RuntimeError as e:
            print(f"Error computing metrics: {e}")
            continue

        # Save metrics
        save_metrics(train_metrics[-1], valid_metrics[-1], epoch)

        # Save models every 20 epochs
        if (epoch + 1) % 20 == 0:
            save_models(netG_A2B, netG_B2A, epoch)

        print_metrics(train_metrics[-1], valid_metrics[-1], epoch, num_epochs)

    return train_metrics, valid_metrics, losses

def compute_metrics(netG_A2B, netG_B2A, dataset, device, n_samples=100):
    psnr_values = psnr(netG_A2B, netG_B2A, dataset, n_samples, device)
    ssim_values = avg_SSIM(netG_A2B, netG_B2A, dataset, n_samples, device)
    cutoff_values = cutoff_batch(netG_A2B, netG_B2A, dataset, n_samples, device)
    
    return {
        'psnr': psnr_values,
        'ssim': ssim_values,
        'cutoff': cutoff_values
    }

def save_metrics(train_metrics, valid_metrics, epoch):
    with open(f'metrics_epoch_{epoch+1}.txt', 'w') as f:
        f.write(f"Epoch {epoch+1}\n")
        f.write(f"Train - PSNR A->B: {train_metrics['psnr'][0]:.4f}±{train_metrics['psnr'][1]:.4f}, PSNR B->A: {train_metrics['psnr'][4]:.4f}±{train_metrics['psnr'][5]:.4f}\n")
        f.write(f"Train - SSIM A->B: {train_metrics['ssim'][0]:.4f}±{train_metrics['ssim'][1]:.4f}, SSIM B->A: {train_metrics['ssim'][2]:.4f}±{train_metrics['ssim'][3]:.4f}\n")
        f.write(f"Train - Cutoff A->B: {train_metrics['cutoff'][0]:.4f}±{train_metrics['cutoff'][1]:.4f}, Cutoff B->A: {train_metrics['cutoff'][2]:.4f}±{train_metrics['cutoff'][3]:.4f}\n")
        f.write(f"Valid - PSNR A->B: {valid_metrics['psnr'][0]:.4f}±{valid_metrics['psnr'][1]:.4f}, PSNR B->A: {valid_metrics['psnr'][4]:.4f}±{valid_metrics['psnr'][5]:.4f}\n")
        f.write(f"Valid - SSIM A->B: {valid_metrics['ssim'][0]:.4f}±{valid_metrics['ssim'][1]:.4f}, SSIM B->A: {valid_metrics['ssim'][2]:.4f}±{valid_metrics['ssim'][3]:.4f}\n")
        f.write(f"Valid - Cutoff A->B: {valid_metrics['cutoff'][0]:.4f}±{valid_metrics['cutoff'][1]:.4f}, Cutoff B->A: {valid_metrics['cutoff'][2]:.4f}±{valid_metrics['cutoff'][3]:.4f}\n")

def save_models(netG_A2B, netG_B2A, epoch):
    torch.save(netG_A2B.state_dict(), f'netG_A2B_epoch_{epoch+1}.pth')
    torch.save(netG_B2A.state_dict(), f'netG_B2A_epoch_{epoch+1}.pth')

def print_metrics(train_metrics, valid_metrics, epoch, num_epochs):
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f"Train - PSNR A->B: {train_metrics['psnr'][0]:.4f}±{train_metrics['psnr'][1]:.4f}, PSNR B->A: {train_metrics['psnr'][4]:.4f}±{train_metrics['psnr'][5]:.4f}")
    print(f"Train - SSIM A->B: {train_metrics['ssim'][0]:.4f}±{train_metrics['ssim'][1]:.4f}, SSIM B->A: {train_metrics['ssim'][2]:.4f}±{train_metrics['ssim'][3]:.4f}")
    print(f"Train - Cutoff A->B: {train_metrics['cutoff'][0]:.4f}±{train_metrics['cutoff'][1]:.4f}, Cutoff B->A: {train_metrics['cutoff'][2]:.4f}±{train_metrics['cutoff'][3]:.4f}")
    print(f"Valid - PSNR A->B: {valid_metrics['psnr'][0]:.4f}±{valid_metrics['psnr'][1]:.4f}, PSNR B->A: {valid_metrics['psnr'][4]:.4f}±{valid_metrics['psnr'][5]:.4f}")
    print(f"Valid - SSIM A->B: {valid_metrics['ssim'][0]:.4f}±{valid_metrics['ssim'][1]:.4f}, SSIM B->A: {valid_metrics['ssim'][2]:.4f}±{valid_metrics['ssim'][3]:.4f}")
    print(f"Valid - Cutoff A->B: {valid_metrics['cutoff'][0]:.4f}±{valid_metrics['cutoff'][1]:.4f}, Cutoff B->A: {valid_metrics['cutoff'][2]:.4f}±{valid_metrics['cutoff'][3]:.4f}")

# Main execution
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize models
    netG_A2B = Generator(3, 3).to(device)
    netG_B2A = Generator(3, 3).to(device)
    netD_A = Discriminator(3).to(device)
    netD_B = Discriminator(3).to(device)

    # Load datasets
    transform = transforms.Compose([
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    train_dataset = CycleGANDataset('./data/confocal_exper_altogether_trainR_256.npz', transform=transform)
    valid_dataset = CycleGANDataset('./data/confocal_exper_non_sat_filt_validR_256.npz', transform=transform)
    valid_paired_dataset = CycleGANDataset('./data/confocal_exper_paired_filt_validsetR_256.npz', transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
    valid_paired_loader = DataLoader(valid_paired_dataset, batch_size=1, shuffle=False)

    # Train the model
    num_epochs = 1
    train_metrics, valid_metrics, losses = train(netG_A2B, netG_B2A, netD_A, netD_B, train_loader, valid_loader, valid_paired_loader, num_epochs, device)

    # Plot metrics
    def plot_metrics(train_metrics, valid_metrics, losses, num_epochs):
        epochs = range(1, num_epochs + 1)

        plt.figure(figsize=(20, 15))
        
        # PSNR A->B
        plt.subplot(331)
        plt.errorbar(epochs, [m['psnr'][0] for m in train_metrics], yerr=[m['psnr'][1] for m in train_metrics], label='Train A->B')
        plt.errorbar(epochs, [m['psnr'][0] for m in valid_metrics], yerr=[m['psnr'][1] for m in valid_metrics], label='Valid A->B')
        plt.title('PSNR A->B')
        plt.legend()

        # PSNR B->A
        plt.subplot(332)
        plt.errorbar(epochs, [m['psnr'][4] for m in train_metrics], yerr=[m['psnr'][5] for m in train_metrics], label='Train B->A')
        plt.errorbar(epochs, [m['psnr'][4] for m in valid_metrics], yerr=[m['psnr'][5] for m in valid_metrics], label='Valid B->A')
        plt.title('PSNR B->A')
        plt.legend()

        # SSIM A->B
        plt.subplot(333)
        plt.errorbar(epochs, [m['ssim'][0] for m in train_metrics], yerr=[m['ssim'][1] for m in train_metrics], label='Train A->B')
        plt.errorbar(epochs, [m['ssim'][0] for m in valid_metrics], yerr=[m['ssim'][1] for m in valid_metrics], label='Valid A->B')
        plt.title('SSIM A->B')
        plt.legend()

        # SSIM B->A
        plt.subplot(334)
        plt.errorbar(epochs, [m['ssim'][2] for m in train_metrics], yerr=[m['ssim'][3] for m in train_metrics], label='Train B->A')
        plt.errorbar(epochs, [m['ssim'][2] for m in valid_metrics], yerr=[m['ssim'][3] for m in valid_metrics], label='Valid B->A')
        plt.title('SSIM B->A')
        plt.legend()

        # Cutoff Frequency A->B
        plt.subplot(335)
        plt.errorbar(epochs, [m['cutoff'][0] for m in train_metrics], yerr=[m['cutoff'][1] for m in train_metrics], label='Train A->B')
        plt.errorbar(epochs, [m['cutoff'][0] for m in valid_metrics], yerr=[m['cutoff'][1] for m in valid_metrics], label='Valid A->B')
        plt.title('Cutoff Frequency A->B')
        plt.legend()

        # Cutoff Frequency B->A
        plt.subplot(336)
        plt.errorbar(epochs, [m['cutoff'][2] for m in train_metrics], yerr=[m['cutoff'][3] for m in train_metrics], label='Train B->A')
        plt.errorbar(epochs, [m['cutoff'][2] for m in valid_metrics], yerr=[m['cutoff'][3] for m in valid_metrics], label='Valid B->A')
        plt.title('Cutoff Frequency B->A')
        plt.legend()

        # Losses
        plt.subplot(337)
        plt.plot(epochs, losses['G'], label='Generator')
        plt.plot(epochs, losses['D_A'], label='Discriminator A')
        plt.plot(epochs, losses['D_B'], label='Discriminator B')
        plt.title('Losses')
        plt.legend()

        plt.tight_layout()
        plt.savefig('metrics_and_losses_plot.png')
        plt.close()

        print("Training completed. Metrics and losses plot saved as 'metrics_and_losses_plot.png'.")

    # Call the plot_metrics function
    

Python version
3.12.1 (tags/v3.12.1:2305ca5, Dec  7 2023, 22:03:25) [MSC v.1937 64 bit (AMD64)]
PyTorch version: 2.2.1+cu118
Epoch [1/1], Step [1/117], D_A_loss: 1.7988, D_B_loss: 1.0306, G_loss: 32.8086
Epoch [1/1], Step [101/117], D_A_loss: 0.4688, D_B_loss: 0.4135, G_loss: 3.2834


The maximal number of iterations maxit (set to 20 by the program)
allowed for finding a smoothing spline with fp=s has been reached: s
too small.
There is an approximation returned but the corresponding weighted sum
of squared residuals does not satisfy the condition abs(fp-s)/s < tol.
  fitresult = UnivariateSpline(xData, yData, s=0.01, k=3)
A theoretically impossible result was found during the iteration
process for finding a smoothing spline with fp = s: s too small.
There is an approximation returned but the corresponding weighted sum
of squared residuals does not satisfy the condition abs(fp-s)/s < tol.
  fitresult = UnivariateSpline(xData, yData, s=0.01, k=3)
  cutoff_AtoB = torch.tensor([cutoff_metric(img) for img in X_fakeB]).to(device)


Epoch [1/1]
Train - PSNR A->B: 8.3369±1.7672, PSNR B->A: 11.5372±2.0123
Train - SSIM A->B: 0.9792±0.0072, SSIM B->A: 0.9895±0.0043
Train - Cutoff A->B: 25.9896±0.4121, Cutoff B->A: 27.3916±0.6080
Valid - PSNR A->B: 17.2726±nan, PSNR B->A: 11.7358±nan
Valid - SSIM A->B: 0.9983±nan, SSIM B->A: 0.9911±nan
Valid - Cutoff A->B: 26.6464±nan, Cutoff B->A: 27.3737±nan


  return (psnr_AtoB.mean().item(), psnr_AtoB.std().item(), mse_AtoB.mean().item(), mse_AtoB.std().item(),
  psnr_BtoA.mean().item(), psnr_BtoA.std().item(), mse_BtoA.mean().item(), mse_BtoA.std().item())
  return (ssim_AtoB.mean().item(), ssim_AtoB.std().item(),
  ssim_BtoA.mean().item(), ssim_BtoA.std().item())
  return (cutoff_AtoB.mean().item(), cutoff_AtoB.std().item(),
  cutoff_BtoA.mean().item(), cutoff_BtoA.std().item())


In [2]:
import numpy as np

def inspect_dataset(file_path):
    print(f"\nInspecting dataset: {file_path}")
    data = np.load(file_path)
    
    for key in data.keys():
        array = data[key]
        print(f"\nArray: {key}")
        print(f"Shape: {array.shape}")
        
        if len(array.shape) == 4:
            num_samples, channels, height, width = array.shape
            print(f"Number of samples: {num_samples}")
            print(f"Number of channels: {channels}")
            print(f"Image dimensions: {height}x{width}")
        elif len(array.shape) == 3:
            num_samples, height, width = array.shape
            print(f"Number of samples: {num_samples}")
            print(f"Number of channels: 1 (grayscale)")
            print(f"Image dimensions: {height}x{width}")
        else:
            print("Unexpected array shape")
        
        print(f"Data type: {array.dtype}")
        print(f"Min value: {array.min()}")
        print(f"Max value: {array.max()}")
        print(f"Mean value: {array.mean()}")
        print(f"Standard deviation: {array.std()}")

# List of dataset files to inspect
dataset_files = [
    'confocal_exper_altogether_trainR_256.npz',
    'confocal_exper_non_sat_filt_validR_256.npz',
    'confocal_exper_paired_filt_validsetR_256.npz'
]

# Inspect each dataset
for file in dataset_files:
    inspect_dataset(file)

print("\nDataset inspection complete.")


Inspecting dataset: confocal_exper_altogether_trainR_256.npz


FileNotFoundError: [Errno 2] No such file or directory: 'confocal_exper_altogether_trainR_256.npz'

In [4]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os



def load_model(model_path, device):
    model = Generator(3, 3).to(device)
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    model.load_state_dict(state_dict)
    model.eval()
    return model

def generate_predictions(generator, dataloader, device, output_dir, direction):
    os.makedirs(output_dir, exist_ok=True)
    
    with torch.no_grad():
        for i, (real_A, real_B) in enumerate(dataloader):
            if direction == 'A2B':
                real = real_A.to(device)
                fake = generator(real)
            else:  # B2A
                real = real_B.to(device)
                fake = generator(real)
            
            # Denormalize images
            real = (real + 1) / 2
            fake = (fake + 1) / 2
            
            # Save images
            save_image(real, os.path.join(output_dir, f'real60_{direction}_{i}.png'))
            save_image(fake, os.path.join(output_dir, f'fake60_{direction}_{i}.png'))
            
            if i >= 99:  # Generate 100 images
                break

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load the trained models
    netG_A2B = load_model('./models/netG_A2B_epoch_60.pth', device)  
    netG_B2A = load_model('./models/netG_B2A_epoch_60.pth', device)  

    # Load the dataset
    test_dataset = CycleGANDataset('./data/confocal_exper_paired_filt_validsetR_256.npz', transform=None)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # Generate and save predictions
    generate_predictions(netG_A2B, test_loader, device, 'predictions_A2B', 'A2B')
    generate_predictions(netG_B2A, test_loader, device, 'predictions_B2A', 'B2A')

if __name__ == '__main__':
    main()