In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

In [2]:
noise_dim = 64
batch_size = 128
lr = 0.00001
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
train_loader = DataLoader(MNIST(root="./content", download=True, train=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True)
len(train_loader)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


469

In [4]:
class Generator(nn.Module):

  def __init__(self, hidden_dim = 128, output_size = 28 * 28):
    
    super(Generator, self).__init__()

    self.model = nn.Sequential(
        nn.Linear(noise_dim, hidden_dim),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),

        nn.Linear(hidden_dim, hidden_dim * 2),
        nn.BatchNorm1d(hidden_dim * 2),
        nn.ReLU(inplace=True),

        nn.Linear(hidden_dim * 2, hidden_dim * 4),
        nn.BatchNorm1d(hidden_dim * 4),
        nn.ReLU(inplace=True),

        nn.Linear(hidden_dim * 4, hidden_dim * 8),
        nn.BatchNorm1d(hidden_dim * 8),
        nn.ReLU(inplace=True),

        nn.Linear(hidden_dim * 8, output_size),
        nn.Sigmoid()
    )
  
  def forward(self, x):
    return self.model(x)

In [5]:
def get_noise(size, device):
  return torch.randn(size, device=device)

In [6]:
noise = get_noise((batch_size, noise_dim), device=device)

In [7]:
gen = Generator().to(device)
gen_optimizer = torch.optim.Adam(gen.parameters(), lr=lr)

In [8]:
class Discriminator(nn.Module):

  def __init__(self, image_size = 784, hidden = 128):
    
    super(Discriminator, self).__init__()

    self.model = nn.Sequential(
        
        nn.Linear(image_size, hidden * 4),
        # nn.BatchNorm1d(hidden * 4),
        nn.LeakyReLU(0.2),

        nn.Linear(hidden * 4, hidden * 2),
        # nn.BatchNorm1d(hidden * 2),
        nn.LeakyReLU(0.2),

        nn.Linear(hidden * 2, hidden),
        # nn.BatchNorm1d(hidden),
        nn.LeakyReLU(0.2),

        nn.Linear(hidden, 1),
    )
  
  def forward(self, x):

    return self.model(x)

In [9]:
disc = Discriminator(image_size=784, hidden=128).to(device)
disc_optimizer = torch.optim.Adam(disc.parameters(), lr=lr)

In [10]:
criterion = nn.BCEWithLogitsLoss()

In [11]:
def gen_loss(generator, discriminator, real, num_images, z_dim, device):

  # real is tensor of shape (batch_size, 28 * 28)
  # num_images is the length of real
  
  noise = get_noise((num_images, z_dim), device)
  out_gen = generator(noise)
  out_disc = discriminator(out_gen)
  
  gen_loss = criterion(out_disc, torch.ones_like(out_disc))
  
  return gen_loss

In [12]:
def disc_loss(generator, discriminator, real, num_images, z_dim, device):

  # real is tensor of shape (batch_size, 28 * 28)
  # num_images is the length of real

  noise = get_noise((num_images, z_dim), device)
  out_gen_fake = generator(noise).detach()
  out_disc_fake = discriminator(out_gen_fake)
  fake_loss = criterion(out_disc_fake, torch.zeros_like(out_disc_fake))

  out_disc_real = discriminator(real)
  real_loss = criterion(out_disc_real, torch.ones_like(out_disc_real))

  disc_loss = (1/2) * (fake_loss + real_loss)

  return disc_loss

In [13]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [14]:
cur_step = 0

In [None]:
for epoch in range(30):

  for iteration, (X, y) in enumerate(train_loader):
    
    X = X.to(device)
    y = y.to(device)
    # Training Generator

    gen_optimizer.zero_grad()
    
    gloss = gen_loss(gen, disc, X.reshape(-1, 28 * 28), len(X), noise_dim, device)

    gloss.backward()

    gen_optimizer.step()


    # Training Discriminator

    disc_optimizer.zero_grad()

    dloss = disc_loss(gen, disc, X.reshape(-1, 28 * 28), len(X), noise_dim, device)

    dloss.backward()
    
    disc_optimizer.step()
    

    ## Visualize Generator output

    if cur_step % 500 == 0:
      print("Generator Loss : ", gloss.item())
      print("Discriminator Loss : ", dloss.item())


      noise = get_noise((batch_size, noise_dim), device)
      out = gen(noise)
      print(cur_step)
      show_tensor_images(out)
      show_tensor_images(X)
    
    cur_step += 1