In [9]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl

import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from ilan_src.models import *



In [18]:
class WGANGP(LightningModule):
    def __init__(self,channels_noise, channels_img, features_g, num_classes, img_size, embed_size, features_d, 
                      b1 = 0.0, b2 = 0.9, lr = 1e-4, lambda_gp = 10): # fill in
        super().__init__()
        self.lr, self.b1, self.b2 = lr, b1, b2
        self.disc_freq, self.gen_freq = 5, 1
        self.noise_shape = (channels_noise, 1, 1)
        self.lambda_gp = lambda_gp
        self.gen = Generator(channels_noise, channels_img, features_g, num_classes, img_size, embed_size)
        self.disc = Discriminator(channels_img, features_d, num_classes, img_size)
        self.num_to_visualise = 12
        self.num_classes = num_classes
        self.automatic_optimization = False
        
        
    def forward(self, condition, noise):
        return self.gen(condition, noise)
    
    def gradient_penalty(self, condition, real, fake):
        BATCH_SIZE, C, H, W = real.shape
        epsilon = torch.rand((BATCH_SIZE, 1, 1, 1), device=self.device).repeat(1,C,H,W)
        interpolated_images = real*epsilon + fake*(1-epsilon)
        interpolated_images.requires_grad = True
        mixed_scores = self.disc(condition, interpolated_images)
        gradient = torch.autograd.grad(
                    inputs=interpolated_images,
                    outputs=mixed_scores, 
                    grad_outputs = torch.ones_like(mixed_scores), 
                    create_graph=True, 
                    retain_graph = True)[0]

        gradient = gradient.view(gradient.shape[0], -1)
        gradient_norm = gradient.norm(2, dim=1)
        gradient_penalty = torch.mean((gradient_norm - 1)**2)
        return gradient_penalty
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        gen_opt, disc_opt = self.optimizers()
        
        real, condition = batch # if label is condition.
#         print(real.device)
        if batch_idx%100==0:
            noise = torch.randn(real.shape[0], *self.noise_shape, device=self.device)
    #         # log sampled images
            sample_imgs = self.gen(condition, noise)
            sample_imgs = torch.cat([real, sample_imgs], dim = 0)
            grid = torchvision.utils.make_grid(sample_imgs)
            self.logger.experiment.add_image('generated_images', grid, batch_idx)
        
        for _ in range(self.disc_freq):
#         # train discriminator
#         if optimizer_idx == 0:
            noise = torch.randn(real.shape[0], *self.noise_shape, device=self.device)
            fake = self.gen(condition, noise)
            disc_real = self.disc(condition, real).reshape(-1)
            disc_fake = self.disc(condition, fake).reshape(-1)
            gp = self.gradient_penalty(condition, real, fake)
            loss_disc = -(torch.mean(disc_real) - torch.mean(disc_fake)) + self.lambda_gp*gp
#             self.logger.experiment.add_scalar('discriminator_loss', loss_disc, self.current_epoch)
            self.log('discriminator_loss', loss_disc, on_epoch=True, on_step=True, prog_bar=True, logger=True)
#             return loss_disc
            disc_opt.zero_grad()
            self.manual_backward(loss_disc, retain_graph=True)
            disc_opt.step()
        
#         #train generator
#         elif optimizer_idx ==1:
        noise = torch.randn(real.shape[0], *self.noise_shape, device = self.device)
        fake = self.gen(condition, noise)
        gen_fake = self.disc(condition, fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        self.log('generator_loss', loss_gen, on_epoch=True, on_step=True, prog_bar=True, logger=True)
#         return loss_gen 
        gen_opt.zero_grad()
        self.manual_backward(loss_gen)
        gen_opt.step()
        

        
    
    def training_epoch_end(self, outputs):
        noise = torch.randn(self.num_to_visualise, *self.noise_shape, device = self.device)

        # log sampled images
        sample_imgs = self(torch.randint(low=0, high = self.num_classes - 1, size=(self.num_to_visualise,), device=self.device), noise)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
        
    def configure_optimizers(self):
        gen_opt = optim.Adam(self.disc.parameters(), lr=self.lr, betas=(self.b1, self.b2))
        disc_opt = optim.Adam(self.disc.parameters(), lr=self.lr, betas=(self.b1, self.b2))
#         return [{"optimizer": disc_opt, "frequency": self.disc_freq}, {"optimizer": gen_opt, "frequency": self.gen_freq}]
        return gen_opt, disc_opt
        

In [19]:
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMG_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 30
FEATURES_CRITIC = 32 #64
FEATURES_GEN = 32 # 64
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10
NUM_CLASSES = 10
GEN_EMBEDDING = 100

trans = transforms.Compose(
    [
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

dataset = datasets.MNIST(root="dataset/", transform=trans, download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=6)

trainer = pl.Trainer(gpus = 1)
model = WGANGP(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMG_SIZE, GEN_EMBEDDING, FEATURES_CRITIC)

trainer.fit(model, loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


AttributeError: can't set attribute

In [5]:
%reload_ext tensorboard
%tensorboard --logdir lightning_logs

Reusing TensorBoard on port 6006 (pid 124372), started 2:43:02 ago. (Use '!kill 124372' to kill it.)