# Project AI 
This script runs three types of models for the discrete latent space of a VAE: Gaussian, Gumbel-Softmax, and Logit-normal. The Logit-normal is run for different hyperparameter settings of the prior.

In [None]:
# Import required sources.
%pylab inline
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch import nn, optim
from VAE import *
from train import *
import numpy as np
from collections import *
import random
import pickle
import os

## Setting parameters and loading data

In [None]:
# set hyperparameters
latent_dims = [2, 4, 8, 20, 40]
variances = [0.32, 0.56, 1., 1.78]
variance_tags = [32, 56, 1, 178]
epochs = 1

# Load data
train_data = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor())

# Make model directory
if not os.path.exists('models'):
    os.makedirs('models')
if not os.path.exists('models/gumbel'):
    os.makedirs('models/gumbel')
if not os.path.exists('models/gauss'):
    os.makedirs('models/gauss')
if not os.path.exists('models/logit'):
    os.makedirs('models/logit')

## Running the Gaussian models

In [None]:
for dim in latent_dims:

    VAE, loss, z, KL, log_bern = run_train(dim, epochs, 'Gaussian', train_data, 1e-3)
    pickle.dump([VAE, loss, z, KL, log_bern], open('models/gauss/gauss{}.p'.format(dim), 'wb'))
    del VAE, loss, z, KL, log_bern

## Running the Gumbel-Softmax models

In [None]:
for dim in latent_dims:

    VAE, loss, z, KL, log_bern = run_train(dim, epochs, 'Gumbel', train_data, 1e-3)
    pickle.dump([VAE, loss, z, KL, log_bern], open('models/gumbel/gumbel{}.p'.format(dim), 'wb'))
    del VAE, loss, z, KL, log_bern

## Running the Logit-normal models with varying priors

In [None]:
for dim in latent_dims:
    for (var, label) in zip(variances, variance_tags):
        VAE, loss, z, KL, log_bern = run_train(dim, epochs, 'logit', train_data, 1e-3, variance=var)
        pickle.dump([VAE, loss, z, KL, log_bern], open('models/logit/logit{}_{}.p'.format(dim, label), 'wb'))
        del VAE, loss, z, KL, log_bern