In [None]:
!pip install pytorch-lightning albumentations

In [None]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import albumentations as A
import matplotlib.pyplot as plt
import pytorch_lightning as pl

In [None]:
class ImageNetDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.data_dir = ''
        self.batch_size_train, self.batch_size_valid, self.batch_size_test = 64,64,64


    def setup(self, stage):
        if stage == "fit" or stage is None:
            # LOAD DATA
            self.imagenet_train = ...
            self.imagenet_valid = ...

        if stage == "test" or stage is None:
            # LOAD DATA
            self.imagenet_test = ...

        if stage == "predict" or stage is None:
            # LOAD DATA
            self.imagenet_predict = ...


    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(self.imagenet_train, batch_size=self.batch_size_train, shuffle=True)
        return train_loader

    def val_dataloader(self):
        val_loader = torch.utils.data.DataLoader(self.imagenet_valid, batch_size=self.batch_size_valid, shuffle=False)
        return val_loader

    def test_dataloader(self):
        test_loader = torch.utils.data.DataLoader(self.imagenet_test, batch_size=self.batch_size_test, shuffle=False)
        return test_loader
    
    def predict_dataloader(self):
        predict_loader = torch.utils.data.DataLoader(self.imagenet_predict, batch_size=self.batch_size_test, shuffle=False)
        return predict_loader

In [None]:
# A VOIR PLUS TARD POUR FAIRE CA ET SIMPLIFIER LE CODE
# class ConvBlock(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(ConvBlock, self).__init__()
#         self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=5, padding=2)
#         self.bn = nn.BatchNorm2d(out_channels)
#         self.relu = nn.ReLU()

#     def forward(self, x):
#         x = self.conv(x)
#         x = self.bn(x)
#         x = self.relu(x)
#         return x

# CHANGER PARAMETRES D'ENTREE QUAND ON AURA OPTIMISER LA FONCTION

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),

            # Todo: à voir si on la garde ou non, dans l'article ils disent qu'ils en ont utilisé une, mais pas dans le code
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Couche permettant de passer à l'étape "Channel-wise fully-connected layer"
            nn.Conv2d(512, 4000, kernel_size=4)
        )
        


    def forward_encoder(self, x):
        x = self.encoder(x)

        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4000, 512, kernel_size=4),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            # Dans leur code, il n'y a pas de couche BatchNorm2d et de ReLU après la dernière couche ConvTranspose2d
            # nn.BatchNorm2d(3),
            # nn.ReLU(True),

            # Fonction d'activation finale permettant de normaliser les valeurs entre -1 et 1, présente dans le code mais pas dans l'article
            nn.Tanh()
        )
        


    def forward_decoder(self, x):
        x = self.decoder(x)

        return x

In [None]:
class Bottleneck(nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()

        # A VOIR SI ON AJOUTE LE NOISEGEN

        self.bottleneck = nn.Sequential(
            nn.BatchNorm2d(4000),
            nn.LeakyReLU(0.2, True),
        )

    def forward_bottleneck(self, x):
        x = self.bottleneck(x)

        return x

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.encoder = Encoder()
        self.bottleneck = Bottleneck()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder.forward_encoder(x)
        x = self.bottleneck.forward_bottleneck(x)
        x = self.decoder.forward_decoder(x)

        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.discriminator = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.discriminator(x)
        x = x.view(x.size(0), -1)

        return x