In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd.variable import Variable
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter

In [2]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,),(0.5,))
                ])
to_image = transforms.ToPILImage()
trainset = MNIST(root='./data/', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=100, shuffle=True)

writer = SummaryWriter('./log/')
device = 'cuda'

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.n_features = 128
        self.n_out = 784
        self.fc0 = nn.Sequential(
                    nn.Linear(self.n_features, 256),
                    nn.LeakyReLU(0.2)
                    )
        self.fc1 = nn.Sequential(
                    nn.Linear(256, 512),
                    nn.LeakyReLU(0.2)
                    )
        self.fc2 = nn.Sequential(
                    nn.Linear(512, 1024),
                    nn.LeakyReLU(0.2)
                    )
        self.fc3 = nn.Sequential(
                    nn.Linear(1024, self.n_out),
                    nn.Tanh()
                    )
    def forward(selfs, x):
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = x.view(-1, 1, 28, 28)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_in = 784
        self.n_out = 1
        self.fc0 = nn.Sequential(
                    nn.Linear(self.n_in, 1024),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc1 = nn.Sequential(
                    nn.Linear(1024, 512),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc2 = nn.Sequential(
                    nn.Linear(512, 256),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc3 = nn.Sequential(
                    nn.Linear(256, self.n_out),
                    nn.Sigmoid()
                    )
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

In [4]:
generator = Generator()
discriminator = Discriminator()

generator.to(device)
discriminator.to(device)

g_optim = optim.Adam(generator.parameters(), lr=2e-4)
d_optim = optim.Adam(discriminator.parameters(), lr=2e-4)

criterion = nn.BCELoss()

def noise(n, n_features=128):
    return Variable(torch.randn(n, n_features)).to(device)

def make_ones(size):
    data = Variable(torch.ones(size, 1))
    return data.to(device)

def make_zeros(size):
    data = Variable(torch.zeros(size, 1))
    return data.to(device)

In [5]:
def train_discriminator(optimizer, real_data, fake_data):
    n = real_data.size(0)

    optimizer.zero_grad()
    
    prediction_real = discriminator(real_data)
    error_real = criterion(prediction_real, make_ones(n))
    error_real.backward()

    prediction_fake = discriminator(fake_data)
    error_fake = criterion(prediction_fake, make_zeros(n))
    
    error_fake.backward()
    optimizer.step()
    
    return error_real + error_fake

def train_generator(optimizer, fake_data):
    n = fake_data.size(0)
    optimizer.zero_grad()
    
    prediction = discriminator(fake_data)
    error = criterion(prediction, make_ones(n))
    
    error.backward()
    optimizer.step()
    
    return error

In [6]:
num_epochs = 400
k = 2
test_noise = noise(64)

generator.train()
discriminator.train()
for epoch in range(num_epochs):
    g_error = 0.0
    d_error = 0.0
    for i, data in enumerate(trainloader):
        imgs, _ = data
        n = len(imgs)
        for j in range(k):
            fake_data = generator(noise(n)).detach()
            real_data = imgs.to(device)
            d_error += train_discriminator(d_optim, real_data, fake_data)
        fake_data = generator(noise(n))
        g_error += train_generator(g_optim, fake_data)
        
    writer.add_scalar('Loss/Generator', g_error/i, epoch)
    writer.add_scalar('Loss/Discriminator', d_error/i, epoch)
        
    img = generator(test_noise).cpu().detach()
    img = make_grid(img)
    writer.add_image('Result', img, epoch)
    print('Epoch {}: g_loss: {:.8f} d_loss: {:.8f}\r'.format(epoch, g_error/i, d_error/i))
    
print('Training Finished')
torch.save(generator.state_dict(), 'mnist_generator.pth')

Epoch 0: g_loss: 3.60787129 d_loss: 0.73424178
Epoch 1: g_loss: 1.68455637 d_loss: 1.03122914
Epoch 2: g_loss: 1.56025362 d_loss: 1.02352405
Epoch 3: g_loss: 1.66856039 d_loss: 0.96457118
Epoch 4: g_loss: 1.41388440 d_loss: 1.02538610
Epoch 5: g_loss: 1.52756274 d_loss: 1.02563298
Epoch 6: g_loss: 1.54301715 d_loss: 1.01866078
Epoch 7: g_loss: 1.43900013 d_loss: 1.02946758
Epoch 8: g_loss: 1.51640964 d_loss: 1.00838947
Epoch 9: g_loss: 1.45800161 d_loss: 1.02215385
Epoch 10: g_loss: 1.28648639 d_loss: 1.10953569
Epoch 11: g_loss: 1.31920636 d_loss: 1.09541667
Epoch 12: g_loss: 1.24594808 d_loss: 1.13200259
Epoch 13: g_loss: 1.25248027 d_loss: 1.10009432
Epoch 14: g_loss: 1.27451050 d_loss: 1.12223697
Epoch 15: g_loss: 1.24741650 d_loss: 1.11819398
Epoch 16: g_loss: 1.21382570 d_loss: 1.11546957
Epoch 17: g_loss: 1.16844130 d_loss: 1.16135430
Epoch 18: g_loss: 1.13485253 d_loss: 1.16816485
Epoch 19: g_loss: 1.12114811 d_loss: 1.16103339
Epoch 20: g_loss: 1.12081015 d_loss: 1.17592466
Ep

Epoch 170: g_loss: 0.88588887 d_loss: 1.28746712
Epoch 171: g_loss: 0.88762128 d_loss: 1.29032815
Epoch 172: g_loss: 0.90749210 d_loss: 1.28725863
Epoch 173: g_loss: 0.90598953 d_loss: 1.27945030
Epoch 174: g_loss: 0.89346021 d_loss: 1.28812551
Epoch 175: g_loss: 0.88172412 d_loss: 1.28955269
Epoch 176: g_loss: 0.88423687 d_loss: 1.28936946
Epoch 177: g_loss: 0.89229381 d_loss: 1.28803766
Epoch 178: g_loss: 0.88098776 d_loss: 1.29117215
Epoch 179: g_loss: 0.87553811 d_loss: 1.29485464
Epoch 180: g_loss: 0.89432448 d_loss: 1.28919709
Epoch 181: g_loss: 0.89438879 d_loss: 1.28715968
Epoch 182: g_loss: 0.89562947 d_loss: 1.28676867
Epoch 183: g_loss: 0.88199240 d_loss: 1.29092431
Epoch 184: g_loss: 0.89816308 d_loss: 1.28603673
Epoch 185: g_loss: 0.88486004 d_loss: 1.28988874
Epoch 186: g_loss: 0.88309807 d_loss: 1.28862441
Epoch 187: g_loss: 0.88697928 d_loss: 1.28989398
Epoch 188: g_loss: 0.88306212 d_loss: 1.28849244
Epoch 189: g_loss: 0.89360517 d_loss: 1.28805554
Epoch 190: g_loss: 0