In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
import torchvision
import torchvision.transforms as transforms

In [3]:
transform = transforms.Compose([
    transforms.ToTensor()
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [4]:
img_set = torchvision.datasets.MNIST(root="./mnist", train=True, download=True, transform=transform)

In [5]:
img_loader = torch.utils.data.DataLoader(dataset=img_set, shuffle=True, batch_size=100)

In [6]:
def output_dimension(height, width, kernal, stride, padding, output_padding=(0,0)):
    h_out = (height - 1) * stride[0] - 2 * padding[0] + kernal[0] + output_padding[0]
    w_out = (width - 1) * stride[1] - 2 * padding[1] + kernal[1] + output_padding[1]
    return (h_out, w_out)

In [7]:
output_dimension(1, 1, (4, 4), (1,1), (0,0))

(4, 4)

In [8]:
output_dimension(4, 4, (4, 4), (2,2), (1,1))

(8, 8)

In [9]:
output_dimension(8, 8, (4, 4), (2,2), (1,1))

(16, 16)

In [10]:
output_dimension(16, 16, (4, 4), (2,2), (3,3))

(28, 28)

In [11]:
class LeNet(nn.Module):
    #the discriminator
    def __init__(self, side, n_hidden, n_chan1, n_chan2, kernal):
        super(LeNet, self).__init__()
        padding = (kernal - 1)//2
        self.conv1 = nn.Conv2d(1, n_chan1, kernal, padding=padding)
        self.conv2 = nn.Conv2d(n_chan1, n_chan2, kernal, padding=padding)
        
        dimension = self.size_after_pool(side, padding, padding)
        dimension = self.size_after_pool(dimension, padding, padding)
        
        self.fc = nn.Linear(dimension * dimension * n_chan2, n_hidden)
        self.output = nn.Linear(n_hidden, 1)
        
        self.drop1 = nn.Dropout2d(p=0.25)
        self.drop2 = nn.Dropout2d(p=0.25)
    
    def forward(self, x):
        
        out = F.relu(self.drop1(self.conv1(x)))
        out = F.max_pool2d(out, (2,2))
        out = F.relu(self.drop2(self.conv2(out)))
        out = F.max_pool2d(out, 2)
        
        out = out.view(out.size(0), -1)
        
        out = F.relu(self.fc(out))
        out = F.sigmoid(self.output(out))
        return out.view(-1, 1)
    
    def size_after_pool(self, size, stride, sp_ext):
        return int((size - sp_ext)/stride) + 1

In [12]:
class Generator(nn.Module):
    def __init__(self, n_batch, n_input, n_z_step):
        super(Generator, self).__init__()
        self.n_input = n_input
        self.nzstep = n_z_step
        self.n_batch = n_batch
        
        self.conv1 = nn.ConvTranspose2d(n_input, n_z_step * 8, 4, 1, 0, bias=False)
        self.conv2 = nn.ConvTranspose2d(n_z_step * 8, n_z_step * 4, 4, 2, 1, bias=False)
        self.conv3 = nn.ConvTranspose2d(n_z_step * 4, n_z_step * 2, 4, 2, 1, bias=False)
        self.conv4 = nn.ConvTranspose2d(n_z_step * 2, 1, 4, 2, 3, bias=False)
        
        self.batch1 = nn.BatchNorm2d(n_z_step * 8)
        self.batch2 = nn.BatchNorm2d(n_z_step * 4)
        self.batch3 = nn.BatchNorm2d(n_z_step * 2)
    
    def generate_noise(self):
        noise = torch.FloatTensor(self.n_batch, self.n_input, 1, 1)
        noise.normal_()
        return noise
        
    def forward(self, x):
        out = F.relu(self.batch1(self.conv1(x)))
        out = F.relu(self.batch2(self.conv2(out)))
        out = F.relu(self.batch3(self.conv3(out)))
        out = F.tanh(self.conv4(out))
        
        return out
        

In [13]:
disc = LeNet(28, 500, 16, 32, 5)
gen = Generator(100, 512, 64)

In [14]:
criterion = nn.BCELoss()
optimizerD = torch.optim.Adam(disc.parameters(), lr=2e-4, betas=(0.9, 0.999))
optimizerG = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.9, 0.999))

In [15]:
disc.cuda()
gen.cuda()
criterion.cuda()

BCELoss (
)

In [16]:
import torchvision

In [17]:
import time

n_epochs = 30

one = torch.FloatTensor([1]).cuda()
mone = one * -1

gen_iter = 0
for epoch in range(n_epochs):
    current_time = time.clock()
    
    i = 0 
    for image, _ in img_loader:
        
        disc.zero_grad()
        real_image = Variable(image.cuda())
        real_label = Variable(torch.ones(100).cuda())
        
        real_output = disc(real_image)
        real_loss = criterion(real_output, real_label)
        real_loss.backward()
        
        #train fake
        noise = gen.generate_noise().cuda()
        noise = Variable(noise)
        
        fake_image = gen(noise)
        fake_label = Variable(torch.zeros(100).cuda())
        
        fake_output = disc(fake_image.detach())
        fake_loss = criterion(fake_output, fake_label)
        fake_loss.backward()
        
        errD = real_loss + fake_loss
        optimizerD.step()
        
        #GENERATOR
        
        gen.zero_grad()
        gen_label = Variable(torch.ones(100).cuda())
        gen_output = disc(fake_image)
        gen_loss = criterion(gen_output, gen_label)
        gen_loss.backward()
        optimizerG.step()
        errG = gen_loss
        
        i += 1
        
        if i % 100 == 0:
            torchvision.utils.save_image(fake_image.data, "fake_image_iep%03d,%03d.png" % (epoch, i))
        
        
    passed_time = time.clock() - current_time
    print("Epoch: ", epoch, "lossD: ", errD.data[0], " lossG: ", errG.data[0], " took: ", passed_time)

Epoch:  0 lossD:  0.17502692341804504  lossG:  3.2645599842071533  took:  104.25778825791035
Epoch:  1 lossD:  0.21347205340862274  lossG:  5.313624382019043  took:  103.24735601968246
Epoch:  2 lossD:  0.023498620837926865  lossG:  6.631292343139648  took:  103.31037837478203
Epoch:  3 lossD:  0.08215543627738953  lossG:  6.47157096862793  took:  103.32489602443883
Epoch:  4 lossD:  0.016522588208317757  lossG:  6.773327827453613  took:  103.44342179395716


KeyboardInterrupt: 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
def imshow(img):
    img = img / 2 + 0.5      # unnormalize
    npimg = img.numpy()
    npimg = np.reshape(npimg, (50, 32, 32, 3), order='F')
    plt.imshow(npimg[0])

In [None]:

n = Variable(gen.generate_noise().cuda())
img = gen(n)
torchvision.utils.save_image(img.cpu().data, "wtf.png", normalize=False)

In [None]:
torch.save(disc.state_dict(), "disc.pkl")
torch.save(gen.state_dict(), "gen.pkl")

In [None]:
%matplotlib inline
imshow(img.cpu().data)