# Wasserstein GAN - Gradient Penalty

- GAN의 손실 함수를 개선해 기울기 소실이나 모드 붕괴 현상을 완화하는 기법

In [2]:
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

In [3]:
epochs, batch_size = 100, 64
lr,b1,b2 = 2e-4,0.5,0.999
latent_dim = 100
img_size = 28
channels = 1
n_critic = 5
lambda_gp = 10
img_shape = (channels, img_size, img_size)
if torch.cuda.is_available():
    print("Train on GPU")
    cuda = True
else:
    print("Train on CPU")
    cuda = False

Train on GPU


In [4]:
os.makedirs("FashionMNIST_DATASET",exist_ok=True)

dataloader = DataLoader(
    datasets.FashionMNIST(
        'FashionMNIST_DATASET',
        train=True,
        download=True,
        transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5],[0.5])
        ])
    ),
    batch_size=batch_size,
    shuffle=True
)

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)),512),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(256,1)
        )
    def forward(self,img):
        flat_img = img.view(img.shape[0],-1)
        pred = self.model(flat_img)
        return pred
        
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
    
        def block(in_feature, out_feature, normalize=True):
            layers = [nn.Linear(in_feature,out_feature)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feature,0.8))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim,128,normalize=False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,int(np.prod(img_shape))),
            nn.Tanh() # -1 ~ 1 (CLIPPING)
        )
    def forward(self,z):
        img = self.model(z)
        img = img.view(img.shape[0],*img_shape)
        return img

In [6]:
G = Generator()
D = Discriminator()
if cuda:
    G.cuda()
    D.cuda()

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

def gradient_penalty(D,real_img,fake_img):
    alpha = Tensor(np.random.random((real_img.size(0),1,1,1)))
    interpolates = (alpha * real_img + ((1-alpha)*fake_img)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_img.shape[0],1).fill_(1.0),requires_grad=False)
    gradients = autograd.grad(outputs = d_interpolates,
                              inputs=interpolates,
                              grad_outputs=fake,
                              create_graph=True,
                              retain_graph=True,
                              only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0),-1)
    GP = ((gradients.norm(2,dim=1)-1) **2).mean()
    return GP

In [8]:
optimizer_G = torch.optim.Adam(G.parameters(), lr = lr, betas=(b1,b2))
optimizer_D = torch.optim.Adam(D.parameters(),lr=lr,betas=(b1,b2))
current_iters = 0
os.makedirs("WGAN-GP_results",exist_ok=True)

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = Variable(imgs.type(Tensor))

        ##Train Discriminator ##
        optimizer_D.zero_grad()
        z = Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],latent_dim))))
        
        fake_imgs = G(z)
        real_pred = D(real_imgs)
        fake_pred = D(fake_imgs)
        GP = gradient_penalty(D, real_imgs.data, fake_imgs.data)
        d_loss = -torch.mean(real_pred) + torch.mean(fake_pred) + lambda_gp*GP
        d_loss.backward()
        optimizer_D.step()
        optimizer_G.zero_grad()
        
        if i % n_critic == 0:
            ## Train G ##
            fake_imgs = G(z)
            fake_pred = D(fake_imgs)
            g_loss = -torch.mean(fake_pred)
            g_loss.backward()
            optimizer_G.step()
            current_iters += n_critic
            
    print(f'EPOCH : {epoch + 1}/{epochs} | [D loss: {d_loss.item()}] | [G loss: {g_loss.item()}]')
    save_image(fake_imgs.data[:25],f"WGAN-GP_results/epochs_{epoch}.png",nrow=5,normalize=True)
    
torch.save(G.state_dict(),'./Generator.pth')
torch.save(D.state_dict(),'./Discriminator.pth')

  z = Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],latent_dim))))
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


EPOCH : 1/100 | [D loss: -3.7432658672332764] | [G loss: 0.43370795249938965]
EPOCH : 2/100 | [D loss: -2.8184943199157715] | [G loss: -1.4357749223709106]
EPOCH : 3/100 | [D loss: -3.0436651706695557] | [G loss: -1.6627751588821411]
EPOCH : 4/100 | [D loss: -2.319378137588501] | [G loss: -0.7082768678665161]
EPOCH : 5/100 | [D loss: -2.835312843322754] | [G loss: -0.7042218446731567]
EPOCH : 6/100 | [D loss: -1.454805850982666] | [G loss: -0.16897116601467133]
EPOCH : 7/100 | [D loss: -2.4491446018218994] | [G loss: -2.0427308082580566]
EPOCH : 8/100 | [D loss: -2.2560102939605713] | [G loss: -0.3144831657409668]
EPOCH : 9/100 | [D loss: -2.2736639976501465] | [G loss: 0.6223976016044617]
EPOCH : 10/100 | [D loss: -2.8120322227478027] | [G loss: -1.374320387840271]
EPOCH : 11/100 | [D loss: -2.0508322715759277] | [G loss: -2.15464186668396]
EPOCH : 12/100 | [D loss: -2.7060093879699707] | [G loss: -0.5170355439186096]
EPOCH : 13/100 | [D loss: -1.3908313512802124] | [G loss: -1.776363