# Import required libraries

In [1]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from torchvision import transforms
from itertools import zip_longest
from torch.utils.tensorboard import SummaryWriter

# Dataloader

In [3]:
# Define the ImageDataset class
class ImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.file_paths = [
            os.path.join(folder_path, file_name) for file_name in os.listdir(folder_path)
        ]
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = Image.open(img_path).convert("RGB")  # Open the image and convert to RGB
        if self.transform:
            image = self.transform(image)  # Apply transformations if provided
        return image

# Define transformation pipeline
transform_pipeline = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

def load_images(folder_path):
    dataset = ImageDataset(folder_path, transform=transform_pipeline)
    # Added num_workers and pin_memory for better performance
    return DataLoader(
        dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=4,  # Adjust this based on your CPU cores
        pin_memory=True  # Speeds up data transfer to GPU if using CUDA
    )

# Load datasets
source = load_images('/home/umang.shikarvar/instaformer/wb_small_airshed/images')  # Give source path
target = load_images('/home/umang.shikarvar/instaformer/delhi_ncr_small/images')   # Give target path

# Model

In [None]:
class ConvolutionalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_downsampling=True, add_activation=True, **kwargs):
        super().__init__()
        if is_downsampling:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True) if add_activation else nn.Identity(),
            )
        else:
            self.conv = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True) if add_activation else nn.Identity(),
            )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvolutionalBlock(channels, channels, add_activation=True, kernel_size=3, padding=1),
            ConvolutionalBlock(channels, channels, add_activation=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=6):
        super().__init__()
        self.initial_layer = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.downsampling_layers = nn.ModuleList([
            ConvolutionalBlock(num_features, num_features * 2, is_downsampling=True, kernel_size=3, stride=2, padding=1),
            ConvolutionalBlock(num_features * 2, num_features * 4, is_downsampling=True, kernel_size=3, stride=2, padding=1),
        ])
        self.residual_layers = nn.Sequential(*[ResidualBlock(num_features * 4) for _ in range(num_residuals)])
        self.upsampling_layers = nn.ModuleList([
            ConvolutionalBlock(num_features * 4, num_features * 2, is_downsampling=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ConvolutionalBlock(num_features * 2, num_features * 1, is_downsampling=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ])
        self.last_layer = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial_layer(x)
        for layer in self.downsampling_layers:
            x = layer(x)
        x = self.residual_layers(x)
        for layer in self.upsampling_layers:
            x = layer(x)
        return torch.tanh(self.last_layer(x))

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial_layer = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2, inplace=True),
        )
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                nn.Conv2d(in_channels, feature, kernel_size=4, stride=2 if feature != features[-1] else 1, padding=1, padding_mode="reflect"),
            )
            layers.append(nn.InstanceNorm2d(feature))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial_layer(x)
        return self.model(x)

# Loss functions

In [None]:
# Loss functions
mse_loss = nn.MSELoss()
cycle_consistency_loss = lambda real, cycled: torch.mean(torch.abs(real - cycled)) * 10.0
identity_loss = lambda real, same: torch.mean(torch.abs(real - same)) * 5.0

# Hyperparameters

In [None]:
# Set device to GPU or CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 4
EPOCHS = 250

# Instantiate models
generator_g = Generator(img_channels=3).to(device)
generator_f = Generator(img_channels=3).to(device)
discriminator_x = Discriminator().to(device)
discriminator_y = Discriminator().to(device)

# Optimizers
generator_g_optimizer = optim.Adam(generator_g.parameters(), lr=2e-4, betas=(0.5, 0.999))
generator_f_optimizer = optim.Adam(generator_f.parameters(), lr=2e-4, betas=(0.5, 0.999))
discriminator_x_optimizer = optim.Adam(discriminator_x.parameters(), lr=2e-4, betas=(0.5, 0.999))
discriminator_y_optimizer = optim.Adam(discriminator_y.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Learning rate schedulers
scheduler_g = optim.lr_scheduler.LambdaLR(generator_g_optimizer, lr_lambda=lambda epoch: 1.0 - max(0, epoch - 100) / 100)
scheduler_f = optim.lr_scheduler.LambdaLR(generator_f_optimizer, lr_lambda=lambda epoch: 1.0 - max(0, epoch - 100) / 100)
scheduler_dx = optim.lr_scheduler.LambdaLR(discriminator_x_optimizer, lr_lambda=lambda epoch: 1.0 - max(0, epoch - 100) / 100)
scheduler_dy = optim.lr_scheduler.LambdaLR(discriminator_y_optimizer, lr_lambda=lambda epoch: 1.0 - max(0, epoch - 100) / 100)

# Training loop

In [None]:
# Initialize TensorBoard writer
writer = SummaryWriter(log_dir='/home/umang.shikarvar/instaformer/CG_logs')

# Training loop
for epoch in range(EPOCHS):
    g_loss_total, f_loss_total, dx_loss_total, dy_loss_total = 0, 0, 0, 0
    cycle_loss_total, identity_loss_total = 0, 0

    for real_x_batch, real_y_batch in zip_longest(source, target):
        if real_x_batch is None or real_y_batch is None:
            continue
        real_x, real_y = real_x_batch.to(device), real_y_batch.to(device)

        # ------------------------
        # Train Generators G and F
        # ------------------------
        # Identity loss
        identity_x = generator_f(real_x)  # F(X) should be X
        identity_y = generator_g(real_y)  # G(Y) should be Y
        id_loss_x = torch.mean(torch.abs(real_x - identity_x)) * 5.0
        id_loss_y = torch.mean(torch.abs(real_y - identity_y)) * 5.0

        # Adversarial loss
        fake_y = generator_g(real_x)  # G(X)
        fake_x = generator_f(real_y)  # F(Y)
        adv_loss_g = mse_loss(discriminator_y(fake_y), torch.ones_like(discriminator_y(fake_y)))
        adv_loss_f = mse_loss(discriminator_x(fake_x), torch.ones_like(discriminator_x(fake_x)))

        # Cycle-consistency loss
        cycle_x = generator_f(fake_y)  # F(G(X))
        cycle_y = generator_g(fake_x)  # G(F(Y))
        cycle_loss_x = torch.mean(torch.abs(real_x - cycle_x)) * 10.0
        cycle_loss_y = torch.mean(torch.abs(real_y - cycle_y)) * 10.0

        # Combine generator losses
        total_g_loss = adv_loss_g + cycle_loss_x + id_loss_y
        total_f_loss = adv_loss_f + cycle_loss_y + id_loss_x

        generator_g_optimizer.zero_grad()
        generator_f_optimizer.zero_grad()
        total_g_loss.backward(retain_graph=True)
        total_f_loss.backward()
        generator_g_optimizer.step()
        generator_f_optimizer.step()

        # -------------------------
        # Train Discriminators X, Y
        # -------------------------
        # Discriminator loss for X
        real_loss_x = mse_loss(discriminator_x(real_x), torch.ones_like(discriminator_x(real_x)))
        fake_loss_x = mse_loss(discriminator_x(fake_x.detach()), torch.zeros_like(discriminator_x(fake_x)))
        dx_loss = (real_loss_x + fake_loss_x) * 0.5

        discriminator_x_optimizer.zero_grad()
        dx_loss.backward()
        discriminator_x_optimizer.step()

        # Discriminator loss for Y
        real_loss_y = mse_loss(discriminator_y(real_y), torch.ones_like(discriminator_y(real_y)))
        fake_loss_y = mse_loss(discriminator_y(fake_y.detach()), torch.zeros_like(discriminator_y(fake_y)))
        dy_loss = (real_loss_y + fake_loss_y) * 0.5

        discriminator_y_optimizer.zero_grad()
        dy_loss.backward()
        discriminator_y_optimizer.step()

        # Accumulate losses
        g_loss_total += total_g_loss.item()
        f_loss_total += total_f_loss.item()
        dx_loss_total += dx_loss.item()
        dy_loss_total += dy_loss.item()
        cycle_loss_total += cycle_loss_x.item() + cycle_loss_y.item()
        identity_loss_total += id_loss_x.item() + id_loss_y.item()

    scheduler_g.step()
    scheduler_f.step()
    scheduler_dx.step()
    scheduler_dy.step()

    # Log losses to TensorBoard
    writer.add_scalar('Loss/Generator_G', g_loss_total, epoch + 1)
    writer.add_scalar('Loss/Generator_F', f_loss_total, epoch + 1)
    writer.add_scalar('Loss/Discriminator_X', dx_loss_total, epoch + 1)
    writer.add_scalar('Loss/Discriminator_Y', dy_loss_total, epoch + 1)
    writer.add_scalar('Loss/Cycle_Consistency', cycle_loss_total, epoch + 1)
    writer.add_scalar('Loss/Identity', identity_loss_total, epoch + 1)

    # Print epoch summary
    print(
        f"Epoch [{epoch + 1}/{EPOCHS}]: "
        f"G_loss: {g_loss_total:.4f}, F_loss: {f_loss_total:.4f}, "
        f"D_X_loss: {dx_loss_total:.4f}, D_Y_loss: {dy_loss_total:.4f}, "
        f"Cycle_loss: {cycle_loss_total:.4f}, Identity_loss: {identity_loss_total:.4f}"
    )

    # Save generated images
    if (epoch + 1) % 25 == 0:
        torch.save(generator_g.state_dict(), f'/home/umang.shikarvar/instaformer/wb_CG_gen/generator_CG_{epoch+1}.pth')
        print(f"Saved generator_G state at epoch {epoch + 1}")

# Close TensorBoard writer after training completes
writer.close()
print("Training complete!")