In [None]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

In [None]:
class DiscriminatorNet(nn.Module):
    
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_feature = 784
        n_out = 1
        
        self.hidden = nn.Sequential(
            nn.Linear(n_feature, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        self.out = nn.Sequential(
            nn.Linear(512, n_out),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.hidden(x)
        x = self.out(x)
        return x

class GeneratorNet(nn.Module):
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 784
        
        self.hidden = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(512, n_out),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.hidden(x)
        x = self.out(x)
        return x

In [None]:
def noise(size):
    n = Variable(torch.randn(size, 100))
    return n

def ones_target(size):
    data = Variable(torch.ones(size, 1))
    return data

def zeros_target(size):
    data = Variable(torch.zeros(size,1))
    return data

In [None]:
def train_discriminator(optimizer, real_data, fake_data):
    N = real_data.size(0)
    optimizer.zero_grad()
    
    # train on real data
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, ones_target(N))
    error_real.backward()
    
    # train on fake data
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, zeros_target(N))
    error_fake.backward()
    
    # update
    optimizer.step()
    return error_real+error_fake, prediction_real, prediction_fake

def train_generator(optimizer, fake_data):
    N = fake_data.size(0)
    
    optimizer.zero_grad()
    
    # sample noise and generate fake data
    prediction = discriminator(fake_data)
    
    error = loss(prediction, ones_target(N))
    error.backward()
    optimizer.step()
    
    return error

In [None]:

discriminator = DiscriminatorNet()
generator = GeneratorNet()

d_optimizer = optim.Adam(discriminator.parameters(), lr = 0.0001)
g_optimizer = optim.Adam(generator.parameters(), lr = 0.0001)

loss = nn.BCELoss()

In [None]:
for epoch in range(num_epochs):
    for data_real in enumerate(dataset):
        ### 1. train discriminator
        data_fake = generator(noise(N)).detach()
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer, data_real, data_fake)
        
        ### 2. train generator
        data_fake = generator(noise(N))
        g_error = train_generator(g_optimizer, data_fake)
