In [44]:
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

In [45]:
z_dim = 64
device = "cuda:0" if torch.cuda.is_available() else "cpu"
batch_size=128
display_step = 50
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
crit_repeats = 5
n_epochs = 10
c_lambda = 10

transforms = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5, ), (0.5))
])

dataloader = DataLoader(
    MNIST("./", download=True, transform=transforms, ),
    batch_size=batch_size,
    shuffle=True,
)

In [46]:
class Generator(nn.Module):

  def __init__(self, hidden_dim=64):
    
    super(Generator, self).__init__()

    self.model = nn.Sequential(
        nn.ConvTranspose2d(in_channels = z_dim, out_channels = hidden_dim * 4, kernel_size=3, stride=2),
        nn.BatchNorm2d(hidden_dim * 4),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(in_channels = hidden_dim * 4, out_channels = hidden_dim * 2, kernel_size=4, stride=1),
        nn.BatchNorm2d(hidden_dim * 2),
        nn.ReLU(inplace=True),


        nn.ConvTranspose2d(in_channels = hidden_dim * 2, out_channels = hidden_dim, kernel_size=3, stride=2),
        nn.BatchNorm2d(hidden_dim),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(in_channels = hidden_dim, out_channels = 1, kernel_size=4, stride=2),
        nn.Tanh()
    )

  def forward(self, noise):

    return self.model(noise.view(len(noise), z_dim, 1, 1))


In [47]:
def get_noise(batch_size, z_dim, device):
  return torch.randn(batch_size, z_dim, device=device)

In [56]:
class Critic(nn.Module):

  def __init__(self, hidden_dim=64):

    super(Critic, self).__init__()

    self.model = nn.Sequential(
        
        nn.Conv2d(in_channels = 1, out_channels = hidden_dim, kernel_size=4, stride=2),
        nn.BatchNorm2d(hidden_dim),
        nn.LeakyReLU(0.2),


        nn.Conv2d(in_channels = hidden_dim, out_channels = hidden_dim * 2, kernel_size=4, stride=2),
        nn.BatchNorm2d(hidden_dim*2),
        nn.LeakyReLU(0.2),


        nn.Conv2d(in_channels = hidden_dim*2, out_channels = 1, kernel_size=4, stride=2),
    )

  def forward(self, x):
    crit_pred = self.model(x)
    return crit_pred.view(len(crit_pred), -1)

In [75]:
def get_gradient_penalty(crit, real, fake, epsilon):

  mixed_images = epsilon * real + (1 - epsilon) * fake

  mixed_scores = crit(mixed_images)

  gradient = torch.autograd.grad(
      inputs=mixed_images,
      outputs=mixed_scores,
      grad_outputs=torch.ones_like(mixed_scores),
      create_graph=True,
      retain_graph=True,
    
  )[0]

  gradient = gradient.view(len(gradient), -1)

  gradient_norm = gradient.norm(2, dim=1)

  penalty = torch.mean((gradient_norm - 1) ** 2)

  return penalty

In [76]:
def get_gen_loss(crit_fake_pred):
  return - torch.mean(crit_fake_pred)

In [77]:
def get_critic_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
  return -(torch.mean(crit_real_pred - crit_fake_pred)) + c_lambda * gp 

In [78]:
gen = Generator().to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))

crit = Critic().to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

In [81]:
for epoch in range(n_epochs):
  for real, _ in tqdm(dataloader):
    
    real = real.to(device)

    for i in range(crit_repeats):
      
      noise = get_noise(len(real), z_dim, device=device)
      fake = gen(noise).detach()
      crit_fake_pred = crit(fake)
      
      epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)


      gp = get_gradient_penalty(crit, real, fake, epsilon)

      crit_real_pred = crit(real)
      
      
      crit_opt.zero_grad()

      crit_loss = get_critic_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

      crit_loss.backward()

      crit_opt.step()

    
    noise = get_noise(batch_size, z_dim, device=device)
    
    out = crit(gen(noise))

    gen_opt.zero_grad()

    gen_loss = get_gen_loss(out)

    gen_loss.backward()

    gen_opt.step()

  0%|          | 0/469 [00:00<?, ?it/s]

KeyboardInterrupt: ignored