In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.distributions.normal import Normal

In [2]:
batch_size = 64

# MNIST Dataset 
mnist = dsets.MNIST(root='./data', 
                      train=True, 
                      transform=transforms.ToTensor(),  
                      download=True)

# Data Loader (Input Pipeline)
data_loader = torch.utils.data.DataLoader(dataset=mnist, 
                                          batch_size=batch_size, 
                                          shuffle=True)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  ################# uses gpu if available
print(device)

cpu


In [4]:
def to_np(x):
    return x.data.cpu().numpy()

def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

In [5]:
class aae_encoder(nn.Module):
    def __init__(self, X_dim, N, latent_dim):
        super(aae_encoder, self).__init__()
        self.input_size = X_dim
        self.latent_dim = latent_dim
        
        self.layer1 = nn.Linear(X_dim, N)
        self.layer2 = nn.Linear(N, N)
        
        ######## mean and variance for gauss distribution
        self.mean = nn.Linear(N, latent_dim)
        self.variance = nn.Linear(N, latent_dim)
        
    def forward(self, x):
        x = F.dropout(self.layer1(x), p=0.2)
        x = F.leaky_relu(x, 0.2)
        x = F.dropout(self.layer2(x), p=0.2)
        x = F.leaky_relu(x, 0.2)
        
        ####### gaussion distribution
        mean_layer = self.mean(x)
        variance_layer = self.variance(x)
        gauss_layer = Normal(torch.tensor(mean_layer), torch.tensor(variance_layer))
        xgauss = gauss_layer.sample()+mean_layer
        
        return xgauss
    
class aae_decoder(nn.Module):
    def __init__(self, latent_dim, N, X_dim):
        super(aae_decoder, self).__init__()
        self.latent_dim = latent_dim
        self.output_size = X_dim
        
        self.layer1 = nn.Linear(latent_dim, N)
        self.layer2 = nn.Linear(N, N)
        self.layer3 = nn.Linear(N, X_dim)
    
    def forward(self, x):
        x = F.dropout(self.layer1(x), p=0.2)
        x = F.leaky_relu(x, 0.2)
        x = F.dropout(self.layer2(x), p=0.2)
        x = F.leaky_relu(x, 0.2)
        
        x = self.layer3(x)
        
        return torch.sigmoid(x)
    
class aae_discriminator(nn.Module):
    def __init__(self, latent_dim, N):
        super(aae_discriminator, self).__init__()
        
        self.layer1 = nn.Linear(latent_dim, N)
        self.layer2 = nn.Linear(N, N)
        self.layer3 = nn.Linear(N, 1)
        
    def forward(self, x):
        x = F.dropout(self.layer1(x), p=0.2)
        x = F.leaky_relu(x, 0.2)
        x = F.dropout(self.layer2(x), p=0.2)
        x = F.leaky_relu(x, 0.2)
        
        x = self.layer3(x)
        
        return torch.sigmoid(x)
        

In [6]:
class AAE(nn.Module):
    def __init__(self, encoder, decoder, discriminator):
        super(AAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator
        
    def forward(self, x):
        encoder_output = self.encoder(x)
        decoder_output = self.decoder(encoder_output)
        discriminator_output = self.discriminator(encoder_output)
        
        return decoder_output, discriminator_output

In [14]:
EPS = 1e-15
N = 512
latent_dim = 128
encoder = aae_encoder(784, N, latent_dim)
decoder = aae_decoder(latent_dim, N, 784)
discriminator = aae_discriminator(latent_dim, 64)

aae = AAE(encoder, decoder, discriminator)

In [15]:
# Set learning rates
gen_lr = 0.0001
reg_lr = 0.00005

#encode/decode optimizers
encoder_optim = torch.optim.Adam(encoder.parameters(), lr=gen_lr)
decoder_optim = torch.optim.Adam(decoder.parameters(), lr=gen_lr)
#regularizing optimizers
encoder_gen_optim = torch.optim.Adam(decoder.parameters(), lr=reg_lr)
discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr=reg_lr)

data_iter = iter(data_loader)
iter_per_epoch = len(data_loader)
total_step = 5000

In [17]:
# Start training

reconstruction_loss = []
discriminator_loss = []
generator_loss = []
gen_learning_rate = []
reg_learning_rate = []

# z_real_gauss = Variable(torch.randn(batch_size, latent_dim) * 5.)

for step in range(total_step):
    
    if (step+1)%1000:
        gen_lr /= 5
        reg_lr /= 5

    gen_learning_rate.append(gen_lr)
    reg_learning_rate.append(reg_lr)
    
    # Reset the data_iter
    if (step+1) % iter_per_epoch == 0:
        data_iter = iter(data_loader)

    # Fetch the images and labels and convert them to variables
    images, labels = next(data_iter)
    images, labels = to_var(images.view(batch_size, -1)), to_var(labels)

    #reconstruction loss
    decoder.zero_grad()
    encoder.zero_grad()
    discriminator.zero_grad()

    X_sample, pred = aae(images)
    recon_loss = F.binary_cross_entropy(X_sample+EPS,images+EPS)

    recon_loss.backward()
    decoder_optim.step()
    encoder_optim.step()

    # Discriminator
    ## true prior is random normal (randn)
    ## this is constraining the Z-projection to be normal!
    encoder.eval()
#     z_real_gauss = Variable(torch.randn(images.size()[0], latent_dim) * 5.).cuda()
#     z_real_gauss = Variable(torch.randn(batch_size, latent_dim) * 5., )
    X_sample, pred = aae(images)
    z_real_gauss = torch.randn(batch_size, latent_dim) * 5.
    Disc_real_gauss = discriminator(z_real_gauss)

    Disc_loss = -torch.mean(torch.log(Disc_real_gauss + EPS) + torch.log(1 - pred + EPS))

    Disc_loss.backward(retain_graph=True)
    discriminator_optim.step()

    # Generator
    encoder.train()
    z_fake_gauss = encoder(images)
    Disc_fake_gauss = discriminator(z_fake_gauss)
    
    Gen_loss = -torch.mean(torch.log(Disc_fake_gauss + EPS))

    Gen_loss.backward()
    encoder_gen_optim.step()   
    
    if (step+1) % 10 == 0:
        print('For Step:', step+1 ,'recon_loss:', recon_loss.item(),
        '\tdiscriminator_loss:', Disc_loss.item(),
        '\tgenerator_loss:', Gen_loss.item())
        reconstruction_loss.append(recon_loss.item())
        discriminator_loss.append(Disc_loss.item())
        generator_loss.append(Gen_loss.item())





For Step: 10 recon_loss: 0.6809074282646179 	discriminator_loss: 1.5029983520507812 	generator_loss: 0.6650325059890747
For Step: 20 recon_loss: 0.6065708994865417 	discriminator_loss: 1.4838974475860596 	generator_loss: 0.665739119052887
For Step: 30 recon_loss: 0.3767111301422119 	discriminator_loss: 1.4573198556900024 	generator_loss: 0.6902081370353699
For Step: 40 recon_loss: 0.2907136082649231 	discriminator_loss: 1.4237043857574463 	generator_loss: 0.7225086092948914
For Step: 50 recon_loss: 0.2934366762638092 	discriminator_loss: 1.3274537324905396 	generator_loss: 0.7617560625076294
For Step: 60 recon_loss: 0.28590139746665955 	discriminator_loss: 1.234162449836731 	generator_loss: 0.7803312540054321
For Step: 70 recon_loss: 0.28752437233924866 	discriminator_loss: 1.2213034629821777 	generator_loss: 0.8098970651626587
For Step: 80 recon_loss: 0.26922914385795593 	discriminator_loss: 1.1803001165390015 	generator_loss: 0.8167733550071716
For Step: 90 recon_loss: 0.276106148958

KeyboardInterrupt: 