In [1]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt 
import random
from ipdb import set_trace as st

In [2]:
import sys
sys.argv=['']
del sys

In [3]:
# construct the argument parser and parser the arguments
parser = argparse.ArgumentParser(description='VAE Example')
parser.add_argument('--batch-size', type=int, default=2048, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args()

torch.manual_seed(args.seed)

# we use CPU for computation
device = torch.device("cpu")

#K = 1024
K = 100

In [4]:
def data_gen(BATCH_SIZE):
    #8 gaussians
    while 1:
        theta = (np.pi/4) * torch.randint(0, 8, (BATCH_SIZE,)).float().to(device)
        centers = torch.stack((torch.cos(theta), torch.sin(theta)), dim = -1)
        noise = torch.randn_like(centers) * 0.1
        yield centers + noise


In [5]:
test_loader = train_loader = data_gen(args.batch_size)

In [6]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc0 = nn.Linear(2, K)
        self.fc1 = nn.Linear(K, K)
        self.fc21 = nn.Linear(K, K)
        self.fc22 = nn.Linear(K, K)
        self.fc3 = nn.Linear(K, K)
        self.fc4 = nn.Linear(K, K)
        self.fc5 = nn.Linear(K, 2)

    def encode(self, x):
        h1 = F.selu(self.fc1(F.selu(self.fc0(x))))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        return self.fc5(F.selu(self.fc4(F.selu(self.fc3(z)))))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 2))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [7]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose = True, threshold = 1E-2, eps=1e-6)

### torch.nn.Sequential and sample
your comment for last code update:
"To simplify the definition of the layers in the VAE and the usage of the nonlinearities, you can use torch.nn.Sequential.
* Aim to integrate the prior into your model in the sense that at a later stage we could easily replace the prior. 
Currently, for instance, you sample from the prior in __main__ to generate data. That should rather be part of the VAE as not each VAE we consider will probably use a zero-mean unit-variance Gaussian and it is then error-prone to do such sampling outside of the VAE."
I try to define VAE with torch.nn.Sequential and define sample as a part of VAE.
A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.


In [None]:
#update
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encodinglayer1 = nn.Sequential(
            nn.Linear(input_size, hidden_size), nn.ReLU()
        )
        self.encodinglayer2_mean = nn.Sequential(nn.Linear(hidden_size, latent_size))
        self.encodinglayer2_logvar = nn.Sequential(nn.Linear(hidden_size, latent_size))
        self.decodinglayer = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid(),
        )

    def sample(self, k):
        sample = torch.randn(2048, K).to(device)
        out = model.decode(sample).cpu().numpy()
        recon = model(gt)[0].cpu().numpy()
        return recon
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std


    def forward(self, x):
        x = x.view(-1, input_size)
        x = self.encodinglayer1(x)
        log_var = self.encodinglayer2_logvar(x)
        mean = self.encodinglayer2_mean(x)

        z = self.sample(log_var, mean)
        x = self.decodinglayer(z)

        return x, mean, log_var

In [8]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    #BCE = F.binary_cross_entropy(recon_x, x.view(-1, 2), reduction='sum')
    L2 = torch.mean((recon_x-x)**2)
    
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return L2 + KLD

In [9]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        if batch_idx > 100:
            break #100 batches per epoch
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    train_loss /= 100
    scheduler.step(train_loss)
    print (train_loss)

### test function
for validation loss 

In [None]:
#update
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
           

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


### Early stopping
your comment for last update code:
"* Make the network smaller, e.g., use K=100.
* Run more epochs, e.g., 1000, but implement an early stopping heuristic so that it can abort training when converged."

Early stopping is a form of regularization used to avoid overfitting on the training dataset. Early stopping keeps track of the validation loss, if the loss stops decreasing for several epochs in a row the training stops. The EarlyStopping class in pytorchtool.py is used to create an object to keep track of the validation loss while training a PyTorch model. It will save a checkpoint of the model each time the validation loss decrease. We set the patience argument in the EarlyStopping class to how many epochs we want to wait after the last time the validation loss improved before breaking the training loop

In [None]:
#update
# import EarlyStopping
from pytorchtools import EarlyStopping

In [None]:
#update
def train_model(model, batch_size, patience, n_epochs):
    
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in range(1, n_epochs + 1):

        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        for batch, (data, target) in enumerate(train_loader, 1):
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            recon_batch, mu, logvar = model(data)
             # calculate the loss
            loss = loss_function(recon_batch, data, mu, logvar)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # record training loss
            train_losses.append(loss.item())

        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
        for data, target in valid_loader:
            # forward pass: compute predicted outputs by passing inputs to the model
            recon_batch, mu, logvar = model(data)
            # calculate the loss
            loss = loss_function(recon_batch, data, mu, logvar)
            # record validation loss
            valid_losses.append(loss.item())

        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = len(str(n_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    # load the last checkpoint with the best model
    model.load_state_dict(torch.load('checkpoint.pt'))

    return  model, avg_train_losses, avg_valid_losses

In [None]:
#update
#n_epochs = 100

#train_loader, test_loader, valid_loader = create_datasets(batch_size)
train_loader, test_loader, valid_loader = data_gen(args.batch_size)

# early stopping patience; how long to wait after last time validation loss improved.
patience = 20

model, train_loss, valid_loss = train_model(model, batch_size, patience, n_epochs)

### next step
your comment for last update code:
"What are next steps from perspective? From my perspective it would be great to run an experiment in which you
 report some form of error between the data generated from the VAE and the original data over different number of sizes of the latent space."


In [None]:
if __name__ == "__main__":
    for epoch in range(1, args.epochs + 1):
        train(epoch)

    gt = next(train_loader)
        
    with torch.no_grad():
        sample = torch.randn(2048, K).to(device)
        out = model.decode(sample).cpu().numpy()
        recon = model(gt)[0].cpu().numpy()

    rx,ry = recon[:,0], recon[:,1]
        
    gt = gt.cpu().numpy()
    gx, gy = gt[:,0], gt[:,1]
        
    xs, ys = out[:,0], out[:,1]

    plt.scatter(gx, gy, c = 'red', s=3)
    plt.scatter(xs, ys, c = 'blue', s=3)
    plt.axes().set_aspect('equal')
    plt.show()
    
    st()
