In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import pytorch_lightning as pl
from torchvision.utils import save_image


# Generator Model
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        model = [
            nn.Conv2d(input_nc, 64, kernel_size=7, stride=1, padding=3, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]

        in_channels = 64
        out_channels = in_channels * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True),
            ]
            in_channels = out_channels
            out_channels = in_channels * 2

        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_channels)]

        out_channels = in_channels // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True),
            ]
            in_channels = out_channels
            out_channels = in_channels // 2

        model += [nn.Conv2d(64, output_nc, kernel_size=7, stride=1, padding=3), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# Residual Block for Generator
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)


# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        model = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        ]

        in_channels = 64
        out_channels = in_channels * 2
        for _ in range(3):
            model += [
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True),
            ]
            in_channels = out_channels
            out_channels = in_channels * 2

        model += [
            nn.Conv2d(in_channels, 1, kernel_size=4, padding=1)
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# CycleGAN Model in PyTorch Lightning
class CycleGAN(pl.LightningModule):
    def __init__(self, input_nc=1, output_nc=1, lr=0.0002, beta1=0.5, lambda_cycle=10.0, lambda_identity=5.0):
        super(CycleGAN, self).__init__()

        self.G_AB = Generator(input_nc, output_nc)
        self.G_BA = Generator(output_nc, input_nc)
        self.D_A = Discriminator(input_nc)
        self.D_B = Discriminator(output_nc)

        self.optimizer_G = optim.Adam(
            list(self.G_AB.parameters()) + list(self.G_BA.parameters()), lr=lr, betas=(beta1, 0.999)
        )
        self.optimizer_D_A = optim.Adam(self.D_A.parameters(), lr=lr, betas=(beta1, 0.999))
        self.optimizer_D_B = optim.Adam(self.D_B.parameters(), lr=lr, betas=(beta1, 0.999))

        self.criterion_GAN = nn.MSELoss()
        self.criterion_cycle = nn.L1Loss()
        self.criterion_identity = nn.L1Loss()

        self.lambda_cycle = lambda_cycle
        self.lambda_identity = lambda_identity

    def forward(self, x):
        fake_B = self.G_AB(x)
        fake_A = self.G_BA(x)
        return fake_B, fake_A

    def generator_step(self, real_A, real_B):
        fake_B = self.G_AB(real_A)
        fake_A = self.G_BA(real_B)

        loss_GAN_AB = self.criterion_GAN(self.D_B(fake_B), torch.ones_like(self.D_B(fake_B)))
        loss_GAN_BA = self.criterion_GAN(self.D_A(fake_A), torch.ones_like(self.D_A(fake_A)))

        recov_A = self.G_BA(fake_B)
        recov_B = self.G_AB(fake_A)
        loss_cycle_A = self.criterion_cycle(recov_A, real_A)
        loss_cycle_B = self.criterion_cycle(recov_B, real_B)

        loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A)
        loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B)

        loss_G = (
            loss_GAN_AB + loss_GAN_BA +
            self.lambda_cycle * (loss_cycle_A + loss_cycle_B) +
            self.lambda_identity * (loss_id_A + loss_id_B)
        )
        return loss_G

    def discriminator_step(self, real_A, real_B):
        fake_A = self.G_BA(real_B).detach()
        fake_B = self.G_AB(real_A).detach()

        loss_D_A_real = self.criterion_GAN(self.D_A(real_A), torch.ones_like(self.D_A(real_A)))
        loss_D_A_fake = self.criterion_GAN(self.D_A(fake_A), torch.zeros_like(self.D_A(fake_A)))
        loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5

        loss_D_B_real = self.criterion_GAN(self.D_B(real_B), torch.ones_like(self.D_B(real_B)))
        loss_D_B_fake = self.criterion_GAN(self.D_B(fake_B), torch.zeros_like(self.D_B(fake_B)))
        loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5

        return loss_D_A + loss_D_B

    def training_step(self, batch, batch_idx, optimizer_idx):
        real_A, real_B = batch

        if optimizer_idx == 0:
            loss_G = self.generator_step(real_A, real_B)
            self.log('loss_G', loss_G, on_step=True, on_epoch=True)
            return loss_G

        if optimizer_idx == 1:
            loss_D = self.discriminator_step(real_A, real_B)
            self.log('loss_D', loss_D, on_step=True, on_epoch=True)
            return loss_D

    def validation_step(self, batch, batch_idx):
        real_A, real_B = batch
        loss_G = self.generator_step(real_A, real_B)
        self.log('val_loss_G', loss_G)

    def test_step(self, batch, batch_idx):
        real_A, real_B = batch
        fake_B = self.G_AB(real_A)
        fake_A = self.G_BA(real_B)

        save_image((fake_B + 1) / 2, f"test_results/fake_B_{batch_idx}.png")
        save_image((fake_A + 1) / 2, f"test_results/fake_A_{batch_idx}.png")

    def configure_optimizers(self):
        return (
            [self.optimizer_G],
            [self.optimizer_D_A, self.optimizer_D_B]
        )


# Data loading
def get_dataloaders(batch_size=1):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

    dataset_A = datasets.ImageFolder("data/domain_A", transform=transform)
    dataset_B = datasets.ImageFolder("data/domain_B", transform=transform)

    loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True)
    loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True)

    return loader_A, loader_B


if __name__ == "__main__":
    loader_A, loader_B = get_dataloaders()

    model = CycleGAN(input_nc=1, output_nc=1)

    trainer = pl.Trainer(
        max_epochs=100,
        gpus=1 if torch.cuda.is_available() else 0,
        log_every_n_steps=10
    )

    trainer.fit(model, train_dataloaders=loader_A, val_dataloaders=loader_B)
    trainer.test(model, dataloaders=loader_B)


In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import os
from PIL import Image


class UnpairedImageDataset(Dataset):
    """
    Custom Dataset for loading unpaired images from two domains.
    """
    def __init__(self, root_dir_A, root_dir_B, transform=None):
        super(UnpairedImageDataset, self).__init__()
        self.transform = transform

        self.files_A = sorted(os.listdir(root_dir_A))
        self.files_B = sorted(os.listdir(root_dir_B))

        self.root_dir_A = root_dir_A
        self.root_dir_B = root_dir_B

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

    def __getitem__(self, idx):
        # Load an image from domain A
        file_A = self.files_A[idx % len(self.files_A)]
        path_A = os.path.join(self.root_dir_A, file_A)
        image_A = Image.open(path_A).convert("L")  # Single-channel grayscale

        # Load an image from domain B
        file_B = self.files_B[idx % len(self.files_B)]
        path_B = os.path.join(self.root_dir_B, file_B)
        image_B = Image.open(path_B).convert("L")  # Single-channel grayscale

        if self.transform:
            image_A = self.transform(image_A)
            image_B = self.transform(image_B)

        return image_A, image_B


def get_dataloaders(root_dir_A, root_dir_B, batch_size=1, val_split=0.2):
    """
    Prepares training and validation DataLoaders for unpaired image datasets.

    Args:
        root_dir_A: Path to images in domain A.
        root_dir_B: Path to images in domain B.
        batch_size: Batch size for DataLoader.
        val_split: Fraction of the dataset to use for validation.

    Returns:
        train_loader, val_loader
    """
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
    ])

    dataset = UnpairedImageDataset(root_dir_A, root_dir_B, transform=transform)

    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size

    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader


In [None]:
if __name__ == "__main__":
    # Define paths to datasets
    root_dir_A = "data/domain_A"  # Path to images from domain A
    root_dir_B = "data/domain_B"  # Path to images from domain B

    # Create DataLoaders
    train_loader, val_loader = get_dataloaders(root_dir_A, root_dir_B, batch_size=4, val_split=0.2)

    # Instantiate the model
    model = CycleGAN(input_nc=1, output_nc=1)

    # Configure PyTorch Lightning Trainer
    trainer = pl.Trainer(
        max_epochs=100,
        gpus=1 if torch.cuda.is_available() else 0,
        log_every_n_steps=10
    )

    # Train the model
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
