# Drive & paths

In [None]:
%pip install -r requirements.txt

In [None]:
data_path = 'raindrop_data'
results_path = 'results'

# Imports or other fuctions

In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Dataset
from PIL import Image
import os

import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

from torch.utils.data import Dataset
from PIL import Image
import os

# Dataset

In [5]:
class RaindropDataset(Dataset):
    def __init__(self, root_dir, mode='train', transform=None):
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform
        self.image_dir = os.path.join(root_dir, mode, mode, "data")
        self.clean_dir = os.path.join(root_dir, mode, mode, "gt")

        self.rainy_images = sorted(os.listdir(self.image_dir))
        self.clean_images = sorted(os.listdir(self.clean_dir))  # must match rain order!

        assert len(self.rainy_images) == len(self.clean_images), "Rainy and clean image count mismatch!"

    def __len__(self):
        return len(self.rainy_images)

    def __getitem__(self, idx):
        rain_path = os.path.join(self.image_dir, self.rainy_images[idx])
        clean_path = os.path.join(self.clean_dir, self.clean_images[idx])

        rain_img = Image.open(rain_path).convert("RGB")
        clean_img = Image.open(clean_path).convert("RGB")

        if self.transform:
            rain_img = self.transform(rain_img)
            clean_img = self.transform(clean_img)

        return {'data': rain_img, 'gt': clean_img}

# Define transforms
transform = transforms.Compose([
    transforms.Resize((480, 720)),  # Maintain your 720x480 resolution
    transforms.ToTensor(),
])

# Create dataset instances
train_dataset = RaindropDataset(root_dir=data_path,
                              mode='train',
                              transform=transform)

test_a_dataset = RaindropDataset(root_dir=data_path,
                                mode='test_a',
                                transform=transform)

test_b_dataset = RaindropDataset(root_dir=data_path,
                                mode='test_b',
                                transform=transform)

# Example of accessing a sample
sample = train_dataset[0]
print("Data image shape:", sample['data'].shape)
print("GT image shape:", sample['gt'].shape)

Data image shape: torch.Size([3, 480, 720])
GT image shape: torch.Size([3, 480, 720])


In [None]:
# prompt: show one of the images of the train dataset (rain and clean)

import matplotlib.pyplot as plt
import numpy as np

# Get a sample from the training dataset
sample = train_dataset[0]

# Get the rain and clean images
rain_image_tensor = sample['data']
clean_image_tensor = sample['gt']

# Convert tensors to numpy arrays for displaying
# PyTorch tensors are (C, H, W), matplotlib expects (H, W, C) for color images
rain_image_np = rain_image_tensor.permute(1, 2, 0).numpy()
clean_image_np = clean_image_tensor.permute(1, 2, 0).numpy()

# Display the images
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(rain_image_np)
plt.title("Rainy Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(clean_image_np)
plt.title("Clean Image")
plt.axis('off')

plt.tight_layout()
plt.show()

# Fist Experiment

Model

In [None]:
import torch
import torch.nn as nn

# Convolution + BatchNormalization + ReLU block for the encoder
class ConvBNReLU(nn.Module):
  """
  A block consisting of a Convolutional layer, Batch Normalization, and ReLU activation.

  Args:
      in_channels (int): Number of input channels.
      out_channels (int): Number of output channels.
      pooling (bool): If True, applies average pooling.
  """
  def __init__(self, in_channels: int, out_channels: int, pooling: bool = False):
      super(ConvBNReLU, self).__init__()
      self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
      self.bn = nn.BatchNorm2d(out_channels)
      self.relu = nn.ReLU(inplace=True)
      self.pool = nn.AvgPool2d(2, 2) if pooling else None

  def forward(self, x: torch.Tensor) -> torch.Tensor:
      """
      Forward pass of the ConvBNReLU block.

      Args:
          x (torch.Tensor): Input tensor.

      Returns:
          torch.Tensor: Output tensor after applying Conv -> BN -> ReLU -> (optional) Pooling.
      """
      if self.pool:
          x = self.pool(x)
      x = self.conv(x)
      x = self.bn(x)
      x = self.relu(x)
      return x

# BatchNormalization + ReLU block + Convolution for the decoder
class BNReLUConv(nn.Module):
  def __init__(self, in_channels: int, out_channels: int, upsampling: bool = False):
      super(BNReLUConv, self).__init__()
      self.bn = nn.BatchNorm2d(in_channels)
      self.relu = nn.ReLU(inplace=True)
      self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
      self.upsample = nn.Upsample(scale_factor=2, mode='nearest') if upsampling else None

  def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = self.bn(x)
      x = self.relu(x)
      if self.upsample:
          x = self.upsample(x)
      x = self.conv(x)
      return x

class Encoder(nn.Module):

  def __init__(self, in_channels: int, out_features: int, base_channels: int = 16):
      super(Encoder, self).__init__()
      self.layer1 = ConvBNReLU(in_channels, base_channels, pooling=False) # Output: (bc, 480, 720)
      self.layer2 = ConvBNReLU(base_channels, base_channels * 2, pooling=True) # Output: (bc * 2, 240, 360)

      self.layer3 = ConvBNReLU(base_channels * 2, base_channels * 4, pooling=False) # Output: (bc * 4, 240, 360)
      self.layer4 = ConvBNReLU(base_channels * 4, base_channels * 8, pooling=True) # Output: (bc * 8, 120, 180)

      self.adaptive_pool = nn.AdaptiveAvgPool2d((8, 8))  # Output: (bc, 8, 8)
      self.fc = nn.Linear(8 * 8 * base_channels * 8, out_features)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = self.layer1(x)
      x = self.layer2(x)
      x = self.layer3(x)
      x = self.layer4(x)
      x = self.adaptive_pool(x)
      x = x.view(x.size(0), -1)  # Flatten the tensor
      x = self.fc(x)
      return x

class Decoder(nn.Module):
  def __init__(self, out_features: int, base_channels: int = 16):
      super(Decoder, self).__init__()
      self.base_channels = base_channels
      self.fc = nn.Linear(out_features, 8 * 8 * base_channels * 8) # Match the Encoder output size

      self.layer4 = BNReLUConv(base_channels * 8, base_channels * 4, upsampling=True) # Output: (bc * 4, 480, 720)
      self.layer3 = BNReLUConv(base_channels * 4, base_channels * 2, upsampling=False) # Output: (bc * 2, 480, 720)

      self.layer2 = BNReLUConv(base_channels * 2, base_channels, upsampling=True) # Output: (bc, 480, 720)
      self.layer1 = BNReLUConv(base_channels, 3, upsampling=False) # Output: (3, 480, 720)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = self.fc(x)
      x = x.view(x.size(0), self.base_channels * 8, 8, 8)  # Reshape to match the start of the decoder convolutional layers
      x = self.layer4(x)
      x = self.layer3(x)
      x = self.layer2(x)
      # For the last layer, since we removed upsampling, we need to resize
      # the output to the original image size (480, 720).
      x = self.layer1(x)
      x = torch.nn.functional.interpolate(x, size=(480, 720), mode='nearest') # Resize to original size
      x = torch.sigmoid(x)
      return x

class Generator(nn.Module):
    def __init__(self, in_features: int, base_channels: int = 16):
        super(Generator, self).__init__()
        self.encoder = Encoder(in_channels=3, out_features=in_features, base_channels=base_channels)
        self.decoder = Decoder(in_features, base_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        return self.decoder(z)

class Discriminator(nn.Module):
    def __init__(self, base_channels: int = 16):
        super(Discriminator, self).__init__()
        self.classifier = Encoder(in_channels=6, out_features=1, base_channels=base_channels)


    def forward(self, rainy_img: torch.Tensor, clean_or_fake: torch.Tensor) -> torch.Tensor:
        # Concatenate the rainy image and the clean/fake image along the channel dimension
        combined_input = torch.cat([rainy_img, clean_or_fake], dim=1)
        # Pass the concatenated input to the classifier
        out = self.classifier(combined_input)
        return torch.sigmoid(out)

Training

In [None]:
def train_conditional_GAN(gen, disc, train_loader, optimizer_gen, optimizer_disc,
                          num_epochs=10, device='cuda', model_name='derain_gan'):

    gen.to(device).train()
    disc.to(device).train()
    criterion = nn.BCELoss()

    losses_list = []

    for epoch in range(num_epochs):
        for i, images in enumerate(train_loader):
            real_images, rainy_images = images['gt'].to(device), images['data'].to(device)
            batch_size = real_images.size(0)

            valid = torch.ones(batch_size, 1).to(device)
            fake = torch.zeros(batch_size, 1).to(device)

            # === Train Generator ===
            optimizer_gen.zero_grad()
            fake_images = gen(rainy_images)
            pred_fake = disc(rainy_images, fake_images)
            loss_gen = criterion(pred_fake, valid)  # Want D to believe G's output is real
            loss_gen.backward()
            optimizer_gen.step()

            # === Train Discriminator ===
            optimizer_disc.zero_grad()
            pred_real = disc(rainy_images, real_images)
            pred_fake = disc(rainy_images, fake_images.detach())
            loss_real = criterion(pred_real, valid)
            loss_fake = criterion(pred_fake, fake)
            loss_disc = (loss_real + loss_fake) / 2
            loss_disc.backward()
            optimizer_disc.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss_G: {loss_gen.item():.4f}, Loss_D: {loss_disc.item():.4f}')
        torch.save(gen.state_dict(), f"{results_path}/{model_name}")
        losses_list.append((loss_gen.item(), loss_disc.item()))

    return losses_list

In [None]:
# Define Geneartor and Discriminator networks
# # Parameters for training
latent_dim = 128
base_channels = 32
num_epochs = 20
learning_rate_generator = 0.001
learning_rate_discriminator = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize Generator and Discriminator
generator = Generator(in_features=latent_dim, base_channels=base_channels)
discriminator = Discriminator(base_channels=base_channels)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate_generator)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate_discriminator)

# Start training
losses = train_conditional_GAN(generator, discriminator, train_loader, optimizer_G, optimizer_D, num_epochs, device)

Plots

In [None]:
import matplotlib.pyplot as plt
# Extract generator and discriminator losses
gen_losses = [loss[0] for loss in losses]
disc_losses = [loss[1] for loss in losses]
epochs = range(1, num_epochs + 1)

# Plot the losses
plt.figure(figsize=(10, 5))
plt.plot(epochs, gen_losses, label='Generator Loss')
plt.plot(epochs, disc_losses, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator and Discriminator Loss over Epochs')
plt.legend()
plt.grid(True)
plt.savefig(results_path + 'plots' + 'losses.png')
plt.show()

In [None]:
def show_generated_images_from_dataset(generator, dataset, device='cpu', n_images=4):
    generator.eval()
    with torch.no_grad():
        for i in range(n_images):
            sample = dataset[i]
            rainy = sample['data'].unsqueeze(0).to(device)  # [1, 3, H, W]
            fake_clean = generator(rainy).cpu().squeeze(0)  # [3, H, W]

            # Mostrar imagen original y generada lado a lado
            fig, axs = plt.subplots(1, 2, figsize=(8, 4))
            axs[0].imshow(sample['data'].permute(1, 2, 0))
            axs[0].set_title('Imagen con lluvia')
            axs[0].axis('off')

            axs[1].imshow(fake_clean.permute(1, 2, 0))
            axs[1].set_title('Imagen generada (limpia)')
            axs[1].axis('off')

            plt.tight_layout()
            plt.savefig(results_path + f'images' + f'generated_image_model1_{i}.png')
            plt.show()

show_generated_images_from_dataset(generator, test_a_dataset, device)

# Second Experiement

Model

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt

# ---------------------- GENERADOR UNET ----------------------
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        return torch.cat((x, skip_input), 1)

class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 256)
        self.up4 = UNetUp(512, 128)
        self.up5 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)

        u1 = self.up1(d6, d5)
        u2 = self.up2(u1, d4)
        u3 = self.up3(u2, d3)
        u4 = self.up4(u3, d2)
        u5 = self.up5(u4, d1)

        return self.final(u5)

# ---------------------- DISCRIMINADOR PATCHGAN ----------------------
class Discriminator(nn.Module):
    def __init__(self, in_channels=6):
        super().__init__()
        def block(in_c, out_c, norm=True):
            layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
            if norm:
                layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(in_channels, 64, norm=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img_A, img_B):
        x = torch.cat((img_A, img_B), dim=1)
        return self.model(x)

# ---------------------- TRANSFORMACIONES Y DATASET ----------------------
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Asume que ya definiste RaindropDataset y data_path correctamente
train_dataset = RaindropDataset(root_dir=data_path, mode='train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# ---------------------- ENTRENAMIENTO ----------------------

g_losses = []
d_losses = []
l1_losses = []
psnr_list = []


num_epochs = 100
lr = 0.0002
lambda_l1 = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Modelos
gen = GeneratorUNet().to(device)
disc = Discriminator().to(device)

# Optimizers y criterios
optimizer_G = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

# Training loop
for epoch in range(num_epochs):
    gen.train()  # Set generator to training mode
    disc.train()  # Set discriminator to training mode

    # Initialize epoch metrics
    epoch_g_loss = 0
    epoch_d_loss = 0
    epoch_l1_loss = 0
    num_batches = 0

    # Progress bar for batches
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

    for batch in loop:
        real_rain = batch['data'].to(device)  # Input images with raindrops
        real_clean = batch['gt'].to(device)  # Ground truth clean images

        # Get shape for creating real/fake labels
        with torch.no_grad():
            pred_shape = disc(real_rain, real_clean).shape

        # Create labels (1 for real, 0 for fake)
        valid = torch.ones(pred_shape, device=device)
        fake = torch.zeros(pred_shape, device=device)

        # ----------------------
        # Train Generator
        # ----------------------
        optimizer_G.zero_grad()  # Clear generator gradients

        # Generate fake clean images
        fake_clean = gen(real_rain)

        # Discriminator's evaluation of fake images
        pred_fake = disc(real_rain, fake_clean)

        # Calculate losses
        loss_GAN = criterion_GAN(pred_fake, valid)  # GAN loss (tries to fool discriminator)
        loss_L1 = criterion_L1(fake_clean, real_clean)  # L1 loss (encourages pixel-wise similarity)
        loss_G = loss_GAN + lambda_l1 * loss_L1  # Total generator loss

        # Backpropagate and update generator
        loss_G.backward()
        optimizer_G.step()

        # ----------------------
        # Train Discriminator
        # ----------------------
        optimizer_D.zero_grad()  # Clear discriminator gradients

        # Discriminator's evaluation of real images
        pred_real = disc(real_rain, real_clean)
        loss_real = criterion_GAN(pred_real, valid)  # Loss for real images

        # Discriminator's evaluation of fake images (detached from generator)
        pred_fake = disc(real_rain, fake_clean.detach())
        loss_fake = criterion_GAN(pred_fake, fake)  # Loss for fake images

        # Total discriminator loss
        loss_D = 0.5 * (loss_real + loss_fake)

        # Backpropagate and update discriminator
        loss_D.backward()
        optimizer_D.step()

        # Update epoch metrics
        epoch_g_loss += loss_G.item()
        epoch_d_loss += loss_D.item()
        epoch_l1_loss += loss_L1.item()
        num_batches += 1

    # Store average losses for the epoch
    g_losses.append(epoch_g_loss / num_batches)
    d_losses.append(epoch_d_loss / num_batches)
    l1_losses.append(epoch_l1_loss / num_batches)

    # Calculate PSNR on sample images
    gen.eval()  # Set generator to evaluation mode
    with torch.no_grad():
        psnr_total = 0
        samples = 50  # Number of samples to evaluate

        for i in range(samples):
            sample = train_dataset[i]  # Get sample
            input_img = sample['data'].unsqueeze(0).to(device)  # Add batch dimension
            target_img = sample['gt'].permute(1, 2, 0).cpu().numpy()  # Prepare ground truth
            output_img = gen(input_img).squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()  # Generate output
            psnr_total += psnr(target_img, output_img, data_range=1.0)  # Calculate PSNR

        psnr_avg = psnr_total / samples  # Average PSNR
        psnr_list.append(psnr_avg)  # Store PSNR
        print(f"Epoch {epoch+1}: PSNR = {psnr_avg:.2f} dB")  # Print PSNR

Plots

In [None]:
import matplotlib.pyplot as plt

def show_generated_images_from_dataset(generator, dataset, device='cpu', n_images=4):
    generator.eval()
    with torch.no_grad():
        for i in range(n_images):
            sample = dataset[i]
            rainy = sample['data'].unsqueeze(0).to(device)  # [1, 3, H, W]
            fake_clean = generator(rainy).cpu().squeeze(0)  # [3, H, W]

            # Clipping (por seguridad)
            rainy_img = sample['data'].clamp(0, 1).permute(1, 2, 0).cpu().numpy()
            fake_img = fake_clean.clamp(0, 1).permute(1, 2, 0).cpu().numpy()

            # Mostrar imagen original y generada lado a lado
            fig, axs = plt.subplots(1, 2, figsize=(12, 6))
            axs[0].imshow(rainy_img)
            axs[0].set_title('Imagen con lluvia')
            axs[0].axis('off')

            axs[1].imshow(fake_img)
            axs[1].set_title('Imagen generada (limpia)')
            axs[1].axis('off')

            plt.tight_layout()
            plt.show()

show_generated_images_from_dataset(gen, train_dataset, device=device, n_images=4)

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses, label='Discriminator Loss')
plt.plot(l1_losses, label='L1 Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Pérdidas por Epoch')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{results_path}/losses.png")
plt.show()

plt.figure(figsize=(8, 5))
plt.plot(psnr_list, label='PSNR', color='purple')
plt.xlabel('Epoch')
plt.ylabel('PSNR (dB)')
plt.title('PSNR por Epoch')
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{results_path}/psnr.png")
plt.show()


In [None]:
import matplotlib.pyplot as plt
import torch

def show_generated_images_with_patch_confidence(generator, discriminator, dataset, device='cpu', n_images=4):
    generator.eval()
    discriminator.eval()

    with torch.no_grad():
        for i in range(n_images):
            sample = dataset[i]
            rainy = sample['data'].unsqueeze(0).to(device)
            fake_clean = generator(rainy)

            # Obtener matriz de confianza del discriminador
            patch_output = torch.sigmoid(discriminator(rainy, fake_clean)).squeeze().cpu().numpy()

            # Procesar imágenes para visualización
            rainy_img = sample['data'].clamp(0, 1).permute(1, 2, 0).cpu().numpy()
            fake_img = fake_clean.squeeze(0).clamp(0, 1).permute(1, 2, 0).cpu().numpy()

            # Mostrar imágenes y matriz
            fig, axs = plt.subplots(1, 3, figsize=(18, 6))
            axs[0].imshow(rainy_img)
            axs[0].set_title('Imagen con lluvia')
            axs[0].axis('off')

            axs[1].imshow(fake_img)
            axs[1].set_title('Imagen generada (limpia)')
            axs[1].axis('off')

            im = axs[2].imshow(patch_output, cmap='viridis', vmin=0, vmax=1)
            axs[2].set_title('Confianza por patch (PatchGAN)')
            axs[2].axis('off')
            plt.colorbar(im, ax=axs[2], shrink=0.7)

            plt.tight_layout()
            plt.show()

show_generated_images_with_patch_confidence(gen, disc, train_dataset, device=device, n_images=4)

# Third Experiment

Model

In [None]:
# ---------------------- GENERADOR UNET ----------------------
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        return torch.cat((x, skip_input), 1)

class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 256)
        self.up4 = UNetUp(512, 128)
        self.up5 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)

        u1 = self.up1(d6, d5)
        u2 = self.up2(u1, d4)
        u3 = self.up3(u2, d3)
        u4 = self.up4(u3, d2)
        u5 = self.up5(u4, d1)

        return self.final(u5)

# ---------------------- DISCRIMINADOR PATCHGAN ----------------------
class Discriminator(nn.Module):
    def __init__(self, in_channels=6):
        super().__init__()
        def block(in_c, out_c, norm=True):
            layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
            if norm:
                layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(in_channels, 64, norm=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img_A, img_B):
        x = torch.cat((img_A, img_B), dim=1)
        return self.model(x)

Training

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr

def train_GAN(gen, disc, train_loader, optimizer_G, optimizer_D, num_epochs=10,
              lambda_l1=100, device="cuda", name='more_stable_CGAN'):

    # Initializing the lists to track metrics
    losses_list = []  # Will store (generator_loss, discriminator_loss) for each epoch
    dic_accuracies_list = []  # Will store discriminator accuracy for each epoch
    psnr_list = []  # Will store PSNR values for each epoch

    # Loss functions
    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_L1 = nn.L1Loss()

    for epoch in range(num_epochs):
        gen.train()
        disc.train()

        # Initialize epoch metrics
        epoch_g_loss = 0
        epoch_d_loss = 0
        epoch_l1_loss = 0
        running_corrects = 0
        total_preds = 0
        num_batches = 0

        # Progress bar for batches
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        for batch in loop:
            real_rain = batch['data'].to(device)
            real_clean = batch['gt'].to(device)

            # Get shape for creating real/fake labels
            with torch.no_grad():
                pred_shape = disc(real_rain, real_clean).shape

            # Label smoothing: using 0.9 for real and 0.1 for fake
            valid = torch.full(pred_shape, 0.9, device=device)
            fake = torch.full(pred_shape, 0.1, device=device)

            # -----------------
            # Train Generator
            # -----------------
            optimizer_G.zero_grad()

            # Generate fake clean images
            fake_clean = gen(real_rain)

            # Discriminator's evaluation of fake images
            pred_fake = disc(real_rain, fake_clean)

            # Calculate losses
            loss_GAN = criterion_GAN(pred_fake, valid)
            loss_L1 = criterion_L1(fake_clean, real_clean)
            loss_G = loss_GAN + lambda_l1 * loss_L1

            # Backpropagate and update generator
            loss_G.backward()
            optimizer_G.step()

            # ---------------------
            # Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            # Discriminator's evaluation of real images
            pred_real = disc(real_rain, real_clean)
            loss_real = criterion_GAN(pred_real, valid)

            # Discriminator's evaluation of fake images (detached from generator)
            pred_fake = disc(real_rain, fake_clean.detach())
            loss_fake = criterion_GAN(pred_fake, fake)

            # Total discriminator loss
            loss_D = 0.5 * (loss_real + loss_fake)

            # Backpropagate and update discriminator
            loss_D.backward()
            optimizer_D.step()

            # Update epoch metrics
            epoch_g_loss += loss_G.item()
            epoch_d_loss += loss_D.item()
            epoch_l1_loss += loss_L1.item()
            num_batches += 1

            # Calculate discriminator accuracy
            with torch.no_grad():
                pred_real_class = (torch.sigmoid(pred_real) > 0.5).float()
                pred_fake_class = (torch.sigmoid(pred_fake) <= 0.5).float()

                correct_real = (pred_real_class == 1).float().sum()
                correct_fake = (pred_fake_class == 1).float().sum()

                running_corrects += correct_real + correct_fake
                total_preds += pred_real.numel() + pred_fake.numel()

            loop.set_postfix({
                "G_loss": f"{loss_G.item():.4f}",
                "D_loss": f"{loss_D.item():.4f}",
                "D_acc": f"{100 * running_corrects.item() / total_preds:.2f}%"
            })

        # Calculate average metrics for the epoch
        avg_g_loss = epoch_g_loss / num_batches
        avg_d_loss = epoch_d_loss / num_batches
        epoch_accuracy = 100 * running_corrects.item() / total_preds

        # Calculate PSNR on validation samples
        gen.eval()
        with torch.no_grad():
            psnr_total = 0
            samples = min(50, len(train_loader.dataset))  # Use up to 50 samples
            for i in range(samples):
                sample = train_loader.dataset[i]
                input_img = sample['data'].unsqueeze(0).to(device)
                target_img = sample['gt'].permute(1, 2, 0).cpu().numpy()
                output_img = gen(input_img).squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()
                psnr_total += psnr(target_img, output_img, data_range=1.0)
            psnr_avg = psnr_total / samples
            psnr_list.append(psnr_avg)

        print(f"\nEpoch {epoch+1}/{num_epochs} | "
              f"G_loss: {avg_g_loss:.4f} | D_loss: {avg_d_loss:.4f} | "
              f"D_acc: {epoch_accuracy:.2f}% | PSNR: {psnr_avg:.2f} dB")

        # Store metrics
        losses_list.append((avg_g_loss, avg_d_loss))
        dic_accuracies_list.append(epoch_accuracy)

        # Save model checkpoints periodically
        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
            torch.save(gen.state_dict(), f"{name}_gen_epoch{epoch+1}.pth")
            torch.save(disc.state_dict(), f"{name}_disc_epoch{epoch+1}.pth")

    return losses_list, dic_accuracies_list, psnr_list

In [None]:
num_epochs = 100
gen_lr = 5e-4
dic_lr = 2e-4

lambda_l1 = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Modelos
gen = GeneratorUNet().to(device)
disc = Discriminator().to(device)

# Optimizers y criterios
optimizer_G = torch.optim.Adam(gen.parameters(), lr=gen_lr, betas=(0.7, 0.999))
optimizer_D = torch.optim.Adam(disc.parameters(), lr=dic_lr, betas=(0.5, 0.999))

#trining
# Correct the order of arguments in the function call
losses, accuracies, psnr = train_GAN(gen, disc, train_loader, optimizer_G, optimizer_D,
                                     num_epochs=num_epochs, lambda_l1=lambda_l1, device=device)

Plots

In [None]:
import matplotlib.pyplot as plt

def show_generated_images_from_dataset(generator, dataset, device='cpu', n_images=4, name='comparaciones_disc_mejorado'):
    generator.eval()
    with torch.no_grad():
        for i in range(n_images):
            sample = dataset[i]
            rainy = sample['data'].unsqueeze(0).to(device)  # [1, 3, H, W]
            fake_clean = generator(rainy).cpu().squeeze(0)  # [3, H, W]

            # Clipping (por seguridad)
            rainy_img = sample['data'].clamp(0, 1).permute(1, 2, 0).cpu().numpy()
            fake_img = fake_clean.clamp(0, 1).permute(1, 2, 0).cpu().numpy()
            groud_truth = sample['gt'].permute(1, 2, 0).cpu().numpy()

            # Mostrar imagen original y generada lado a lado
            fig, axs = plt.subplots(1, 3, figsize=(12, 6))
            axs[0].imshow(rainy_img)
            axs[0].set_title('Imagen con lluvia')
            axs[0].axis('off')

            axs[1].imshow(fake_img)
            axs[1].set_title('Imagen generada (limpia)')
            axs[1].axis('off')

            axs[2].imshow(rainy_img)
            axs[2].imshow(groud_truth)
            axs[2].set_title('Ground Truth')
            axs[2].axis('off')

            plt.tight_layout()
            plt.savefig(os.path.join(results_path, 'plots', 'comparaciones_disc_mejorado', name + f'_{i}.png'))
            plt.show()

show_generated_images_from_dataset(gen, test_a_dataset, device=device, n_images=40, name='comparaciones_disc_mejorado')

In [None]:
import matplotlib.pyplot as plt
# Make two separate plots:
# 1. Generator and Discriminator losses over epochs.
# 2. PSNR and Discriminator accuracy over epochs.

epochs_list = range(1, num_epochs + 1)

# Plot 1: Generator and Discriminator Losses
plt.figure(figsize=(10, 5))
plt.plot(epochs_list, [l[0] for l in losses], label='Generator Loss')
plt.plot(epochs_list, [l[1] for l in losses], label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator and Discriminator Loss over Epochs')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(results_path, 'plots', 'losses_plot_improved_disc.png')) # Save plot
plt.show()

# Plot 2: PSNR and Discriminator Accuracy
plt.figure(figsize=(10, 5))
plt.plot(epochs_list, accuracies, label='Discriminator Accuracy (%)', color='tab:red')
plt.plot(epochs_list, psnr, label='PSNR (dB)', color='tab:green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Generator and Discriminator Accuracy over Epochs')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(results_path, 'plots', 'psnr_disc_accuracy_plot.png')) # Save plot
plt.show()