In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd.variable import Variable
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import imageio

In [2]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,),(0.5,))
                ])
to_image = transforms.ToPILImage()
trainset = MNIST(root='./data/', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=100, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [11]:
for i, [image, label] in enumerate(trainloader):
    print(i, image.shape, label.shape)
    break

0 torch.Size([100, 1, 28, 28]) torch.Size([100])


In [15]:
device = torch.device('cpu')

In [16]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.n_features = 128
        self.n_out = 784
        self.fc0 = nn.Sequential(
                    nn.Linear(self.n_features, 256),
                    nn.LeakyReLU(0.2)
                    )
        self.fc1 = nn.Sequential(
                    nn.Linear(256, 512),
                    nn.LeakyReLU(0.2)
                    )
        self.fc2 = nn.Sequential(
                    nn.Linear(512, 1024),
                    nn.LeakyReLU(0.2)
                    )
        self.fc3 = nn.Sequential(
                    nn.Linear(1024, self.n_out),
                    nn.Tanh()
                    )
    def forward(self, x):
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = x.view(-1, 1, 28, 28)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_in = 784
        self.n_out = 1
        self.fc0 = nn.Sequential(
                    nn.Linear(self.n_in, 1024),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc1 = nn.Sequential(
                    nn.Linear(1024, 512),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc2 = nn.Sequential(
                    nn.Linear(512, 256),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3)
                    )
        self.fc3 = nn.Sequential(
                    nn.Linear(256, self.n_out),
                    nn.Sigmoid()
                    )
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

In [17]:
generator = Generator()
discriminator = Discriminator()

generator.to(device)
discriminator.to(device)

g_optim = optim.Adam(generator.parameters(), lr=2e-4)
d_optim = optim.Adam(discriminator.parameters(), lr=2e-4)

g_losses = []
d_losses = []
images = []

criterion = nn.BCELoss()

def noise(n, n_features=128):
    return Variable(torch.randn(n, n_features)).to(device)

def make_ones(size):
    data = Variable(torch.ones(size, 1))
    return data.to(device)

def make_zeros(size):
    data = Variable(torch.zeros(size, 1))
    return data.to(device)


In [18]:
def train_discriminator(optimizer, real_data, fake_data):
    n = real_data.size(0)

    optimizer.zero_grad()
    
    prediction_real = discriminator(real_data)
    error_real = criterion(prediction_real, make_ones(n))
    error_real.backward()

    prediction_fake = discriminator(fake_data)
    error_fake = criterion(prediction_fake, make_zeros(n))
    
    error_fake.backward()
    optimizer.step()
    
    return error_real + error_fake

def train_generator(optimizer, fake_data):
    n = fake_data.size(0)
    optimizer.zero_grad()
    
    prediction = discriminator(fake_data)
    error = criterion(prediction, make_ones(n))
    
    error.backward()
    optimizer.step()
    
    return error


In [19]:
num_epochs = 250
k = 1
test_noise = noise(64)

generator.train()
discriminator.train()
for epoch in range(num_epochs):
    g_error = 0.0
    d_error = 0.0
    for i, data in enumerate(trainloader):
        imgs, _ = data
        n = len(imgs)
        for j in range(k):
            fake_data = generator(noise(n)).detach()
            real_data = imgs.to(device)
            d_error += train_discriminator(d_optim, real_data, fake_data)
        fake_data = generator(noise(n))
        g_error += train_generator(g_optim, fake_data)

    img = generator(test_noise).cpu().detach()
    img = make_grid(img)
    images.append(img)
    g_losses.append(g_error/i)
    d_losses.append(d_error/i)
    print('Epoch {}: g_loss: {:.8f} d_loss: {:.8f}\r'.format(epoch, g_error/i, d_error/i))
    
print('Training Finished')
torch.save(generator.state_dict(), 'mnist_generator.pth')

Epoch 0: g_loss: 2.88339138 d_loss: 0.95510399
Epoch 1: g_loss: 1.81115687 d_loss: 1.02838194
Epoch 2: g_loss: 2.47281313 d_loss: 0.77506012
Epoch 3: g_loss: 2.61726665 d_loss: 0.59707808
Epoch 4: g_loss: 3.35524750 d_loss: 0.43079934
Epoch 5: g_loss: 2.89860702 d_loss: 0.53202575
Epoch 6: g_loss: 2.76994109 d_loss: 0.58541030
Epoch 7: g_loss: 2.57045937 d_loss: 0.57326770
Epoch 8: g_loss: 2.65082765 d_loss: 0.57668263
Epoch 9: g_loss: 2.42610359 d_loss: 0.62487388
Epoch 10: g_loss: 2.30858684 d_loss: 0.66642123
Epoch 11: g_loss: 2.04738617 d_loss: 0.73612946
Epoch 12: g_loss: 2.01965261 d_loss: 0.76425332
Epoch 13: g_loss: 1.88807273 d_loss: 0.81586903
Epoch 14: g_loss: 1.89055574 d_loss: 0.83553463
Epoch 15: g_loss: 1.79148924 d_loss: 0.84527814
Epoch 16: g_loss: 1.71810889 d_loss: 0.85064739
Epoch 17: g_loss: 1.69372201 d_loss: 0.87936336
Epoch 18: g_loss: 1.72564459 d_loss: 0.85618520
Epoch 19: g_loss: 1.67952538 d_loss: 0.87805218
Epoch 20: g_loss: 1.57553232 d_loss: 0.92156512
Ep

Epoch 170: g_loss: 0.88601029 d_loss: 1.29034090
Epoch 171: g_loss: 0.88896686 d_loss: 1.28475547
Epoch 172: g_loss: 0.87605643 d_loss: 1.29159474
Epoch 173: g_loss: 0.87774253 d_loss: 1.28654921
Epoch 174: g_loss: 0.88042969 d_loss: 1.28832483
Epoch 175: g_loss: 0.87577224 d_loss: 1.28582692
Epoch 176: g_loss: 0.88051426 d_loss: 1.28365970
Epoch 177: g_loss: 0.87632698 d_loss: 1.29152465
Epoch 178: g_loss: 0.88533276 d_loss: 1.28629851
Epoch 179: g_loss: 0.87059200 d_loss: 1.29234362
Epoch 180: g_loss: 0.87689227 d_loss: 1.28816092
Epoch 181: g_loss: 0.87507397 d_loss: 1.28801942
Epoch 182: g_loss: 0.88452679 d_loss: 1.28962088
Epoch 183: g_loss: 0.87511015 d_loss: 1.29478121
Epoch 184: g_loss: 0.87070036 d_loss: 1.29036832
Epoch 185: g_loss: 0.86518383 d_loss: 1.29352152
Epoch 186: g_loss: 0.87365323 d_loss: 1.29150963
Epoch 187: g_loss: 0.87984043 d_loss: 1.29396677
Epoch 188: g_loss: 0.87432921 d_loss: 1.29398298
Epoch 189: g_loss: 0.87650520 d_loss: 1.28694546
Epoch 190: g_loss: 0