# import libraries

In [32]:
import torchvision
from torchvision.datasets import VisionDataset
import torchvision.datasets as dset
import torch
from torchvision import transforms
import os
from pathlib import Path
import wandb
import pytorch_lightning as pl
import numpy as np
from lightning.pytorch.callbacks import ModelCheckpoint

ModuleNotFoundError: No module named 'lightning'

# wandb setting to log

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mrespwill[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
wandb.init(
  mode='disabled',
  # Set the project where this run will be logged
  project="Monet image gan project", 
  # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
  name="Test0", 
  # Track hyperparameters and run metadata
  config={
      "learning_rate": 0.0003,
      "batch_size":32,
      "latent_dim":100,
      "b1":0.5,
      "b2":0.999,
  })



# Load dataset with transform

In [26]:
class MonetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, transform=None):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        if transform == None:
            self.transform = transforms.Compose([
                transforms.ToTensor()
            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
            ])
    
    def setup(self, stage):
        self.dataset = dset.ImageFolder(self.data_dir, transform=self.transform)
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)

In [29]:
dm = MonetDataModule('./dataset/gan-getting-started/', wandb.config['batch_size'], transform=True)

# Set initial weight of Generator and Discriminator
* reference urls: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

In [8]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0)

# Generator

In [10]:
class Generator(pl.LightningModule):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.main = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, 128),
            torch.nn.BatchNorm1d(128, 0.8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(128, 256),
            torch.nn.BatchNorm1d(256, 0.8),
            torch.nn.LeakyReLU(0.2, inplace=True),
#             torch.nn.Linear(256, 512),
#             torch.nn.BatchNorm1d(512, 0.8),
#             torch.nn.LeakyReLU(0.2, inplace=True),
#             torch.nn.Linear(512, 1024),
#             torch.nn.BatchNorm1d(1024, 0.8),
#             torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(256, int(np.prod(img_shape))),            
            torch.nn.Tanh()
        )
        
    def forward(self, input):
        img = self.main(input)
        # change shape of tensor from network
        # size(0)->batch size?
        img = img.view(img.size(0), *self.img_shape)
        return img

In [20]:
# netG = Generator(1, 100, 200, 64)

In [25]:
# netG.apply(weights_init)

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 1600, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(1600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): ConvTranspose2d(200, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): Tanh()
  )
)

# Discriminator

In [11]:
class Discriminator(pl.LightningModule):
    def __init__(self, img_shape):
        super().__init__()
        self.main = torch.nn.Sequential(
            torch.nn.Linear(int(np.prod(img_shape)), 512),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(512, 256),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(256, 1),
            torch.nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.main(img_flat)
        return validity

In [39]:
# netD = Discriminator(1, 64, 512)

In [35]:
# netD.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(64, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Conv2d(4096, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (4): Sigmoid()
  )
)

# GAN

In [15]:
class GAN(pl.LightningModule):
    def __init__(self, channels, width, height, latent_dim, lr, b1, b2, batch_size):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False
        
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, 
                                   img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)
        self.validation_z = torch.randn(8, self.hparams.latent_dim)
        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
        
    def forward(self, z):
        return self.generator(z)
    
    def adversarial_loss(self, y_hat, y):
        return torch.nn.functional.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch):
        imgs, _ = batch
        optimizer_g, optimizer_d = self.optimizers()
        
        # add noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)
        
        self.toggle_optimizer(optimizer_g, optimizer_idx=0)
        self.generated_imgs = self(z)
        
        sample_imgs = self.generated_imgs[:6]
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image("generated_images", grid, 0)
        
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)
        
        g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
        self.log('g_loss', g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)
        
        self.toggle_optimizer(optimizer_d, optimizer_idx=1)
        
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)
        
        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
        
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)
        
        fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)
        
        d_loss = (real_loss + fake_loss) / 2
        self.log('d_loss', d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)
    
    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []
    
    def on_validation_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)
        
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image("generated_images", grid, self.current_epoch)        

In [16]:
param = {"channels":3,
         "width":256, 
         "height":256,
         "latent_dim":wandb.config['latent_dim'],
         "lr":wandb.config['learning_rate'],
         "b1":wandb.config['b1'],
         "b2":wandb.config['b2'],
         "batch_size":wandb.config['batch_size']}
model = GAN(**param)

In [17]:
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=5
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [30]:
trainer.fit(model, dm)


  | Name          | Type          | Params | In sizes | Out sizes       
------------------------------------------------------------------------------
0 | generator     | Generator     | 50.6 M | [2, 100] | [2, 3, 256, 256]
1 | discriminator | Discriminator | 100 M  | ?        | ?               
------------------------------------------------------------------------------
151 M     Trainable params
0         Non-trainable params
151 M     Total params
605.481   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  tensor = (tensor * 255.0).astype(np.uint8)
`Trainer.fit` stopped: `max_epochs=5` reached.
