### WGAN using weight clipping

In [21]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
%pylab inline
%load_ext tensorboard

Populating the interactive namespace from numpy and matplotlib
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


`%matplotlib` prevents importing * from pylab and numpy
  warn("pylab import has clobbered these variables: %s"  % clobbered +


In [22]:
# Disciminator

class Critic(nn.Module):
    def __init__(self, img_ch=1, hidden_dim=16):
        super().__init__()
        self.critic = nn.Sequential(self.block(img_ch, hidden_dim),
                                 self.block(hidden_dim, hidden_dim*2),
                                 nn.Conv2d(hidden_dim*2, 1, 4, 2))
        
    def block(self, in_channel, op_channel, kernel_size=4, stride=2):
        return nn.Sequential(nn.Conv2d(in_channel, op_channel, kernel_size, stride),
                            nn.BatchNorm2d(op_channel),
                            nn.LeakyReLU(0.2))
    
    def forward(self, x):return self.critic(x)
    
# GENERAtor

class Generator(nn.Module):
    def __init__(self, z_dim, img_ch=1, hidden_dim=64):
        super().__init__()
        self.gen = nn.Sequential(self.block(z_dim, hidden_dim*4),
                                self.block(hidden_dim*4, hidden_dim*2),
                                self.block(hidden_dim*2, hidden_dim),
                                nn.ConvTranspose2d(hidden_dim, img_ch, 3, 2),
                                nn.Tanh())
        
    def block(self, in_channel, op_channel, kernel_size=3, stride=2):
        return nn.Sequential(nn.ConvTranspose2d(in_channel, op_channel, kernel_size, stride),
                            nn.BatchNorm2d(op_channel),
                            nn.ReLU())
    
    def forward(self, x):return self.gen(x)

In [23]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [24]:
z_dim = 64
batch_size = 128
lr = 0.0002
critic_iters = 5
weights_clip = 0.01

In [25]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5), (0.5))])

dataset = datasets.MNIST(r"C:\Users\sankalp\Desktop\Computer Vision\GAN\Data", download=True, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [26]:
critic = Critic().to(device)
gen = Generator(z_dim).to(device)

In [27]:
optim_critic = optim.RMSprop(critic.parameters(), lr=lr)
optim_gen = optim.RMSprop(gen.parameters(), lr=lr)

In [28]:
# WEIGHTS

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
        nn.init.constant_(m.bias, 0)
        
# Applying WEIGHTS

critic = critic.apply(weights_init)
gen = gen.apply(weights_init)

In [29]:
num_epochs = 50
test_noise = torch.randn(batch_size, z_dim, 1, 1)

In [30]:
summary_fake = SummaryWriter(f'logs_dcgan/fake')
summary_real = SummaryWriter(f'logs_dcgan/real')

test_noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
step = 0

In [31]:
%tensorboard --logdir logs_dcgan

Reusing TensorBoard on port 6006 (pid 12840), started 1 day, 1:44:55 ago. (Use '!kill 12840' to kill it.)

In [32]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        for _ in range(critic_iters):
            noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
            fake = gen(noise)
            critic_fake = critic(fake).reshape(-1)
            critic_real = critic(real).reshape(-1)
            loss_crtic = -(torch.mean(critic_real) - torch.mean(critic_fake)) # - due to maximization of loss wrt critic
            critic.zero_grad()
            loss_crtic.backward(retain_graph=True)
            optim_critic.step()
            
            for m in critic.parameters():
                m.data.clamp_(-weights_clip, weights_clip)
                
        output = critic(fake)
        loss_gen = -(torch.mean(output))
        gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()
        
        if batch_idx == 0:
            with torch.no_grad():
                fake = gen(test_noise)
                summary_fake.add_image('Fake', make_grid(fake, normalize=True), global_step=step)
                summary_real.add_image('Real', make_grid(real, normalize=True), global_step=step)
            step += 1

KeyboardInterrupt: 