In [1]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

In [4]:
def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(), 
         transforms.Normalize((.5, .5, .5), (.5, .5, .5))
        ]
    )
    out_dir = './dataset'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

data = mnist_data()
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
num_batches = len(data_loader)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [8]:
data

Dataset MNIST
    Number of datapoints: 60000
    Split: train
    Root Location: ./dataset
    Transforms (if any): Compose(
                             ToTensor()
                             Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                         )
    Target Transforms (if any): None

In [10]:
class DiscriminatorNet(nn.Module):
    
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_feature = 784
        n_out = 1
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_feature, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        self.out = nn.Sequential(
            nn.Linear(256, n_out),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    

discriminator = DiscriminatorNet()
        

In [11]:
def img2vec(img):
    return img.view(img.size(0), 784)
def vec2img(vec):
    return vec.view(vec.size(0), 1, 28, 28)

In [14]:
class GeneratorNet(nn.Module):
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 784
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU(0.2)
        )
        
        self.hidden1 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

generator = GeneratorNet()


In [15]:
def noise(size):
    n = Variable(torch.randn(size, 100))
    return n

In [16]:
d_optimizer = optim.Adam(discriminator.parameters(), lr = 0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr = 0.0002)
loss = nn.BCELoss()


In [17]:
def ones_target(size):
    data = Variable(torch.ones(size, 1))
    return data

def zeros_target(size):
    data = Variable(torch.zeros(size,1))
    return data

In [19]:
def train_discriminator(optimizer, real_data, fake_data):
    N = real_data.size(0)
    optimizer.zero_grad()
    
    # train on real data
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, ones_target(N))
    error_real.backward()
    
    # train on fake data
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, zeros_target(N))
    error_fake.backward()
    
    # update
    optimizer.step()
    return error_real+error_fake, prediction_real, prediction_fake

def train_generator(optimizer, fake_data):
    N = fake_data.size(0)
    
    optimizer.zero_grad()
    
    # sample noise and generate fake data
    prediction = discriminator(fake_data)
    
    error = loss(prediction, ones_target(N))
    error.backward()
    optimizer.step()
    
    return error

In [20]:
num_test_samples = 16
test_noise = noise(num_test_samples)

In [27]:
num_epochs = 200
for epoch in range(num_epochs):
    for n_batch, (real_batch, _) in enumerate(data_loader):
        N = real_batch.size(0)
        
        ### 1. train discriminator
        real_data = Variable(img2vec(real_batch))
        print(real_data.shape)
        
        # generate fake data and detach 
        #(so gradients are not calculated for generator, 
            #which means that we don't want to update the weights in generator 
            #that gives fake data in the discriminator training process
        #)
        fake_data = generator(noise(N)).detach()
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer, real_data, fake_data)
        
        ### 2. train generator
        fake_data = generator(noise(N))
        g_error = train_generator(g_optimizer, fake_data)


torch.Size([100, 784])
