In [None]:
import os
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import torchvision.utils as vutils

# Dataset, and DataLoader

In [None]:
import os
import torch
from torchvision import transforms
from PIL import Image
import random
import matplotlib.pyplot as plt

# Custom Dataset for Cityscapes
class CityscapesDataset(torch.utils.data.Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.image_files = sorted([os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('.png', '.jpg'))])
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load the concatenated image as a PIL image
        image = Image.open(self.image_files[idx]).convert("RGB")
        width, height = image.size

        # Split into input and output
        input = image.crop((width // 2, 0, width, height))  # Right half (input)
        output = image.crop((0, 0, width // 2, height))  # Left half (output)

        # Apply paired transformations if provided
        if self.transform:
          input, output= self.transform([input, output])

        return input, output


# Custom transform class to apply the same transformation to both input and output
class PairedResize(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.resize = transforms.Resize(size)

    def __call__(self, imgs):
        return [self.resize(img) for img in imgs]


class PairedRandomCrop(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.size = size

    def __call__(self, imgs):
        width, height = imgs[0].size
        crop_width, crop_height = self.size

        top = random.randint(0, height - crop_height)
        left = random.randint(0, width - crop_width)

        crops = [img.crop((left, top, left + crop_width, top + crop_height)) for img in imgs]
        return crops


class PairedRandomHorizontalFlip(torch.nn.Module):
    def __init__(self, flip_prob=0.5):
        super().__init__()
        self.flip_prob = flip_prob

    def __call__(self, imgs):
        if random.random() < self.flip_prob:
          return [transforms.functional.hflip(img) for img in imgs]
        return imgs


class PairedToTensor(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, imgs):
        return [transforms.ToTensor()(img) for img in imgs]


# Define paired transformations
paired_transform_train = transforms.Compose([
    PairedResize((286, 286)),
    PairedRandomCrop((256, 256)),
    PairedRandomHorizontalFlip(),
    PairedToTensor(),
])

paired_transform_test = transforms.Compose([
    PairedToTensor(),
])

# Create datasets for train and validation
train_dir = "/kaggle/input/cityscapes-pix2pix-dataset/train"
val_dir = "/kaggle/input/cityscapes-pix2pix-dataset/val"

train_dataset = CityscapesDataset(train_dir, transform=paired_transform_train)
val_dataset = CityscapesDataset(val_dir, transform=paired_transform_test)

# Create DataLoaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True)

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(10, 15))

# Iterate through the train DataLoader
for index, (inputs, outputs) in enumerate(train_loader):
    if index < 3:  # Display three pairs of images
        input_img = inputs.squeeze(0).permute(1, 2, 0).numpy()  # Convert to HWC
        output_img = outputs.squeeze(0).permute(1, 2, 0).numpy()

        axes[index, 0].imshow(input_img)
        axes[index, 0].set_title(f"Input Image (Train Set) {index+1}")
        axes[index, 0].axis('off')

        axes[index, 1].imshow(output_img)
        axes[index, 1].set_title(f"Output Image (Train Set) {index+1}")
        axes[index, 1].axis('off')
    else:
        break

plt.show()

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(10, 15))

# Iterate through the train DataLoader
for index, (inputs, outputs) in enumerate(val_loader):
    if index < 3:  # Display three pairs of images
        input_img = inputs.squeeze(0).permute(1, 2, 0).numpy()  # Convert to HWC
        output_img = outputs.squeeze(0).permute(1, 2, 0).numpy()

        axes[index, 0].imshow(input_img)
        axes[index, 0].set_title(f"Input Image (Test Set) {index+1}")
        axes[index, 0].axis('off')

        axes[index, 1].imshow(output_img)
        axes[index, 1].set_title(f"Output Image (Test Set) {index+1}")
        axes[index, 1].axis('off')
    else:
        break

plt.show()

# ConvBlock Class

This class defines a reusable building block for convolutional layers.
* Initializes with `in_channels`, `out_channels`, `apply_bn`, and `apply_leaky` flags.
* Creates a 2D convolutional layer with kernel size 4, stride 2, padding 1, and no bias.
* Optionally adds a BatchNorm2d layer.
* Sets the activation function to LeakyReLU or ReLU.
* In the forward pass, it applies convolution, optional batch normalization, and the activation function.

# UpSample Class

This class defines a building block for upsampling operations in the decoder.
* Initializes with `in_channels`, `out_channels`, and `apply_dropout` flags.
* Creates a sequence of layers:
    * Transposed convolutional layer to upsample the feature map.
    * BatchNorm2d layer.
    * ReLU activation.
    * Optional Dropout layer.
* In the forward pass, it applies the sequence of layers to the input.

# PatchDiscriminator Class

This class implements a patch-based discriminator network for image generation.
* Initializes with the number of input channels.
* Creates a sequence of layers using `nn.Sequential`:
    * First ConvBlock with 64 channels and no batch normalization.
    * Subsequent ConvBlocks with increasing channels and batch normalization.
    * Convolutional layers with 512 and 1 channels.
    * Sigmoid activation for output.
* In the forward pass, concatenates input and generated images, and passes through the layers.

# UnetGenerator Class

This class implements a U-Net generator architecture for image-to-image translation.
* Initializes with input and output channels.
* Creates encoder blocks (`en1` to `en8`) using ConvBlock.
* Creates decoder blocks (`de1` to `de7`) using custom `UpSample` class.
* Creates a final convolutional transpose layer (`de8`) and Tanh activation.
* In the forward pass,
    * Passes input through encoder,
    * Passes encoder output through decoder with upsampling and concatenation,
    * Applies final convolutional transpose and Tanh activation.

In [None]:
class ConvBlock(nn.Module):
    """
    Convolutional block with optional batch normalization and activation.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        apply_bn (bool, optional): Whether to apply batch normalization. Defaults to True.
        apply_leaky (bool, optional): Whether to use LeakyReLU activation. Defaults to True.

    """
    def __init__(self, in_channels, out_channels, apply_bn=True, apply_leaky=True):
        super().__init__()
        self.apply_bn = apply_bn
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
        if apply_bn:
            self.bn = nn.BatchNorm2d(out_channels)
        self.activation = nn.LeakyReLU(0.2, inplace=True) if apply_leaky else nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.apply_bn and x.size(2) > 1 and x.size(3) > 1:  # Only apply BatchNorm for size > 1
            x = self.bn(x)
        return self.activation(x)

In [None]:
class UpSample(nn.Module):
    """
    Upsampling block with transposed convolution, batch normalization, and optional dropout.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        apply_dropout (bool, optional): Whether to apply dropout. Defaults to False.

    """
    def __init__(self, in_channels, out_channels, apply_dropout=False):
        super().__init__()
        layers = list()
        layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))
        if apply_dropout:
            layers.append(nn.Dropout(0.5))

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

In [None]:
class PatchDiscriminator(nn.Module):
    """
    Patch-based discriminator for image-to-image translation.

    Args:
        in_channels (int, optional): Number of input channels. Defaults to 3.

    Attributes:
        layers: nn.Sequential containing the discriminator layers.
    """
    def __init__(self, in_channels=3):
        super().__init__()
        self.layers = nn.Sequential(
            ConvBlock(2*in_channels, 64, apply_bn=False),  # 64x128x128
            ConvBlock(64, 128),  # 128x64x64
            ConvBlock(128, 256),  # 256x32x32
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1, bias=False),  # 512x31x31
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),  # 1x30x30
            nn.Sigmoid()
        )

    def forward(self, input_image, generated_image):
        """
        Forward pass of the discriminator.

        Args:
            input_image: Input image tensor.
            generated_image: Generated image tensor.

        Returns:
            Discriminator output (probability score).
        """
        return self.layers(torch.cat([input_image, generated_image], dim=1))

In [None]:
class UnetGenerator(nn.Module):
    """
    U-Net Generator for image-to-image translation.

    Args:
        in_channels (int, optional): Number of input channels. Defaults to 3.
        out_channels (int, optional): Number of output channels. Defaults to 3.

    Attributes:
        en1-en8: Encoder blocks (ConvBlocks).
        de1-de7: Decoder blocks (UpSample blocks).
        de8: Final convolutional transpose layer.
        final_activation: Tanh activation function.

    """
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.en1 = ConvBlock(in_channels, 64, apply_bn=False)  # 64x128x128
        self.en2 = ConvBlock(64, 128)  # 128x64x64
        self.en3 = ConvBlock(128, 256)  # 256x32x32
        self.en4 = ConvBlock(256, 512)  # 512x16x16
        self.en5 = ConvBlock(512, 512)  # 512x8x8
        self.en6 = ConvBlock(512, 512)  # 512x4x4
        self.en7 = ConvBlock(512, 512)  # 512x2x2
        self.en8 = ConvBlock(512, 512)  # 512x1x1

        self.de1 = UpSample(512, 512, apply_dropout=True)  # 512x2x2
        self.de2 = UpSample(1024, 512, apply_dropout=True)  # 512x4x4
        self.de3 = UpSample(1024, 512, apply_dropout=True)  # 512x8x8
        self.de4 = UpSample(1024, 512)  # 512x16x16
        self.de5 = UpSample(1024, 256)  # 256x32x32
        self.de6 = UpSample(512, 128)  # 128x64x64
        self.de7 = UpSample(256, 64)  # 64x128x128

        self.de8 = nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1)  # 3x256x256
        self.final_activation = nn.Tanh()

    def forward(self, x):
        """
        Forward pass of the U-Net generator.

        Args:
            x: Input image tensor (e.g., 3x256x256).

        Returns:
            Generated output image tensor (e.g., 3x256x256).
        """
        e1 = self.en1(x)  # 64x128x128
        e2 = self.en2(e1)  # 128x64x64
        e3 = self.en3(e2)  # 256x32x32
        e4 = self.en4(e3)  # 512x16x16
        e5 = self.en5(e4)  # 512x8x8
        e6 = self.en6(e5)  # 512x4x4
        e7 = self.en7(e6)  # 512x2x2
        e8 = self.en8(e7)  # 512x1x1

        d1 = self.de1(e8)  # 512x2x2
        d2 = self.de2(torch.cat([d1, e7], dim=1))  # 512x4x4
        d3 = self.de3(torch.cat([d2, e6], dim=1))  # 512x8x8
        d4 = self.de4(torch.cat([d3, e5], dim=1))  # 512x16x16
        d5 = self.de5(torch.cat([d4, e4], dim=1))  # 256x32x32
        d6 = self.de6(torch.cat([d5, e3], dim=1))  # 128x64x64
        d7 = self.de7(torch.cat([d6, e2], dim=1))  # 64x128x128

        d8 = self.de8(torch.cat([d7, e1], dim=1))  # 3x256x256
        return self.final_activation(d8)

# Training

In [None]:
# Hyperparameters
lr = 2e-4
beta1 = 0.5
beta2 = 0.999
num_epochs = 50
lambda_l1 = 100  # Weight for L1 loss
checkpoint_path = "/kaggle/working/pix2pix_checkpoint.pth"

# Initialize Generator and Discriminator
generator = UnetGenerator().cuda()
discriminator = PatchDiscriminator().cuda()

# Define Loss Functions
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

# Track losses
generator_losses = []
discriminator_losses = []
start_epoch = 0

# Load checkpoint if it exists
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    generator.load_state_dict(checkpoint["generator_state_dict"])
    discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
    optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"])
    optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"])
    generator_losses = checkpoint["generator_losses"]
    discriminator_losses = checkpoint["discriminator_losses"]
    start_epoch = checkpoint["epoch"] + 1
    print(f"Checkpoint loaded. Resuming from epoch {start_epoch}.")

# Training Loop
for epoch in range(start_epoch, num_epochs):
    generator.train()
    discriminator.train()

    loop = tqdm(train_loader, leave=True)
    epoch_G_loss = 0.0
    epoch_D_loss = 0.0

    for batch_idx, (input_image, target_image) in enumerate(loop):
        input_image = input_image.cuda()
        target_image = target_image.cuda()
        # print(input_image.shape)
        # print(target_image.shape)
        ### Train Discriminator ###
        fake_image = generator(input_image)

        # Discriminator on real images
        D_real = discriminator(input_image, target_image)
        D_real_loss = criterion_GAN(D_real, torch.ones_like(D_real).cuda())

        # Discriminator on fake images
        D_fake = discriminator(input_image, fake_image.detach())
        D_fake_loss = criterion_GAN(D_fake, torch.zeros_like(D_fake).cuda())

        # Total Discriminator loss
        D_loss = (D_real_loss + D_fake_loss) / 2
        optimizer_D.zero_grad()
        D_loss.backward()
        optimizer_D.step()

        ### Train Generator ###
        D_fake_for_G = discriminator(input_image, fake_image)
        G_GAN_loss = criterion_GAN(D_fake_for_G, torch.ones_like(D_fake_for_G).cuda())
        G_L1_loss = criterion_L1(fake_image, target_image) * lambda_l1
        G_loss = G_GAN_loss + G_L1_loss
        optimizer_G.zero_grad()
        G_loss.backward()
        optimizer_G.step()

        # Update epoch loss
        epoch_D_loss += D_loss.item()
        epoch_G_loss += G_loss.item()

        # Update progress bar
        loop.set_description(f"Epoch [{epoch}/{num_epochs}]")
        loop.set_postfix(D_loss=f"{D_loss.item():.4f}", G_loss=f"{G_loss.item():.4f}")
    print(f"Generator Loss is:{epoch_G_loss / len(train_loader)}")
    print(f"Discriminator Loss is:{epoch_D_loss / len(train_loader)}")

    # Average losses for the epoch
    generator_losses.append(epoch_G_loss / len(train_loader))
    discriminator_losses.append(epoch_D_loss / len(train_loader))

    # Save checkpoint
    checkpoint = {
        "epoch": epoch,
        "generator_state_dict": generator.state_dict(),
        "discriminator_state_dict": discriminator.state_dict(),
        "optimizer_G_state_dict": optimizer_G.state_dict(),
        "optimizer_D_state_dict": optimizer_D.state_dict(),
        "generator_losses": generator_losses,
        "discriminator_losses": discriminator_losses,
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved for epoch {epoch}.")

    if (epoch) % 5 == 0:
      with torch.no_grad():
          generator.eval()

          # Generate fake sample
          input_image, target_image = next(iter(val_loader))
          input_image, target_image = input_image.cuda(), target_image.cuda()
          fake_sample = generator(input_image)

          # Convert tensors to CPU and detach them for Matplotlib
          input_img = input_image.squeeze(0).cpu().permute(1, 2, 0)  # Remove batch dimension and permute
          target_img = target_image.squeeze(0).cpu().permute(1, 2, 0)  # Remove batch dimension and permute
          generated_img = fake_sample.squeeze(0).detach().cpu().permute(1, 2, 0)  # Remove batch dimension and permute

          # Plot images
          plt.figure(figsize=(12, 4))

          plt.subplot(1, 3, 1)
          plt.imshow(input_img, cmap="gray" if input_img.shape[-1] == 1 else None)
          plt.title("Input Image")
          plt.axis("off")

          plt.subplot(1, 3, 2)
          plt.imshow(target_img, cmap="gray" if target_img.shape[-1] == 1 else None)
          plt.title("Target Image")
          plt.axis("off")

          plt.subplot(1, 3, 3)
          plt.imshow(generated_img, cmap="gray" if generated_img.shape[-1] == 1 else None)
          plt.title("Generated Image")
          plt.axis("off")

          # Show the plot
          plt.show()
    print("################################################################")

In [None]:
import matplotlib.pyplot as plt

# Plot Generator and Discriminator losses
plt.figure(figsize=(12, 6))

# Plot Generator Loss
plt.subplot(1, 2, 1)
plt.plot(range(0, 2), generator_losses, label="Generator Loss", color='blue')
plt.title("Generator Loss Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

# Plot Discriminator Loss
plt.subplot(1, 2, 2)
plt.plot(range(0, 2), discriminator_losses, label="Discriminator Loss", color='red')
plt.title("Discriminator Loss Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

# Show plots
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import torch

# Ensure the generator is in evaluation mode
generator.eval()

# Number of samples to display
num_samples = 10

# Lists to hold images
input_images_list = []
target_images_list = []
generated_images_list = []

# Collect images from the validation loader
for i, (input_image, target_image) in enumerate(val_loader):
    if len(input_images_list) >= num_samples:
        break
    input_images_list.append(input_image.squeeze(0).cuda())  # Remove batch dimension
    target_images_list.append(target_image.squeeze(0).cuda())  # Remove batch dimension
    with torch.no_grad():
        generated_images_list.append(generator(input_image.cuda()).squeeze(0))  # Remove batch dimension

# Plot the images
fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples * 3))

for i in range(num_samples):
    # Input image
    input_img = input_images_list[i].cpu().permute(1, 2, 0).detach()  # Convert to HWC
    # Target image
    target_img = target_images_list[i].cpu().permute(1, 2, 0).detach()  # Convert to HWC
    # Generated image
    generated_img = generated_images_list[i].cpu().permute(1, 2, 0).detach()  # Convert to HWC

    # Plot the input image
    axes[i, 0].imshow(input_img, cmap="gray" if input_img.shape[-1] == 1 else None)
    axes[i, 0].axis("off")
    if i == 0:
        axes[i, 0].set_title("Input Image")

    # Plot the target image
    axes[i, 1].imshow(target_img, cmap="gray" if target_img.shape[-1] == 1 else None)
    axes[i, 1].axis("off")
    if i == 0:
        axes[i, 1].set_title("Target Image")

    # Plot the generated image
    axes[i, 2].imshow(generated_img, cmap="gray" if generated_img.shape[-1] == 1 else None)
    axes[i, 2].axis("off")
    if i == 0:
        axes[i, 2].set_title("Generated Image")

plt.tight_layout()
plt.show()
