In [None]:
!pip install pytorch-lightning albumentations

In [None]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import albumentations as A
import matplotlib.pyplot as plt
import pytorch_lightning as pl

In [None]:
class ImageNetDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.data_dir = ''
        self.batch_size_train, self.batch_size_valid, self.batch_size_test = 64,64,64


    def setup(self, stage):
        if stage == "fit" or stage is None:
            # LOAD DATA
            self.imagenet_train = ...
            self.imagenet_valid = ...

        if stage == "test" or stage is None:
            # LOAD DATA
            self.imagenet_test = ...

        if stage == "predict" or stage is None:
            # LOAD DATA
            self.imagenet_predict = ...


    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(self.imagenet_train, batch_size=self.batch_size_train, shuffle=True)
        return train_loader

    def val_dataloader(self):
        val_loader = torch.utils.data.DataLoader(self.imagenet_valid, batch_size=self.batch_size_valid, shuffle=False)
        return val_loader

    def test_dataloader(self):
        test_loader = torch.utils.data.DataLoader(self.imagenet_test, batch_size=self.batch_size_test, shuffle=False)
        return test_loader
    
    def predict_dataloader(self):
        predict_loader = torch.utils.data.DataLoader(self.imagenet_predict, batch_size=self.batch_size_test, shuffle=False)
        return predict_loader

In [None]:
# A VOIR PLUS TARD POUR FAIRE CA ET SIMPLIFIER LE CODE
# class ConvBlock(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(ConvBlock, self).__init__()
#         self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=5, padding=2)
#         self.bn = nn.BatchNorm2d(out_channels)
#         self.relu = nn.ReLU()

#     def forward(self, x):
#         x = self.conv(x)
#         x = self.bn(x)
#         x = self.relu(x)
#         return x

# CHANGER PARAMETRES D'ENTREE QUAND ON AURA OPTIMISER LA FONCTION

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),

            # Todo: à voir si on la garde ou non, dans l'article ils disent qu'ils en ont utilisé une, mais pas dans le code
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Couche permettant de passer à l'étape "Channel-wise fully-connected layer"
            nn.Conv2d(512, 4000, kernel_size=4)
        )
        


    def forward_encoder(self, x):
        x = self.encoder(x)

        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4000, 512, kernel_size=4),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            # Dans leur code, il n'y a pas de couche BatchNorm2d et de ReLU après la dernière couche ConvTranspose2d
            # nn.BatchNorm2d(3),
            # nn.ReLU(True),

            # Fonction d'activation finale permettant de normaliser les valeurs entre -1 et 1, présente dans le code mais pas dans l'article
            nn.Tanh()
        )
        


    def forward_decoder(self, x):
        x = self.decoder(x)

        return x

In [None]:
class Bottleneck(nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()

        # A VOIR SI ON AJOUTE LE NOISEGEN

        self.bottleneck = nn.Sequential(
            nn.BatchNorm2d(4000),
            nn.LeakyReLU(0.2, True),
        )

    def forward_bottleneck(self, x):
        x = self.bottleneck(x)

        return x

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.encoder = Encoder()
        self.bottleneck = Bottleneck()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder.forward_encoder(x)
        x = self.bottleneck.forward_bottleneck(x)
        x = self.decoder.forward_decoder(x)

        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.discriminator = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.discriminator(x)
        x = x.view(x.size(0), -1)

        return x

In [None]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MNISTClassifier, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes

        # TODO : Define your model here
        self.mlp = MLP(input_size, hidden_size, num_classes)

        self.acc = 0

    def forward(self,x):
        # TODO : What would be the forward steps of this classifier ?
        x = self.mlp(x)
        return x

    def configure_optimizers(self):
        # TODO : Choose your optimizer : https://pytorch.org/docs/stable/optim.html
        optimizer = optim.SGD(self.parameters(), lr=0.01, momentum=0.9)
        return optimizer

    def training_step(self, batch, batch_idx):
        # TODO : Define your Training Step
        # This method is pretty much similar to what your did in the Tutorial to train your model.
        x,y = batch
        preds = self(x)
        loss = nn.CrossEntropyLoss()(preds, y)
        acc = torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes).to(preds.device)(preds, y)

        # Don't remove the next line, you will understand why later
        self.log('train_acc', acc)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # TODO : Define your Validation Step
        # What is the difference between the Training and the Validation Step ?
        x,y = batch
        preds = self(x)
        loss = nn.CrossEntropyLoss()(preds, y)
        acc = torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes).to(preds.device)(preds, y)

        # Don't remove the next line, you will understand why later
        self.log('val_acc', acc)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        # TODO : Define your Test Step
        # What is the difference between the Training, Validation and Test Step ?
        x,y = batch
        preds = self(x)
        loss = nn.CrossEntropyLoss()(preds, y)
        self.acc = self.acc + torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes).to(preds.device)(preds, y) # We accumulate every accuracy

        # Don't remove the next line, you will understand why later
        self.log('test_loss', loss)
        self.log('test_acc', self.acc)

    def test_epoch_start(self):
        self.acc = 0

    def on_test_epoch_end(self):
        self.acc = self.acc / len(self.trainer.datamodule.test_dataloader())
        self.log('Final Accuracy', self.acc)
        self.acc = 0

In [None]:
import pytorch_lightning as pl

class UNetModule(pl.LightningModule):
    def __init__(self):
        super(UNetModule, self).__init__()
        # Define the model architecture
        self.model = UNet(in_channels=1, n_classes=1, n_filters=64, n_blocks=4)
        # Define the loss function
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self.model(x)
        y = y.type(torch.float32)
        loss = self.loss_fn(preds, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)

In [None]:
class GeneratorModule(pl.LightningModule):
    def __init__(self, in_channel, lambda_rec = 0.999, lambda_adv = 0.001):
        super(GeneratorModule, self).__init__()

        self.generator = Generator()
        self.discriminator = Discriminator()
        
        self.lambda_rec = lambda_rec
        self.lambda_adv = lambda_adv

        # A REVOIR
        # DANS LEUR CODE, ILS ONT MIS CA, MAIS PAS DANS L'ARTICLE
        # DANS L'ARTICLE, ILS UTILISENT DES LOSS LEGEREMENT DIFFERENTES, QUI PERMETTENT D'AVOIR DE LA PONDÉRATION
        self.criterion_rec = nn.MSELoss()
        self.criterion_adv = nn.BCELoss()
    
    def forward(self, x):
        return self.generator(x)
    
    def configure_optimizers(self):
        optimizerD = torch.optim.Adam(self.discriminator.parameters(), lr = 0.0002, betas = (0.5, 0.999))
        optimizerG = torch.optim.Adam(self.generator.parameters(), lr = 0.002, betas = (0.5, 0.999))
        return [optimizerD, optimizerG], []
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        full_images, center_images = batch
        batch_size = full_images.size(0)
        valid = torch.ones(batch_size, 1)
        fake = torch.zeros(batch_size, 1)

        # RAJOUTER DU BRUIT ? AMELIORE L'ENTRAINEMENT ? 
        # --> z = torch.randn(batch_size, 100, 1, 1, device=self.device)
        # --> fake = self.generator(z)
        # CAS DISCRIMINATOR
        if optimizer_idx == 0:
            fake_images = self.generator(full_images)

            pred_discriminator = self.discriminator(center_images)
            fake_discriminator = self.discriminator(fake_images.detach())
            
            real_loss = self.criterion_adv(pred_discriminator, valid)
            fake_loss = self.criterion_adv(fake_discriminator, fake)

            # VOIR COMMENT EXPLIQUER CETTE LOSS ? EST CE QU'ELLE EST DANS LE PAPIER ?
            d_loss = (real_loss + fake_loss) / 2

            self.log('d_loss', d_loss, prog_bar=True)

            return d_loss

        # CAS GENERATOR         
        if optimizer_idx == 1:   


: 