In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

In [2]:
# Image Preprocessing
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

In [3]:
# MNIST Dataset
train_dataset = dsets.MNIST(root='../data/',
                            train=True, 
                            transform=transform,
                            download=True)

In [4]:
# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100, 
                                           shuffle=True)

In [5]:
# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
        
    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        out = F.sigmoid(self.fc3(h))
        return out

In [6]:
# Generator Model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 784)
            
    def forward(self, x):
        h = F.leaky_relu(self.fc1(x))
        h = F.leaky_relu(self.fc2(h))
        out = F.tanh(self.fc3(h))
        return out

In [7]:
discriminator = Discriminator()
generator = Generator()

In [8]:
# Loss and Optimizer
criterion = nn.BCELoss()    # binary cross entropy
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0005)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)

In [9]:
# Training 
for epoch in range(200):
    for i, (images, _) in enumerate(train_loader):
        # Build mini-batch dataset
        images = images.view(images.size(0), -1)
        images = Variable(images)
        real_labels = Variable(torch.ones(images.size(0)))
        fake_labels = Variable(torch.zeros(images.size(0)))
        
        # Train the discriminator
        discriminator.zero_grad()
        outputs = discriminator(images)
        real_loss = criterion(outputs, real_labels) # maximize D is right on true
        real_score = outputs
        
        noise = Variable(torch.randn(images.size(0), 128))
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach()) 
        fake_loss = criterion(outputs, fake_labels) # maximize D is right on fake
        fake_score = outputs
        
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        # Train the generator 
        generator.zero_grad()
        noise = Variable(torch.randn(images.size(0), 128))
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        # maximize D classify fake as true (conveted into max, better to train at 
        # the begining)
        g_loss = criterion(outputs, real_labels) 
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 300 == 0:
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, ' 
                  'D(x): %.2f, D(G(z)): %.2f' 
                  %(epoch+1, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
                    real_score.data.mean(), fake_score.cpu().data.mean()))
            
    # Save the sampled images
    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
    torchvision.utils.save_image(fake_images.data, 
        './data/fake_samples_%d.png' %(epoch+1))

Epoch [0/200], Step[300/600], d_loss: 0.0599, g_loss: 5.2123, D(x): 0.99, D(G(z)): 0.05
Epoch [0/200], Step[600/600], d_loss: 0.4934, g_loss: 6.6483, D(x): 0.89, D(G(z)): 0.17
Epoch [1/200], Step[300/600], d_loss: 0.8111, g_loss: 2.7667, D(x): 0.76, D(G(z)): 0.26
Epoch [1/200], Step[600/600], d_loss: 2.0467, g_loss: 0.7656, D(x): 0.51, D(G(z)): 0.66
Epoch [2/200], Step[300/600], d_loss: 0.3638, g_loss: 3.1908, D(x): 0.85, D(G(z)): 0.12
Epoch [2/200], Step[600/600], d_loss: 1.0088, g_loss: 1.3936, D(x): 0.67, D(G(z)): 0.38
Epoch [3/200], Step[300/600], d_loss: 1.8104, g_loss: 1.2920, D(x): 0.53, D(G(z)): 0.61
Epoch [3/200], Step[600/600], d_loss: 3.4970, g_loss: 0.3797, D(x): 0.38, D(G(z)): 0.84
Epoch [4/200], Step[300/600], d_loss: 0.8351, g_loss: 3.4683, D(x): 0.72, D(G(z)): 0.27
Epoch [4/200], Step[600/600], d_loss: 1.3750, g_loss: 1.0079, D(x): 0.53, D(G(z)): 0.36
Epoch [5/200], Step[300/600], d_loss: 1.1849, g_loss: 2.7277, D(x): 0.68, D(G(z)): 0.28
Epoch [5/200], Step[600/600], d_

In [10]:
# Save the Models 
torch.save(generator.state_dict(), './generator.pkl')
torch.save(discriminator.state_dict(), './discriminator.pkl')