# Generative Adversarial Network on MNIST 

In [1]:
import matplotlib.pyplot as plt
%matplotlib inline

In [17]:
import torch
import torchvision
import torch.utils.data as Data
import numpy as np
import torch.nn as nn
torch.manual_seed(7)
np.random.seed(7)

In [25]:
BATCH_SIZE = 64 
LR_G = 0.01 # Learning rate for G network
LR_D = 0.01 # Learning rate for D network
INPUT_DIMENSION = 784
INPUT_NUMBERS = 100
N_NODES = 128

## 1. Load MNIST Data

In [30]:
mnist_trainset=torchvision.datasets.MNIST('MNIST',download=True,transform=torchvision.transforms.ToTensor())

## 2. Create Two Networks

In [56]:
noise_x.shape

torch.Size([64, 100])

In [31]:
Generator = nn.Sequential(
    nn.Linear(INPUT_NUMBERS,N_NODES),
    nn.ReLU(),
    nn.Linear(N_NODES,INPUT_DIMENSION),
)

Discriminator = nn.Sequential(
    nn.Linear(INPUT_DIMENSION,N_NODES),
    nn.ReLU(),
    nn.Linear(N_NODES,1),
    nn.Sigmoid(),
)

In [70]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

Generator.apply(init_weights)
Discriminator.apply(init_weights)

  This is separate from the ipykernel package so we can avoid doing imports until


Sequential(
  (0): Linear(in_features=784, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=1, bias=True)
  (3): Sigmoid()
)

## 3. Model Training

In [32]:
# Here we will do the batch training.

In [33]:
train_loader = Data.DataLoader(dataset=mnist_trainset, batch_size=BATCH_SIZE, shuffle=True)

In [35]:
opt_D = torch.optim.Adam(Discriminator.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(Generator.parameters(), lr=LR_G)

In [None]:
for epoch in range(10):
    
    for step, (true_x, true_y) in enumerate(train_loader):
    # Reshape data
        true_x= true_x.view(-1,INPUT_DIMENSION)
    # Random generate noise data
        NOISE_SIZE = true_x.shape[0]
        noise_x = torch.randn(NOISE_SIZE, INPUT_NUMBERS)
    # G Model
        fake_x = Generator(noise_x)
    # D Model
        prob0 = Discriminator(true_x)
        prob1 = Discriminator(fake_x)

    # Calculate Loss
        D_loss = - torch.mean(torch.log(prob0) + torch.log(1. - prob1))
        G_loss = torch.mean(torch.log(1. - prob1))
    # Optimization
        opt_D.zero_grad()
        D_loss.backward(retain_graph=True)      # reusing computational graph
        opt_D.step()

        opt_G.zero_grad()
        G_loss.backward()
        opt_G.step()
        
        if step % 200 == 0:  # plotting
            #print('True_prob{},Fake_prob{}'.format(prob0,prob1))
            fake_pic = fake_x.view(-1,28,28).data.numpy()[0]
            plt.cla()
            plt.imshow(fake_pic)