# import libraries

In [1]:
import torchvision
from torchvision.datasets import VisionDataset
import torchvision.datasets as dset
import torch
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision import transforms
import os
from pathlib import Path
import wandb
import pytorch_lightning as pl
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

# wandb setting to log

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mrespwill[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
wandb.init(
#   mode='disabled',
  # Set the project where this run will be logged
  project="Monet image gan project", 
  # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
  name="Test13", 
  # Track hyperparameters and run metadata
  config={
      "learning_rate": 0.0004,
      "batch_size":32,
      "latent_dim":100,
      "b1":0.5,
      "b2":0.999,
  })
wandb_logger = WandbLogger()

  rank_zero_warn(


# Load dataset with transform

In [4]:
class MonetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir1, data_dir2, batch_size, transform=None):
        super().__init__()
        self.data_dir_monet = data_dir1
        self.data_dir_photo = data_dir2
        self.batch_size = batch_size
        if transform == None:
            self.transform = transforms.Compose([
                transforms.ToTensor()
            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
            ])
    
    def setup(self, stage):
#         print(stage)
        self.dataset = dset.ImageFolder('./dataset/gan-getting-started/monet/', transform=self.transform)
#         self.dataset_train, self.dataset_val = random_split(dataset, [int(len(dataset)*0.9), int(len(dataset)*0.1)])
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)
    
#     def val_dataloader(self):
#         return torch.utils.data.DataLoader(self.dataset_val, batch_size=self.batch_size)

In [5]:
dm = MonetDataModule('./dataset/gan-getting-started/monet_jpg/','./dataset/gan-getting-started/photo_jpg/', wandb.config['batch_size'], transform=True)

# Set initial weight of Generator and Discriminator
* reference urls: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0)

# Generator

In [7]:
class Generator(pl.LightningModule):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.main = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(128, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(256, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(512, 1024),
            torch.nn.BatchNorm1d(1024, 0.8),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, int(np.prod(img_shape))),            
            torch.nn.Tanh()
        )
        
    def forward(self, input):
        img = self.main(input)
        # change shape of tensor from network
        # size(0)->batch size?
        img = img.view(img.size(0), *self.img_shape)
        return img

# Discriminator

In [8]:
class Discriminator(pl.LightningModule):
    def __init__(self, img_shape):
        super().__init__()
        self.main = torch.nn.Sequential(
            torch.nn.Linear(int(np.prod(img_shape)), 512),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(512, 256),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(256, 1),
            torch.nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.main(img_flat)
        return validity

# GAN

In [9]:
class GAN(pl.LightningModule):
    def __init__(self, channels, width, height, latent_dim, lr, b1, b2, batch_size):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False
        
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, 
                                   img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)
        self.validation_z = torch.randn(8, self.hparams.latent_dim)
        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
        
    def forward(self, z):
        return self.generator(z)
    
    def adversarial_loss(self, y_hat, y):
        return torch.nn.functional.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch):
        # 
        imgs, _ = batch
#         print(imgs)
        optimizer_g, optimizer_d = self.optimizers()
        
        # add noise: crate random number..?
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)
        
        # train generator
        self.toggle_optimizer(optimizer_g, optimizer_idx=0)
        self.generated_imgs = self(z)
        # get 6 samples from generated images
#         sample_imgs = self.generated_imgs[:3]
        # make flattened data into grid shape to see as image
#         grid = torchvision.utils.make_grid(sample_imgs)
#         self.logger.log_image("generated_images", [grid])
        
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)
        
        # binary cross-entropy
        # if generator created images well so that discriminator recognize the results as real(1), loss decrease.
        g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
        self.log('g_loss', g_loss, prog_bar=True, on_epoch=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        # stop generator training
        self.untoggle_optimizer(optimizer_g)
        
        # train discriminator
        self.toggle_optimizer(optimizer_d, optimizer_idx=1)
        
        # check if discriminator can recognize real images as real
        # low loss means discriminator recognize real images well.
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)
        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
        
        # check if discriminator can recognize fake images as fake
        # low loss means discriminator recognize fake images.
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)
        fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)
        
        # overall loss is average of these two loss.
        d_loss = (real_loss + fake_loss) / 2
        # low d_loss means discriminator recognize images well.
        self.log('d_loss', d_loss, prog_bar=True, on_epoch=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        
        # stop discriminator training.
        self.untoggle_optimizer(optimizer_d)
    
    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []
    
    def validation_step(self, *args):
        pass
    
    def on_training_epoch_end(self):
#     def validation_step(self):
#         print("it is the end of validation epoch")
        z = self.validation_z.type_as(self.generator.main[0].weight)
        
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
#         grid = wandb.Image(grid)
#         self.logger.experiment.add_image("generated_images", grid, self.current_epoch)        
        self.logger.log_image("generated_images", [grid])

In [10]:
param = {"channels":3,
         "width":256, 
         "height":256,
         "latent_dim":wandb.config['latent_dim'],
         "lr":wandb.config['learning_rate'],
         "b1":wandb.config['b1'],
         "b2":wandb.config['b2'],
         "batch_size":wandb.config['batch_size']}
model = GAN(**param)

In [None]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath='./check_point/', 
    filename='{epoch}-{loss:.4f}', 
    monitor="loss", 
    mode="min", 
    save_top_k=2
)

In [11]:
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=1000,
    log_every_n_steps=27,
    logger=wandb_logger,
    callbacks=[checkpoint_callback]
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, dm)

  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type          | Params | In sizes | Out sizes       
------------------------------------------------------------------------------
0 | generator     | Generator     | 202 M  | [2, 100] | [2, 3, 256, 256]
1 | discriminator | Discriminator | 100 M  | ?        | ?               
------------------------------------------------------------------------------
303 M     Trainable params
0         Non-trainable params
303 M     Total params
1,212.101 Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]