In [4]:
class VGGLoss(nn.Module):
    def __init__(self, device):
        super(VGGLoss, self).__init__()
        # Cargar solo las capas de características que necesitamos (hasta relu4_4)
        vgg = torchvision.models.vgg19(pretrained=True).features[:26]
        vgg.to(device)
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.criterion = nn.L1Loss()
        
        # Parámetros de normalización VGG
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x, y):
        # Asegurarnos de que los valores están en [0,1]
        x = (x + 1) * 0.5  # de [-1,1] a [0,1]
        y = (y + 1) * 0.5  # de [-1,1] a [0,1]
        
        # Normalizar con mean y std de VGG
        x = (x - self.mean) / self.std
        y = (y - self.mean) / self.std
        
        # Obtener características y calcular pérdida
        x_vgg = self.vgg(x.clamp(0, 1))
        y_vgg = self.vgg(y.clamp(0, 1))
        
        return self.criterion(x_vgg, y_vgg)

In [1]:
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

print("Python version")
print(sys.version)
print("PyTorch version:", torch.__version__)

# Define the generator
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9):
        super(Generator, self).__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        ]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(ngf * mult * 2),
                nn.ReLU(True)
            ]

        # Resnet blocks
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [
                nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(int(ngf * mult / 2)),
                nn.ReLU(True)
            ]

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64):
        super(Discriminator, self).__init__()
        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True)
        ]

        model += [
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True)
        ]

        model += [nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Define the dataset
class CycleGANDataset(Dataset):
    def __init__(self, npz_file, transform=None, fraction=0.1):
        data = np.load(npz_file)
        total_samples = len(data['arr_0'])
        num_samples = int(total_samples * fraction)
        
        self.A = data['arr_0'][:num_samples]
        self.B = data['arr_1'][:num_samples]
        self.transform = transform

    def __len__(self):
        return len(self.A)

    def __getitem__(self, idx):
        A = self.A[idx].astype(np.float32) / 255.0  # Normalize to [0, 1]
        B = self.B[idx].astype(np.float32) / 255.0  # Normalize to [0, 1]
        
        A = torch.from_numpy(A).permute(2, 0, 1)  # Change from (H, W, C) to (C, H, W)
        B = torch.from_numpy(B).permute(2, 0, 1)  # Change from (H, W, C) to (C, H, W)
        
        if self.transform:
            A = self.transform(A)
            B = self.transform(B)
        
        return A, B

# Define the loss functions
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)



# Modificamos la función de entrenamiento para incluir VGG loss
def train(netG_A2B, netG_B2A, netD_A, netD_B, train_loader, num_epochs, device):
    criterionGAN = GANLoss().to(device)
    criterionCycle = nn.L1Loss()
    criterionIdt = nn.L1Loss()
    criterionVGG = VGGLoss(device)  # Agregamos VGG loss

    # Lambda para VGG loss
    lambda_vgg = 10.0  # Puedes ajustar este valor

    optimizer_G = optim.Adam(list(netG_A2B.parameters()) + list(netG_B2A.parameters()), 
                           lr=0.0002, betas=(0.5, 0.999))
    optimizer_D_A = optim.Adam(netD_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D_B = optim.Adam(netD_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

    scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=30, gamma=0.1)
    scheduler_D_A = optim.lr_scheduler.StepLR(optimizer_D_A, step_size=30, gamma=0.1)
    scheduler_D_B = optim.lr_scheduler.StepLR(optimizer_D_B, step_size=30, gamma=0.1)

    losses = {'G': [], 'D_A': [], 'D_B': [], 'VGG': []}  # Agregamos tracking de VGG loss

    for epoch in range(num_epochs):
        epoch_losses = {'G': [], 'D_A': [], 'D_B': [], 'VGG': []}
        
        for i, (real_A, real_B) in enumerate(train_loader):
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Train Generators
            optimizer_G.zero_grad()

            fake_B = netG_A2B(real_A)
            fake_A = netG_B2A(real_B)

            # Identity loss
            loss_idt_A = criterionIdt(netG_A2B(real_B), real_B) * 5.0
            loss_idt_B = criterionIdt(netG_B2A(real_A), real_A) * 5.0

            # GAN loss
            loss_GAN_A2B = criterionGAN(netD_B(fake_B), True)
            loss_GAN_B2A = criterionGAN(netD_A(fake_A), True)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            recovered_B = netG_A2B(fake_A)
            loss_cycle_ABA = criterionCycle(recovered_A, real_A) * 10.0
            loss_cycle_BAB = criterionCycle(recovered_B, real_B) * 10.0

            # VGG loss
            loss_vgg_A = criterionVGG(fake_A, real_A) * lambda_vgg
            loss_vgg_B = criterionVGG(fake_B, real_B) * lambda_vgg
            loss_vgg = loss_vgg_A + loss_vgg_B

            # Total generator loss
            loss_G = (loss_GAN_A2B + loss_GAN_B2A + 
                     loss_cycle_ABA + loss_cycle_BAB + 
                     loss_idt_A + loss_idt_B + 
                     loss_vgg)
            
            loss_G.backward()
            optimizer_G.step()

            # Train Discriminator A
            optimizer_D_A.zero_grad()
            loss_D_A = criterionGAN(netD_A(real_A), True) + criterionGAN(netD_A(fake_A.detach()), False)
            loss_D_A.backward()
            optimizer_D_A.step()

            # Train Discriminator B
            optimizer_D_B.zero_grad()
            loss_D_B = criterionGAN(netD_B(real_B), True) + criterionGAN(netD_B(fake_B.detach()), False)
            loss_D_B.backward()
            optimizer_D_B.step()

            # Guardar losses
            epoch_losses['G'].append(loss_G.item())
            epoch_losses['D_A'].append(loss_D_A.item())
            epoch_losses['D_B'].append(loss_D_B.item())
            epoch_losses['VGG'].append(loss_vgg.item())

            if i % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], '
                      f'D_A_loss: {loss_D_A.item():.4f}, D_B_loss: {loss_D_B.item():.4f}, '
                      f'G_loss: {loss_G.item():.4f}, VGG_loss: {loss_vgg.item():.4f}')

        # Step the schedulers
        scheduler_G.step()
        scheduler_D_A.step()
        scheduler_D_B.step()

        # Calculate average losses for the epoch
        for key in losses.keys():
            losses[key].append(np.mean(epoch_losses[key]))

        # Save models every epoch
        if (epoch + 1) % 1 == 0:
            torch.save(netG_A2B.state_dict(), f'netG_A2B_epoch_{epoch+1}.pth')
            torch.save(netG_B2A.state_dict(), f'netG_B2A_epoch_{epoch+1}.pth')

    return losses

import matplotlib.pyplot as plt

def plot_losses(losses, save_path='training_losses.png'):
    """
    Grafica las funciones de pérdida del entrenamiento
    
    Args:
        losses (dict): Diccionario con las pérdidas {'G': [], 'D_A': [], 'D_B': [], 'VGG': []}
        save_path (str): Ruta donde guardar la gráfica
    """
    plt.figure(figsize=(12, 8))
    
    # Crear el eje x (épocas)
    epochs = range(1, len(losses['G']) + 1)
    
    # Graficar cada pérdida
    plt.plot(epochs, losses['G'], 'b-', label='Generator Loss', linewidth=2)
    plt.plot(epochs, losses['D_A'], 'r-', label='Discriminator A Loss', linewidth=2)
    plt.plot(epochs, losses['D_B'], 'g-', label='Discriminator B Loss', linewidth=2)
    plt.plot(epochs, losses['VGG'], 'm-', label='VGG Loss', linewidth=2)
    
    plt.title('Training Losses Over Time', size=14)
    plt.xlabel('Epoch', size=12)
    plt.ylabel('Loss', size=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(fontsize=10)
    
    # Guardar la figura
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()
    
    print(f"Loss plot saved as {save_path}")

# Modificar el main para incluir el plotting
if __name__ == '__main__':
    import torchvision
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize models
    netG_A2B = Generator(3, 3).to(device)
    netG_B2A = Generator(3, 3).to(device)
    netD_A = Discriminator(3).to(device)
    netD_B = Discriminator(3).to(device)

    # Load datasets
    transform = transforms.Compose([
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    train_dataset = CycleGANDataset('./data/confocal_exper_altogether_trainR_256.npz', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

    # # Train the model
    # num_epochs = 100
    # losses = train(netG_A2B, netG_B2A, netD_A, netD_B, train_loader, num_epochs, device)
    
    # # Plot and save losses
    # plot_losses(losses)

Python version
3.12.1 (tags/v3.12.1:2305ca5, Dec  7 2023, 22:03:25) [MSC v.1937 64 bit (AMD64)]
PyTorch version: 2.2.1+cu118


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Importamos las clases necesarias del código anterior

def denormalize_image(image):
    """Convierte las imágenes de [-1,1] a [0,1] para visualización"""
    return (image + 1) / 2

def visualize_results(model_path, data_path, num_images=5, device='cuda'):
    """
    Visualiza los resultados del modelo CycleGAN
    
    Args:
        model_path (str): Ruta al archivo .pth del modelo
        data_path (str): Ruta al archivo .npz con los datos
        num_images (int): Número de imágenes a visualizar
        device (str): Dispositivo a usar ('cuda' o 'cpu')
    """
    # Configurar el dispositivo
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    
    # Cargar el modelo
    model = Generator(3, 3)  # 3 canales de entrada y salida
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    
    # Preparar el dataset
    transform = transforms.Compose([
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    dataset = CycleGANDataset(data_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    
    # Configurar la visualización
    fig, axes = plt.subplots(num_images, 3, figsize=(15, 5*num_images))
    plt.subplots_adjust(hspace=0.3)
    
    with torch.no_grad():
        for i, (real_A, real_B) in enumerate(dataloader):
            if i >= num_images:
                break
                
            # Mover datos al dispositivo
            real_A = real_A.to(device)
            real_B = real_B.to(device)
            
            # Generar imagen falsa
            fake_B = model(real_A)
            
            # Convertir a CPU y numpy para visualización
            real_A_np = denormalize_image(real_A[0]).cpu().numpy().transpose(1, 2, 0)
            real_B_np = denormalize_image(real_B[0]).cpu().numpy().transpose(1, 2, 0)
            fake_B_np = denormalize_image(fake_B[0]).cpu().numpy().transpose(1, 2, 0)
            
            # Visualizar las imágenes
            axes[i, 0].imshow(real_A_np)
            axes[i, 0].set_title('Input')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(real_B_np)
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(fake_B_np)
            axes[i, 2].set_title('Generated')
            axes[i, 2].axis('off')
    
    plt.savefig('cyclegan_results.png', bbox_inches='tight', dpi=300)
    plt.close()
    print("Results saved as 'cyclegan_results.png'")

# Ejemplo de uso
if __name__ == "__main__":
    model_path = 'netG_A2B_epoch_75.pth'  # Ajusta esto a la ruta de tu modelo
    data_path = './data/confocal_exper_paired_filt_validsetR_256.npz'  # Ajusta esto a la ruta de tus datos
    
    visualize_results(
        model_path=model_path,
        data_path=data_path,
        num_images=5,
        device='cuda'
    )

Results saved as 'cyclegan_results.png'
