In [None]:
!pip install pytorch-lightning albumentations

In [None]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import albumentations as A
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import os

In [None]:
def rebuild_image(full_image, center_image):
    full_image[32:96, 32:96, :] = center_image
    return full_image

In [None]:
def extract_center_image(image):
    return image[32:96, 32:96, :]

In [None]:
def load_data(data_dir):
    images_completes = []
    images_incompletes = []

    for file in os.listdir(data_dir + "/image_complète"):
        img = cv2.imread(os.path.join(data_dir, file))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        images_completes.append(img)
    
    for file in os.listdir(data_dir + "/image_incomplète"):
        img = cv2.imread(os.path.join(data_dir, file))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        images_incompletes.append(img)
    
    return images_completes, images_incompletes

In [None]:
class ImageNetDataset(torch.utils.data.Dataset):
    def __init__(self, image_full, center_image, transform=None):
        self.image_full = image_full
        self.center_image = center_image
        self.transform = transform

    def __len__(self):
        return len(self.image_full)

    def __getitem__(self, idx):
        image_full = self.image_full[idx]
        center_image = self.center_image[idx]

        if self.transform:
            image_full = self.transform(image_full)
            center_image = self.transform(center_image)

        return image_full, center_image

In [None]:
class ImageNetDataModule(pl.LightningDataModule):
    def __init__(self, batch_size = 64):
        super().__init__()
        self.train_data_dir = 'data/train/'
        self.val_data_dir = 'data/validation/'
        self.test_data_dir = 'data/test/'
        # self.predict_data_dir = 'data/'
        self.batch_size_train, self.batch_size_valid, self.batch_size_test = batch_size,batch_size,batch_size


    def setup(self, stage):
        if stage == "fit" or stage is None:
            images_complètes_train, images_incomplètes_train = load_data(self.train_data_dir)
            images_complètes_valid, images_incomplètes_valid = load_data(self.val_data_dir)

            center_image_train = [extract_center_image(img) for img in images_complètes_train]
            center_image_valid = [extract_center_image(img) for img in images_complètes_valid]

            self.imagenet_train = ImageNetDataset(images_incomplètes_train, center_image_train)
            self.imagenet_valid = ImageNetDataset(images_incomplètes_valid, center_image_valid)

        if stage == "test" or stage is None:
            images_complètes_test, images_incomplètes_test = load_data(self.test_data_dir)

            center_image_test = [extract_center_image(img) for img in images_complètes_test]

            self.imagenet_test = ImageNetDataset(images_incomplètes_test, center_image_test)

        # if stage == "predict" or stage is None:
        #     # LOAD DATA
        #     self.imagenet_predict = ...


    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(self.imagenet_train, batch_size=self.batch_size_train, shuffle=True)
        return train_loader

    def val_dataloader(self):
        val_loader = torch.utils.data.DataLoader(self.imagenet_valid, batch_size=self.batch_size_valid, shuffle=False)
        return val_loader

    def test_dataloader(self):
        test_loader = torch.utils.data.DataLoader(self.imagenet_test, batch_size=self.batch_size_test, shuffle=False)
        return test_loader
    
    # def predict_dataloader(self):
    #     predict_loader = torch.utils.data.DataLoader(self.imagenet_predict, batch_size=self.batch_size_test, shuffle=False)
    #     return predict_loader

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, learning_rate=0.2, inplace=True):
        super(EncoderBlock, self).__init__()

        self.encoderblock = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(learning_rate, inplace),
        )
    
    def forward(self, x):
        return self.encoderblock(x)

In [None]:
# CHANGER PARAMETRES D'ENTREE QUAND ON AURA OPTIMISER LA FONCTION

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            # EncoderBlock(64, 64),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            # EncoderBlock(64, 128),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            # EncoderBlock(128, 256),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # EncoderBlock(256, 512),

            # Todo: à voir si on la garde ou non, dans l'article ils disent qu'ils en ont utilisé une, mais pas dans le code
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Couche permettant de passer à l'étape "Channel-wise fully-connected layer"
            nn.Conv2d(512, 4000, kernel_size=4)
        )

        
        
    def forward_encoder(self, x):
        x = self.encoder(x)

        return x

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, inplace=True):
        super(DecoderBlock, self).__init__()

        self.decoderblock = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace),
        )
    
    def forward(self, x):
        return self.decoderblock(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(
            
            nn.ConvTranspose2d(4000, 512, kernel_size=4),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # DecoderBlock(512, 256),
            
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # DecoderBlock(256, 128),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # DecoderBlock(128, 64),
            
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            # Dans leur code, il n'y a pas de couche BatchNorm2d et de ReLU après la dernière couche ConvTranspose2d
            # nn.BatchNorm2d(3),
            # nn.ReLU(True),

            # Fonction d'activation finale permettant de normaliser les valeurs entre -1 et 1, présente dans le code mais pas dans l'article
            nn.Tanh()
        )
        


    def forward_decoder(self, x):
        x = self.decoder(x)

        return x

In [None]:
class Bottleneck(nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()

        # A VOIR SI ON AJOUTE LE NOISEGEN

        self.bottleneck = nn.Sequential(
            nn.BatchNorm2d(4000),
            nn.LeakyReLU(0.2, True),
        )

    def forward_bottleneck(self, x):
        x = self.bottleneck(x)

        return x

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.encoder = Encoder()
        self.bottleneck = Bottleneck()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder.forward_encoder(x)
        x = self.bottleneck.forward_bottleneck(x)
        x = self.decoder.forward_decoder(x)

        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.discriminator = nn.Sequential(

            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            # EncoderBlock(3, 64),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            # EncoderBlock(64, 128),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            # EncoderBlock(128, 256),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()

        )

    def forward(self, x):
        x = self.discriminator(x)
        x = x.view(x.size(0), -1)

        return x

In [None]:
class GeneratorModule(pl.LightningModule):
    def __init__(self, lambda_rec = 0.999, lambda_adv = 0.001):
        super(GeneratorModule, self).__init__()

        self.generator = Generator()
        self.discriminator = Discriminator()
        
        self.lambda_rec = lambda_rec
        self.lambda_adv = lambda_adv

        # A REVOIR
        # DANS LEUR CODE, ILS ONT MIS CA, MAIS PAS DANS L'ARTICLE
        # DANS L'ARTICLE, ILS UTILISENT DES LOSS LEGEREMENT DIFFERENTES, QUI PERMETTENT D'AVOIR DE LA PONDÉRATION
        self.criterion_rec = nn.MSELoss()
        self.criterion_adv = nn.BCELoss()
    
    def forward(self, x):
        return self.generator(x)
    
    def configure_optimizers(self):
        optimizerD = torch.optim.Adam(self.discriminator.parameters(), lr = 0.0002, betas = (0.5, 0.999))
        optimizerG = torch.optim.Adam(self.generator.parameters(), lr = 0.002, betas = (0.5, 0.999))
        return [optimizerD, optimizerG], []
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        full_images, center_images = batch
        batch_size = full_images.size(0)
        valid = torch.ones(batch_size, 1)
        fake = torch.zeros(batch_size, 1)

        # RAJOUTER DU BRUIT ? AMELIORE L'ENTRAINEMENT ? 
        # --> z = torch.randn(batch_size, 100, 1, 1, device=self.device)
        # --> fake = self.generator(z)
        # CAS DISCRIMINATOR
        if optimizer_idx == 0:
            fake_images = self.generator(full_images)

            pred_discriminator = self.discriminator(center_images)
            fake_discriminator = self.discriminator(fake_images.detach())
            
            real_loss = self.criterion_adv(pred_discriminator, valid)
            fake_loss = self.criterion_adv(fake_discriminator, fake)

            # VOIR COMMENT EXPLIQUER CETTE LOSS ? EST CE QU'ELLE EST DANS LE PAPIER ?
            d_loss = (real_loss + fake_loss) / 2

            self.log('d_loss', d_loss, prog_bar=True)

            return d_loss

        # CAS GENERATOR         
        if optimizer_idx == 1:
            fake_images = self.generator(full_images)

            pred_discriminator = self.discriminator(fake_images)

            loss_rec = self.criterion_rec(fake_images, center_images)
            loss_adv = self.criterion_adv(pred_discriminator, valid)

            joint_loss = loss_rec*self.lambda_rec + loss_adv*self.lambda_adv

            self.log('joint_loss', joint_loss, prog_bar=True)

            return joint_loss
    
    def test_step(self, batch, batch_idx):
        full_images, center_images = batch
        batch_size = full_images.size(0)
        valid = torch.ones(batch_size, 1)

        fake_images = self.generator(full_images)

        pred_discriminator = self.discriminator(fake_images)

        loss_rec = self.criterion_rec(fake_images, center_images)
        loss_adv = self.criterion_adv(pred_discriminator, valid)

        joint_loss = loss_rec*self.lambda_rec + loss_adv*self.lambda_adv

        self.log('joint_loss', joint_loss, prog_bar=True)

        return joint_loss
    
    # AJOUTER validation_step ?
    # AJOUTER prediction_step ?

: 

In [None]:
def display_examples(model, dataloader, num_examples = 5):
    model.eval()
    with torch.no_grad():
        for full_image, center_image in dataloader:
            
            center_pred = model(full_image)

            for i in range(num_examples):

                full_image_true = rebuild_image(full_image[i], center_image[i])
                full_image_pred = rebuild_image(full_image[i], center_pred[i])

                plt.figure(figsize=(10, 5))
                
                plt.subplot(1, 2, 1)
                plt.imshow(full_image_true)
                plt.title("True Image")

                plt.subplot(1, 2, 2)
                plt.imshow(full_image_pred)
                plt.title("Fake Image")

                plt.show()

            break

In [None]:
def save_images(model, dataloader, output_dir = 'output_directory'):
    model.eval()

    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        for full_image, _ in dataloader:

            center_pred = model(full_image)

            for i in range(full_image.size(0)):

                full_image_pred = rebuild_image(full_image[i], center_pred[i])

                cv2.imwrite(f"{output_dir}/fake_image_{i}.jpg", cv2.cvtColor(full_image_pred, cv2.COLOR_RGB2BGR))

In [None]:
batch_size = 64
max_epochs = 10
 
train_loader = ImageNetDataModule(batch_size)
model = GeneratorModule()

# REVOIR LES PARAMETRES
# METTRE LES PARAMETRES DANS UN BLOC A PART, UN PEU COMME DES VARIABLES GLOBALES ?
trainer = pl.Trainer(max_epochs, devices=-1)
trainer.fit(model, train_loader)