In [None]:
# !wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz
# !tar -xvf facades.tar.gz

# https://github.com/aniketmaurya/pytorch-gans/tree/main

In [None]:
# !pip install pytorch-lightning==1.5.7
# !pip install torchtext

In [None]:
# !pip install pytorch-lightning

In [None]:
import os
from glob import glob
from pathlib import Path

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms.functional import center_crop
from torchvision.utils import make_grid
from tqdm.auto import tqdm

In [None]:
path = "./facades/train/"


class FacadesDataset(Dataset):
    def __init__(self, path, target_size=None):
        self.filenames = glob(str(Path(path) / "*"))
        self.target_size = target_size

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        image = Image.open(filename)
        image = transforms.functional.to_tensor(image)
        image_width = image.shape[2]

        real = image[:, :, : image_width // 2]
        condition = image[:, :, image_width // 2 :]

        target_size = self.target_size
        if target_size:
            condition = nn.functional.interpolate(condition, size=target_size)
            real = nn.functional.interpolate(real, size=target_size)

        return real, condition

In [None]:
class DownSampleConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
        """
        Paper details:
        - C64-C128-C256-C512-C512-C512-C512-C512
        - All convolutions are 4×4 spatial filters applied with stride 2
        - Convolutions in the encoder downsample by a factor of 2
        """
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm

        self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm:
            x = self.bn(x)
        if self.activation:
            x = self.act(x)
        return x

In [None]:
class UpSampleConv(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=4,
        strides=2,
        padding=1,
        activation=True,
        batchnorm=True,
        dropout=False
    ):
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm
        self.dropout = dropout

        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.ReLU(True)

        if dropout:
            self.drop = nn.Dropout2d(0.5)

    def forward(self, x):
        x = self.deconv(x)
        if self.batchnorm:
            x = self.bn(x)

        if self.dropout:
            x = self.drop(x)
        return x

In [None]:
class Generator(nn.Module):

    def __init__(self, in_channels, out_channels):
        """
        Paper details:
        - Encoder: C64-C128-C256-C512-C512-C512-C512-C512
        - All convolutions are 4×4 spatial filters applied with stride 2
        - Convolutions in the encoder downsample by a factor of 2
        - Decoder: CD512-CD1024-CD1024-C1024-C1024-C512 -C256-C128
        """
        super().__init__()

        # encoder/donwsample convs
        self.encoders = [
            DownSampleConv(in_channels, 64, batchnorm=False),  # bs x 64 x 128 x 128
            DownSampleConv(64, 128),  # bs x 128 x 64 x 64
            DownSampleConv(128, 256),  # bs x 256 x 32 x 32
            DownSampleConv(256, 512),  # bs x 512 x 16 x 16
            DownSampleConv(512, 512),  # bs x 512 x 8 x 8
            DownSampleConv(512, 512),  # bs x 512 x 4 x 4
            DownSampleConv(512, 512),  # bs x 512 x 2 x 2
            DownSampleConv(512, 512, batchnorm=False),  # bs x 512 x 1 x 1
        ]

        # decoder/upsample convs
        self.decoders = [
            UpSampleConv(512, 512, dropout=True),  # bs x 512 x 2 x 2
            UpSampleConv(1024, 512, dropout=True),  # bs x 512 x 4 x 4
            UpSampleConv(1024, 512, dropout=True),  # bs x 512 x 8 x 8
            UpSampleConv(1024, 512),  # bs x 512 x 16 x 16
            UpSampleConv(1024, 256),  # bs x 256 x 32 x 32
            UpSampleConv(512, 128),  # bs x 128 x 64 x 64
            UpSampleConv(256, 64),  # bs x 64 x 128 x 128
        ]
        self.decoder_channels = [512, 512, 512, 512, 256, 128, 64]
        self.final_conv = nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

        self.encoders = nn.ModuleList(self.encoders)
        self.decoders = nn.ModuleList(self.decoders)

    def forward(self, x):
        skips_cons = []
        for encoder in self.encoders:
            x = encoder(x)

            skips_cons.append(x)

        skips_cons = list(reversed(skips_cons[:-1]))
        decoders = self.decoders[:-1]

        for decoder, skip in zip(decoders, skips_cons):
            x = decoder(x)
            # print(x.shape, skip.shape)
            x = torch.cat((x, skip), axis=1)

        x = self.decoders[-1](x)
        # print(x.shape)
        x = self.final_conv(x)
        return self.tanh(x)

In [None]:
class PatchGAN(nn.Module):

    def __init__(self, input_channels):
        super().__init__()
        self.d1 = DownSampleConv(input_channels, 64, batchnorm=False)
        self.d2 = DownSampleConv(64, 128)
        self.d3 = DownSampleConv(128, 256)
        self.d4 = DownSampleConv(256, 512)
        self.final = nn.Conv2d(512, 1, kernel_size=1)

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.d1(x)
        x1 = self.d2(x0)
        x2 = self.d3(x1)
        x3 = self.d4(x2)
        xn = self.final(x3)
        return xn

In [None]:
adversarial_loss = nn.BCEWithLogitsLoss()

reconstruction_loss = nn.L1Loss()

In [None]:
# https://stackoverflow.com/questions/49433936/how-to-initialize-weights-in-pytorch

import torch
import matplotlib.pyplot as plt  # Import for using plt


def _weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

def display_progress(cond, real, fake, figsize=(10,5)):
    cond = cond.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)
    real = real.detach().cpu().permute(1, 2, 0)

    fig, ax = plt.subplots(1, 3, figsize=figsize)
    ax[0].imshow(cond)
    ax[1].imshow(real)
    ax[2].imshow(fake)
    plt.show()

# Old

In [None]:
# Pix2Pix, CycleGAN, AttentionGAN (https://github.com/Ha0Tang/AttentionGAN?tab=readme-ov-file)




class Pix2Pix(pl.LightningModule):

    def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200, display_step=25):

        super().__init__()
        self.save_hyperparameters()

        self.automatic_optimization = False

        self.display_step = display_step
        self.gen = Generator(in_channels, out_channels)
        self.patch_gan = PatchGAN(in_channels + out_channels)

        # intializing weights
        self.gen = self.gen.apply(_weights_init)
        self.patch_gan = self.patch_gan.apply(_weights_init)

        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.recon_criterion = nn.L1Loss()

    def _gen_step(self, real_images, conditioned_images):
        # Pix2Pix has adversarial and a reconstruction loss
        # First calculate the adversarial loss
        fake_images = self.gen(conditioned_images)
        disc_logits = self.patch_gan(fake_images, conditioned_images)
        adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))

        # calculate reconstruction loss
        recon_loss = self.recon_criterion(fake_images, real_images)
        lambda_recon = self.hparams.lambda_recon

        return adversarial_loss + lambda_recon * recon_loss

    def _disc_step(self, real_images, conditioned_images):
        fake_images = self.gen(conditioned_images).detach()
        fake_logits = self.patch_gan(fake_images, conditioned_images)

        real_logits = self.patch_gan(real_images, conditioned_images)

        fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
        real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
        return (real_loss + fake_loss) / 2

    def configure_optimizers(self):
        lr = self.hparams.learning_rate
        gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr)
        disc_opt = torch.optim.Adam(self.patch_gan.parameters(), lr=lr)
        return disc_opt, gen_opt

    ###########################################
    def validation_step(self, batch, batch_idx):
            '''
            Hàm này nhận vào một batch
            lấy kết quả của mô hình và tính loss

            '''
            from torchmetrics.functional.image import structural_similarity_index_measure
            from torchmetrics.functional.image import peak_signal_noise_ratio

            # sketch_img, photo_img = batch
            real, condition = batch
            # outputs = self.G_basestyle(sketch_img)
            fake = self.gen(condition).detach()
            # ssim = self.default_evaluator.run([[fake, real]]).metrics['ssim']
            ssim = structural_similarity_index_measure(fake, real)
            psnr = peak_signal_noise_ratio(fake, real)
            self.log("SSIM_val", ssim)
            print(f"SSIM_val: {ssim}")
            self.log("PSNR_val", psnr)
            print(f"PSNR_val: {psnr}")
            # self.log("SSIM_valid", ssim)
            # return ssim

    def test_step(self, batch, batch_idx):
            
            from torchmetrics.functional.image import structural_similarity_index_measure
            from torchmetrics.functional.image import peak_signal_noise_ratio

            # sketch_img, photo_img = batch
            real, condition = batch
            # outputs = self.G_basestyle(sketch_img)
            fake = self.gen(condition).detach()
            # ssim = self.default_evaluator.run([[fake, real]]).metrics['ssim']
            ssim = structural_similarity_index_measure(fake, real)
            psnr = peak_signal_noise_ratio(fake, real)
            # self.log("SSIM_train", ssim)
            self.log("SSIM_heldout_set", ssim)
            print(f"SSIM_heldout_set: {ssim}")
            self.log("PSNR_val", psnr)
            print(f"PSNR_val: {psnr}")
            # return ssim

    ###########################################

    def training_step(self, batch, batch_idx):
        real, condition = batch

        gen_opt, disc_opt = self.optimizers()


        # loss = None
        # if batch_idx % 2 == 0:  # Train discriminator on even batches
        #   gen_opt.zero_grad()   
        #   loss = self._gen_step(real, condition)
        #   self.manual_backward(loss)
        #   # self.manual_backward(loss, gen_opt)
        #   gen_opt.step()
        #   self.untoggle_optimizer(gen_opt)
        #   print('Generator Loss', loss)

        # else:
        #   disc_opt.zero_grad()
        #   loss = self._disc_step(real, condition)
        #   self.manual_backward(loss)
        #   # self.manual_backward(loss, disc_opt)
        #   disc_opt.step()
        #   self.untoggle_optimizer(disc_opt)
        #   print('PatchGAN Loss', loss)

        loss = None
        if batch_idx % 2 == 0:  # Train discriminator on even batches
          # disc_opt.zero_grad()
          loss = self._gen_step(real, condition)
        #   loss = self._disc_step(real, condition)
          self.manual_backward(loss)
          disc_opt.step()
          # print('PatchGAN Loss', loss)
        else:
          # gen_opt.zero_grad()
        #   loss = self._gen_step(real, condition)
          loss = self._disc_step(real, condition)
          self.manual_backward(loss)
          gen_opt.step()
          # print('Generator Loss', loss)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        if self.current_epoch%self.display_step==0 and batch_idx==0:
            print('PatchGAN Loss', loss)
            print('Generator Loss', loss)
            fake = self.gen(condition).detach()
            display_progress(condition[0], fake[0], real[0])
            # self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            
        return loss

In [None]:
# target_size = None
# lr=0.0002
# lambda_recon=200
# batch_size = 128

# display_step = True
# dataset = FacadesDataset(path, target_size=target_size)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# pix2pix = Pix2Pix(3, 3, learning_rate=lr, lambda_recon=lambda_recon, display_step=display_step)
# trainer = pl.Trainer(max_epochs=2000)
# trainer.fit(pix2pix, dataloader)     

# Dataset check experiment

In [None]:
dataset

In [None]:
for batch_orig in dataloader:
    # print(batch)
    break

In [None]:
batch[0].shape

In [None]:
from DataLoaderManager import DataLoaderManager
from torchvision import transforms


# Example usage:
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize the image
    transforms.ToTensor(),           # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

data_loader_manager = DataLoaderManager(root_dir='SIDD_Small_sRGB_Only/SIDD_Small_sRGB_Only/Data/', transform=transform)
dataloader, val_dataloader = data_loader_manager.process_dataloaders(batch_size=128, shuffle=True)

In [None]:
for batch in dataloader:
    # print(batch)
    break

In [None]:
len(batch_orig), batch_orig[0].shape

In [None]:
len(batch), batch[0].shape

# Dataset Experiment Check end

In [None]:
# Make a held out set for each set of Dataset
# Make npy file of that
# Make sure that is not being used again    

# THIS IS TO CREATE THE HELDOUT LIST
# import glob
# from sklearn.model_selection import train_test_split


# root_dir = 'SIDD_Small_sRGB_Only/SIDD_Small_sRGB_Only/Data/*'
# dataset =  glob.glob(root_dir)

# train_list, test_list = train_test_split(dataset, test_size=0.2, random_state=42)
# heldout_list = [os.path.basename(folder_name) for folder_name in test_list]
# print(heldout_list)



# folder_list = os.listdir(root_dir) 
# # - heldout_list
# heldout_list = ['0135_006_IP_00400_00400_5500_N', '0138_006_IP_00100_00100_3200_L', '0180_008_GP_00100_00100_5500_N', '0070_003_IP_02000_04000_3200_N', '0121_006_N6_03200_01000_3200_L', '0035_002_GP_00800_00350_3200_N', '0130_006_GP_00400_00400_4400_N', '0065_003_GP_10000_08460_4400_N', '0129_006_GP_00100_00100_4400_N', '0181_008_GP_00800_00800_5500_N', '0022_001_N6_00100_00060_5500_N', '0108_005_GP_06400_06400_4400_N', '0017_001_GP_00100_00060_5500_N', '0086_004_GP_00100_00100_5500_L', '0029_001_IP_00800_01000_5500_N', '0036_002_GP_06400_03200_3200_N', '0164_007_IP_00400_00400_3200_N', '0188_008_IP_00100_00100_3200_N', '0126_006_S6_00400_00200_4400_L', '0018_001_GP_00100_00160_5500_L', '0097_005_N6_03200_02000_3200_L', '0020_001_GP_00800_00350_5500_N', '0014_001_S6_03200_01250_3200_N', '0011_001_S6_00800_00500_5500_L', '0038_002_GP_00800_00640_3200_L', '0192_009_IP_00100_00200_3200_N', '0125_006_S6_00100_00050_4400_L', '0072_003_IP_01000_02000_5500_L', '0167_008_N6_00100_00050_4400_L', '0134_006_IP_00100_00100_5500_N', '0173_008_G4_00400_00400_4400_N', '0101_005_S6_00100_00050_4400_L']
# new_list = [img_name for img_name in folder_list if img_name not in heldout_list]

# len(new_list), len(folder_list), len(heldout_list)

In [None]:
## CREATE DATASETS


# from DataLoaderManager import DataLoaderManager
# from torchvision import transforms

# batch_size = 128
# # Example usage:
# transform = transforms.Compose([
#     transforms.Resize((256, 256)),  # Resize the image
#     transforms.ToTensor(),          # Convert to tensor
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
# ])


# root_dir='SIDD_Small_sRGB/SIDD_Small_sRGB/Data/'
# base_filename = 'SIDD_Small_sRGB'

# data_loader_manager = DataLoaderManager(root_dir=root_dir, train_file='heldout_test',test_file=base_filename,make_held_out_set=True,transform=transform)
# dataloader, val_dataloader = data_loader_manager.process_dataloaders(batch_size=batch_size, shuffle=True)
# print(f"Heldout Dataset Size: {len(dataloader.dataset)} \n\n\n")



# data_loader_manager = DataLoaderManager(root_dir=root_dir, train_file=base_filename,test_file=base_filename,make_held_out_set=False,transform=transform)
# dataloader, val_dataloader = data_loader_manager.process_dataloaders(batch_size=batch_size, shuffle=True)
# print(f"Train Dataset Size: {len(dataloader.dataset)}")
# print(f"Test Dataset Size: {len(val_dataloader.dataset)} \n\n\n")

# root_dir='SIDD_Medium_sRGB/SIDD_Medium_sRGB/Data/'
# base_filename = 'SIDD_Medium_sRGB'

# data_loader_manager = DataLoaderManager(root_dir=root_dir, train_file=base_filename,test_file=base_filename,make_held_out_set=False,transform=transform)
# dataloader, val_dataloader = data_loader_manager.process_dataloaders(batch_size=batch_size, shuffle=True)
# print(f"Train Dataset Size: {len(dataloader.dataset)}")
# print(f"Test Dataset Size: {len(val_dataloader.dataset)}")

# Small Dataset

In [13]:
from DataLoaderManager import DataLoaderManager
from torchvision import transforms

batch_size = 128
# Example usage:
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the image
    transforms.ToTensor(),          # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])


target_size = None
lr=0.0002
lambda_recon=200
# batch_size = 128


display_step = 250

root_dir='SIDD_Small_sRGB/SIDD_Small_sRGB/Data/'
base_filename = 'SIDD_Small_sRGB'

data_loader_manager = DataLoaderManager(root_dir=root_dir, train_file=base_filename,test_file=base_filename,make_held_out_set=False,transform=transform)
dataloader, val_dataloader = data_loader_manager.process_dataloaders(batch_size=batch_size, shuffle=True)

print(f"Train Dataset Size: {len(dataloader.dataset)}")
print(f"Test Dataset Size: {len(val_dataloader.dataset)}")


pix2pix = Pix2Pix(3, 3, learning_rate=lr, lambda_recon=lambda_recon, display_step=display_step)
trainer = pl.Trainer(max_epochs=1)
trainer.fit(pix2pix, dataloader)

In [None]:
# automatically auto-loads the best weights from the previous run
# trainer.validate(dataloaders=val_dataloader,ckpt_path='lightning_logs/version_11/checkpoints/epoch=0-step=1.ckpt')
trainer.validate(dataloaders=val_dataloader)

# Medium Dataset

In [None]:
target_size = None
lr=0.0002
lambda_recon=200
# batch_size = 128


display_step = 250

root_dir='SIDD_Medium_sRGB/SIDD_Medium_sRGB/Data/'
base_filename = 'SIDD_Medium_sRGB'

data_loader_manager = DataLoaderManager(root_dir=root_dir, train_file=base_filename,test_file=base_filename,make_held_out_set=False,transform=transform)
dataloader, val_dataloader = data_loader_manager.process_dataloaders(batch_size=batch_size, shuffle=True)

pix2pix = Pix2Pix(3, 3, learning_rate=lr, lambda_recon=lambda_recon, display_step=display_step)
trainer = pl.Trainer(max_epochs=2000)
trainer.fit(pix2pix, dataloader)     

In [None]:
# automatically auto-loads the best weights from the previous run
trainer.test(dataloaders=val_dataloader)

# Reference

In [None]:
target_size = None
lr=0.0002
lambda_recon=200
# batch_size = 128


display_step = True
# dataset = FacadesDataset(path, target_size=target_size)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

pix2pix = Pix2Pix(3, 3, learning_rate=lr, lambda_recon=lambda_recon, display_step=display_step)
trainer = pl.Trainer(max_epochs=2000)
trainer.fit(pix2pix, dataloader)    

In [None]:
for batch in dataloader:
    break

In [None]:
img_val = 100

for img_val in range(1,20):
    display_progress(batch[0][img_val], batch[0][img_val], batch[1][img_val])

In [None]:
batch[0][0].shape

In [None]:
batch[1][0].shape