In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

In [2]:
datasets = torchvision.datasets.MNIST(root = "./MNIST", download=True, transform= transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset=datasets, shuffle=True, batch_size=200)

In [3]:
class Encoder(nn.Module):
    
    def __init__(self, height, width, n_channel = 1, n_hidden = 500, z_dim = 300):
        super(Encoder, self).__init__()
        self.width = width
        self.height = height
        self.z_dim = z_dim
        self.n_hidden = n_hidden
        self.n_channel = n_channel
        self.lin1 = nn.Linear(width * height * n_channel, n_hidden)
        
        self.out = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, z_dim))
    
    def forward(self, X):
        X = X.view(-1, self.width * self.height * self.n_channel)
        out = F.relu(self.lin1(X))
        out = self.out(out)
        return out
    
    def sample(self, batch_size, mu, log_var):
        eps = Variable(torch.randn(batch_size, self.z_dim)).cuda()
        return mu + torch.exp(log_var / 2) * eps

In [4]:
class Decoder(nn.Module):
    
    def __init__(self, width, height, n_channel = 1, n_hidden = 500, z_dim = 300):
        super(Decoder, self).__init__()
        self.width = width
        self.height = height
        self.z_dim = z_dim
        self.n_hidden = n_hidden
        self.n_channel = n_channel
        
        self.lin1 = nn.Sequential(
            nn.Linear(z_dim, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden))
        
        self.out = nn.Linear(n_hidden, self.width * self.height * n_channel)
    
    def forward(self, X):
        out = F.relu(self.lin1(X))
        out = F.sigmoid(self.out(out))
        return out

In [5]:
class Discriminator(nn.Module):
    def __init__(self, z_dim, n_hidden = 500):
        super(Discriminator, self).__init__()
        self.z_dim = z_dim
        
        self.n_hidden = n_hidden
        
        self.net = nn.Sequential(
            nn.Linear(z_dim, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 1),
            nn.Sigmoid()
        )
    
    def forward(self, X):
        out = self.net(X)
        return out

In [6]:
width = 28
height = 28
encoder = Encoder(width, height, z_dim=5, n_hidden=1000)
decoder = Decoder(width, height, z_dim=5, n_hidden=1000)
discriminator = Discriminator(encoder.z_dim)
TINY = 1e-15

In [7]:
gen_lr, reg_lr = 0.0006, 0.0003
autoenc_lr, gen_disc_lr = 0.01, 0.1
decoding_optim = torch.optim.Adam(decoder.parameters(), lr=autoenc_lr)
encoder_encoding_optim = torch.optim.Adam(encoder.parameters(), lr=autoenc_lr)
encoder_generator_optim = torch.optim.Adam(encoder.parameters(), lr=gen_disc_lr)
discriminator_optim = torch.optim.Adam(discriminator.parameters(),lr=gen_disc_lr)

def turn_reg_off():
    encoder_encoding_optim = torch.optim.Adam(encoder.parameters(), lr=0)
    discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr=0)

def turn_reg_on():
    encoder_generator_optim = torch.optim.Adam(encoder.parameters(), lr=reg_lr)
    discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr=reg_lr)

In [8]:
#enable cuda
encoder.cuda()
decoder.cuda()
discriminator.cuda()

Discriminator (
  (net): Sequential (
    (0): Linear (5 -> 500)
    (1): ReLU ()
    (2): Linear (500 -> 500)
    (3): ReLU ()
    (4): Linear (500 -> 1)
    (5): Sigmoid ()
  )
)

In [9]:

def zero_grads():
    encoder.zero_grad()
    decoder.zero_grad()
    discriminator.zero_grad()
    
def run_training_epoch():
    for img, label in dataloader:
        
        
        
        zero_grads()
        
        img = img * 0.3081 + 0.1307
        X = Variable(img.cuda())
        z = encoder(X)
        X_recon = decoder(z)
        recon_loss = F.binary_cross_entropy(X_recon + TINY, X + TINY)

        recon_loss.backward()
        decoding_optim.step()
        encoder_encoding_optim.step()
        
        
        #Regularization
        zero_grads()

        z_real_gauss = Variable((torch.randn(200, encoder.z_dim) * 5.0).cuda())
        z_fake_gauss = encoder(X)

        real_gauss_loss = discriminator(z_real_gauss)
        fake_gauss_loss = discriminator(z_fake_gauss)
        
        #print(real_gauss_loss)
        #print(fake_gauss_loss)
        
        total_disc_loss = -torch.mean(torch.log(real_gauss_loss + TINY) + torch.log(1 - fake_gauss_loss + TINY))
        total_disc_loss.backward()
        discriminator_optim.step()
        
       
        #Generator
        zero_grads()
        
        encoder_generator_optim.step()
        z_fake = encoder(X)
        z_fake_disc = discriminator(z_fake)
            
        g_loss = -torch.mean(torch.log(z_fake_disc + TINY))
        g_loss.backward()
        encoder_generator_optim.step()
        
        zero_grads()
       
    return recon_loss.data[0], total_disc_loss.data[0], g_loss.data[0]


In [10]:
n_epochs = 50
for epoch in range(n_epochs):
    losses = run_training_epoch()
    print("Epoch {}, loss {}".format(epoch, losses))
    test_z = Variable(torch.FloatTensor(200, encoder.z_dim).normal_()).cuda()
    result = decoder(test_z).data.view(-1, 1, 28, 28)
    torchvision.utils.save_image(result, "aae-results-{}.png".format(epoch))

KeyboardInterrupt: 

In [None]:
import torchvision
test_z = Variable(torch.FloatTensor(200, encoder.z_dim).normal_()).cuda()
result = decoder(test_z).data.view(-1, 1, 28, 28)
torchvision.utils.save_image(result, "vae-results.png")