In [None]:
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms


class ImageDataset(Dataset):
    def __init__(self,root,transform=None,mode='train'):
        self.transform= transform

        self.files_A = sorted(glob.glob(os.path.join(root,'%sA' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root,'%sB' % mode) + '/*.*'))

    def __getitem__(self,index):

        item_A= Image.open(self.files_A[index % len(self.files_A)])

        item_B=Image.open(self.files_B[random.randint(0,len(self.files_B)-1)])

        if self.transform:
            item_A=self.transform(item_A)
            item_B=self.transform(item_B)

        return item_A, item_B
    
    def __len__(self):
        return max(len(self.files_A),len(self.files_B))
    

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self,in_features):
        super(ResidualBlock,self).__init__()

        conv_block = [ nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features,in_features,3),
                                nn.InstanceNorm2d(in_features),
                                nn.ReLU(inplace=True),
                                nn.ReflectionPad2d(1),
                                nn.Conv2d(in_features,in_features,3),
                                nn.InstanceNorm2d(in_features)]
        self.conv_block=nn.Sequential(*conv_block)

    def forward(self,x):
        
        return x+self.conv_block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self,input_nc,output_nc,n_residual_blocks=4):
        super(Generator,self).__init__()

        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc,64,7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True)
        ]

        in_features=64
        out_features=in_features*2
        for _ in range(2):
            model += [ nn.Conv2d(in_features,out_features,3,stride=2,padding=1),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True)
            ]
            in_features=out_features
            out_features=in_features*2
        
        for _ in range (n_residual_blocks):
            model+= [ResidualBlock(in_features)]
        
        out_features=in_features//2
        for _ in range(2):
            model+= [nn.ConvTranspose2d(in_features,out_features,3,stride=2,padding=1,output_padding=1),
                     nn.InstanceNorm2d(out_features),
                     nn.ReLU(inplace=True)
                     ] ##BOOOH
            in_features=out_features
            out_features=in_features//2
        
        model+= [ nn.ReflectionPad2d(3),
                 nn.Conv2d(64,output_nc,7),
                 nn.Tanh()
        ]

        self.model=nn.Sequential(*model)

    def forward(self,x):
        return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self,input_nc):
        super(Discriminator,self).__init__()

        model = [ nn.Conv2d(input_nc,64,4,stride=2,padding=1),
                 nn.LeakyReLU(0.2,inplace=True)]
        
        model += [ nn.Conv2d(64,128,4,stride=2,padding=1),
                  nn.InstanceNorm2d(128),
                  nn.LeakyReLU(0.2,inplace=True)]
        
        model += [ nn.Conv2d(128,256,4,stride=2,padding=1),
                  nn.InstanceNorm2d(256),
                  nn.LeakyReLU(0.2,inplace=True)]
        
        model += [ nn.Conv2d(256,512,4,stride=2,padding=1),
                  nn.InstanceNorm2d(512),
                  nn.LeakyReLU(0.2,inplace=True)]
        
        model += [ nn.Conv2d(512,1,4,padding=1)]

        self.model= nn.Sequential(*model)

    def forward(self,x):
        x= self.model(x)
        return F.avg_pool2d(x,x.size()[2:]).view(x.size()[0],-1)

In [None]:
class LambdaLR():
    def __init__(self,n_epochs,offset,decay_start_epoch):
        assert ((n_epochs-decay_start_epoch)>0), "Decay must start before training session ends"
        self.n_epochs=n_epochs
        self.offset=offset
        self.decay_start_epoch=decay_start_epoch
    
    def step(self,epoch):
        return 1.0-max(0,epoch+self.offset-self.decay_start_epoch)/(self.n_epochs-self.decay_start_epoch)

In [None]:
def weights_init_normal(m):
    classname= m.__class__.__name__
    if classname.find('Conv')!=-1:
        m.weight.data.normal_(0,0.02)
    if classname.find('BatchNorm2d')!=-1:
        m.weight.data.normal_(1.0,0.02)
        m.bias.data.fill_(0.0)

In [None]:
import pytorch_lightning as pl
import itertools
from torch.utils.data import DataLoader
import torchvision
 
class CycleGAN(pl.LightningModule):
    def __init__(self,
                 input_nc=3,  # numero di canali in input
                 output_nc=3,  # numero di canali di output
                 image_size=256,  # le dimensioni dell'immagine
                 lr=0.0002,  # learning rate
                 beta1=0.5, beta2=0.999,  # valori beta dell'Adam
                 starting_epoch=0,  # epoca di partenza se si inizia facendo resume di un training interrotto
                 n_epochs=200,  # numero totale di epoche
                 decay_epoch=100,  # epoca dopo la quale iniziare a far scendere il learning rate
                 data_root='your_dataset',  # cartella in cui si trovano i dati
                 batch_size=1,  # batch size
                 num_workers=4):  # numero di thread per il dataloader
 
        super(CycleGAN, self).__init__()
        self.save_hyperparameters()
 
        # Definizione dei generatori: da A a B e da B ad A
        self.netG_A2B = Generator(input_nc, output_nc)
        self.netG_B2A = Generator(output_nc, input_nc)
 
        # Definizione dei discriminatori: uno per A e uno per B
        self.netD_A = Discriminator(input_nc)
        self.netD_B = Discriminator(output_nc)
 
        # Applicazione della normalizzazione
        self.netG_A2B.apply(weights_init_normal)
        self.netG_B2A.apply(weights_init_normal)
        self.netD_A.apply(weights_init_normal)
        self.netD_B.apply(weights_init_normal)
 
        # Definizione delle loss
        self.criterion_GAN = torch.nn.MSELoss()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_identity = torch.nn.L1Loss()
 
        # Disabilitazione dell'ottimizzazione automatica
        self.automatic_optimization = False
 
    def forward(self, x, mode='A2B'):
        if mode == 'A2B':
            return self.netG_A2B(x)
        else:
            return self.netG_B2A(x)
 
    def configure_optimizers(self):
        # Ci servono 3 optimizer, ciascuno con il suo scheduler
        optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A2B.parameters(), self.netG_B2A.parameters()),
                                       lr=self.hparams.lr, betas=(self.hparams.beta1, self.hparams.beta2))
        optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=self.hparams.lr, betas=(self.hparams.beta1, self.hparams.beta2))
        optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=self.hparams.lr, betas=(self.hparams.beta1, self.hparams.beta2))
 
        lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda epoch: 1.0 - max(0, epoch - self.hparams.decay_epoch) / (self.hparams.n_epochs - self.hparams.decay_epoch))
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lambda epoch: 1.0 - max(0, epoch - self.hparams.decay_epoch) / (self.hparams.n_epochs - self.hparams.decay_epoch))
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lambda epoch: 1.0 - max(0, epoch - self.hparams.decay_epoch) / (self.hparams.n_epochs - self.hparams.decay_epoch))
 
        return [optimizer_G, optimizer_D_A, optimizer_D_B], [lr_scheduler_G, lr_scheduler_D_A, lr_scheduler_D_B]
 
    def training_step(self, batch, batch_idx):
        real_A, real_B = batch
        target_real = torch.ones((real_A.shape[0], 1)).type_as(real_A)
        target_fake = torch.zeros((real_A.shape[0], 1)).type_as(real_A)
 
        optimizers = self.optimizers()

        optimizer_G = optimizers[0]
        optimizer_G.zero_grad()

        # Training generator
        same_B = self.netG_B2A(real_B)
        loss_identity = self.criterion_identity(same_B, real_B) * 5.0

        same_A = self.netG_A2B(real_A)
        loss_identity += self.criterion_identity(same_A, real_A) * 5.0

        fake_B = self.netG_A2B(real_A)
        pred_fake = self.netD_B(fake_B)
        loss_GAN_A2B = self.criterion_GAN(pred_fake, target_real)

        fake_A = self.netG_B2A(real_B)
        pred_fake = self.netD_A(fake_A)
        loss_GAN_B2A = self.criterion_GAN(pred_fake, target_real)

        recovered_A = self.netG_B2A(fake_B)
        loss_cycle_ABA = self.criterion_cycle(recovered_A, real_A) * 10.0

        recovered_B = self.netG_A2B(fake_A)
        loss_cycle_BAB = self.criterion_cycle(recovered_B, real_B) * 10.0

        loss_G = loss_identity + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        self.manual_backward(loss_G)
        optimizer_G.step()

        self.log("loss_G/overall", loss_G)

        if batch_idx % 100 == 0:
            grid_A = torchvision.utils.make_grid(real_A[:50], nrow=8, normalize=True)
            grid_A2B = torchvision.utils.make_grid(fake_B[:50], nrow=8, normalize=True)
            grid_ABA = torchvision.utils.make_grid(recovered_A[:50], nrow=8, normalize=True)

            grid_B = torchvision.utils.make_grid(real_B[:50], nrow=8, normalize=True)
            grid_B2A = torchvision.utils.make_grid(fake_A[:50], nrow=8, normalize=True)
            grid_BAB = torchvision.utils.make_grid(recovered_B[:50], nrow=8, normalize=True)

            self.logger.experiment.add_image("A/A", grid_A, self.global_step)
            self.logger.experiment.add_image("A/A2B", grid_A2B, self.global_step)
            self.logger.experiment.add_image("A/ABA", grid_ABA, self.global_step)

            self.logger.experiment.add_image("B/B", grid_B, self.global_step)
            self.logger.experiment.add_image("B/B2A", grid_B2A, self.global_step)
            self.logger.experiment.add_image("B/BAB", grid_BAB, self.global_step)

        # Training discriminator A
        optimizer_D_A = optimizers[1]
        optimizer_D_A.zero_grad()

        fake_real = self.netD_A(real_A)
        loss_D_real = self.criterion_GAN(fake_real, target_real)

        fake_A = self.netG_B2A(real_B).detach()
        pred_fake = self.netD_A(fake_A)
        loss_D_fake = self.criterion_GAN(pred_fake, target_fake)

        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        self.manual_backward(loss_D_A)
        optimizer_D_A.step()

        self.log("loss_D/loss_D_A", loss_D_A)

        # Training discriminator B
        optimizer_D_B = optimizers[2]
        optimizer_D_B.zero_grad()

        loss_D_real = self.criterion_GAN(self.netD_B(real_B), target_real)

        fake_B = self.netG_A2B(real_A).detach()
        loss_D_fake = self.criterion_GAN(self.netD_B(fake_B), target_fake)

        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        self.manual_backward(loss_D_B)
        optimizer_D_B.step()

        self.log("loss_D/loss_D_B", loss_D_B)

    def train_dataloader(self):
        transform=transforms.Compose([
            transforms.Resize((640,640),Image.BICUBIC),
            transforms.RandomCrop(self.hparams.image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) ## Ma qui non dovrei calcolare media e dev std dei dati?
        ])

        dataloader = DataLoader(ImageDataset(self.hparams.data_root, transform=transform), 
                                batch_size=self.hparams.batch_size, 
                                shuffle=True, 
                                num_workers=4)
        return dataloader

In [None]:
mod=CycleGAN.load_from_checkpoint('./models/cycleGAN/epoch=4-step=21270.ckpt')

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np

# Classe per caricare immagini da più cartelle
class ImageDataset(Dataset):
    def __init__(self, image_dirs, transform=None):
        self.image_files = []  # Lista completa dei percorsi e nomi dei file
        for image_dir in image_dirs:
            files = sorted(os.listdir(image_dir))
            self.image_files.extend([(os.path.join(image_dir, f), f) for f in files])  # Salva percorso e nome file

        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path, filename = self.image_files[idx]
        image = Image.open(image_path).convert("RGB")  # Carica immagine

        if self.transform:
            image = self.transform(image)

        return image, filename  # Restituisce immagine e nome file originale

# Definisci le cartelle delle immagini
folder_A = "./datasets/fullSynth/train/images"  
folder_B = "./datasets/fullSynth/val/images"  

# Trasformazioni per il dataset
transform = transforms.Compose([
    transforms.ToTensor(),  # Converti in tensore
])

# Carica il dataset unendo le immagini delle due cartelle
dset = ImageDataset([folder_A, folder_B], transform)

# Cartella di output
output_dir = "adaptedSynth"
os.makedirs(output_dir, exist_ok=True)

# Funzione per salvare le immagini trasformate
def save_adapted_images():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mod.netG_A2B.to(device)

    for i in range(len(dset)):
        image_in, filename = dset[i]  # Ottieni immagine e nome file
        image_in = image_in.to(device)

        a2b = mod.netG_A2B(image_in.unsqueeze(0))

        # Converti l'immagine in formato salvabile
        a2b = a2b.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
        a2b -= a2b.min()
        a2b /= a2b.max()
        a2b = (a2b * 255).astype(np.uint8)

        # Salva l'immagine trasformata con lo stesso nome originale
        output_path = os.path.join(output_dir, filename)
        Image.fromarray(a2b).save(output_path)

        #print(f"Salvata: {output_path}")

save_adapted_images()
