In [2]:
!pip install corner

Collecting corner
  Downloading https://files.pythonhosted.org/packages/65/af/a7ba022f2d5787f51db91b5550cbe8e8c40a6eebd8f15119e743a09a9c19/corner-2.0.1.tar.gz
Building wheels for collected packages: corner
  Building wheel for corner (setup.py) ... [?25l[?25hdone
  Created wheel for corner: filename=corner-2.0.1-cp36-none-any.whl size=11642 sha256=82a4809b0aeb0e148cf17a7f1e5d41afea86cce858fc2e3c93030f5816724183
  Stored in directory: /root/.cache/pip/wheels/70/d8/e5/e0e7974a2a5757483ea5a180c937041cf6872dc9993d78234a
Successfully built corner
Installing collected packages: corner
Successfully installed corner-2.0.1


In [3]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [4]:
cd drive/My\ Drive/Projects/VariationalAE

/content/drive/My Drive/Projects/VariationalAE


In [0]:
import os
import numpy as np
import pandas as pd

In [0]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.autograd import Variable
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

In [0]:
from my_utils import load_data, train_evaluate_model

In [0]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [0]:
class InfoVAE(nn.Module):
    def __init__(self, nfeat=4, ncode=3, alpha=0, lambd=10000, nhidden=128, nhidden2=35, dropout=0.2):
        super(InfoVAE, self).__init__()
        
        self.ncode = int(ncode)
        self.alpha = float(alpha)
        self.lambd = float(lambd)
        
        self.leaky_relu = nn.LeakyReLU()
        self.encd = nn.Linear(nfeat, nhidden)
        self.d1 = nn.Dropout(p=dropout)
        self.enc2 = nn.Linear(nhidden, nhidden2)
        self.d2 = nn.Dropout(p=dropout)
        self.mu = nn.Linear(nhidden2, ncode)
        self.lv = nn.Linear(nhidden2, ncode)
        
        self.decd = nn.Linear(ncode, nhidden2)
        self.d3 = nn.Dropout(p=dropout)
        self.dec2 = nn.Linear(nhidden2, nhidden)
        self.d4 = nn.Dropout(p=dropout)
        self.outp = nn.Linear(nhidden, nfeat)
        
    def encode(self, x):
        x = self.d1(self.leaky_relu(self.encd(x)))
        x = self.d2(self.leaky_relu(self.enc2(x)))
        mu = self.mu(x)
        logvar = self.lv(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
    
    def decode(self, x):
        x = self.d3(self.leaky_relu(self.decd(x)))
        x = self.d4(self.leaky_relu(self.dec2(x)))
        x = self.outp(x)
        return x
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        mu = mu.to(device)
        logvar = logvar.to(device)
        z = self.reparameterize(mu, logvar).to(device)
        return self.decode(z), mu, logvar
    
    # https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/
    def compute_kernel(self, x, y):
        x_size = x.size(0)
        y_size = y.size(0)
        dim = x.size(1)
        x = x.unsqueeze(1).to(device) # (x_size, 1, dim)
        y = y.unsqueeze(0).to(device) # (1, y_size, dim)
        tiled_x = x.expand(x_size, y_size, dim).to(device)
        tiled_y = y.expand(x_size, y_size, dim).to(device)
        # The example code divides by (dim) here, making <kernel_input> ~ 1/dim
        # excluding (dim) makes <kernel_input> ~ 1
        kernel_input = (tiled_x - tiled_y).pow(2).mean(2)#/float(dim)
        return torch.exp(-kernel_input) # (x_size, y_size)
    
    # https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/
    def compute_mmd(self, x, y):
        xx_kernel = self.compute_kernel(x,x)
        yy_kernel = self.compute_kernel(y,y)
        xy_kernel = self.compute_kernel(x,y)
        return torch.mean(xx_kernel) + torch.mean(yy_kernel) - 2*torch.mean(xy_kernel)
    
    def loss(self, x):
        recon_x, mu, logvar = self.forward(x)
        MSE = torch.sum(0.5 * (x - recon_x).pow(2))
        
        # KL divergence (Kingma and Welling, https://arxiv.org/abs/1312.6114, Appendix B)
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        #return MSE + self.beta*KLD, MSE
                
        # https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/
        true_samples = Variable(torch.randn(200, self.ncode), requires_grad=False)
        z = self.reparameterize(mu, logvar) #duplicate call
        # compute MMD ~ 1, so upweight to match KLD which is ~ n_batch x n_code
        MMD = self.compute_mmd(true_samples,z) * x.size(0) * self.ncode
        return MSE + (1-self.alpha)*KLD + (self.lambd+self.alpha-1)*MMD, MSE, KLD, MMD

In [0]:
tag = '4D/'
#os.makedirs(tag)

In [0]:
train_dl, valid_dl, train_x, train_y, test_x, test_y = load_data(batch_size = 512)

In [0]:
def train(): #model, optimizer, epoch, min_valid_loss, badepochs
    model.train()
    train_loss = 0
    train_logL = 0
    for batch_idx, data in enumerate(train_dl):
        x = data[0].to(device)
        optimizer.zero_grad()
        loss, logL, KLD, MMD = model.loss(x)
        loss = loss.to(device)
        logL = logL.to(device)
        KLD = KLD.to(device)
        MMD = MMD.to(device)
        loss.backward()
        train_loss += loss.item()
        train_logL += logL.item()
        optimizer.step()
    train_loss /= len(train_dl.dataset)
    
    with torch.no_grad():
        model.eval()
        valid_loss = 0
        valid_logL = 0
        valid_KLD = 0
        valid_MMD = 0

        for data in valid_dl:
            x = data[0].to(device)
            loss, logL, KLD, MMD = model.loss(x)
            valid_loss += loss.item()
            valid_logL += logL.item()
            valid_KLD += KLD.item()
            valid_MMD += MMD.item()
        
        valid_loss /= len(valid_dl.dataset)
        valid_logL /= len(valid_dl.dataset)
        valid_KLD  /= len(valid_dl.dataset)
        valid_MMD  /= len(valid_dl.dataset)
    return valid_loss, valid_logL, valid_KLD, valid_MMD

In [0]:
class EarlyStopper: #Stops the run if loss starts to diverge
    def __init__(self, precision=1e-3, patience=10):
        self.precision = precision
        self.patience = patience
        self.badepochs = 0
        self.min_valid_loss = float('inf')
        
    def step(self, valid_loss):
        if valid_loss < self.min_valid_loss*(1-self.precision):
            self.badepochs = 0
            self.min_valid_loss = valid_loss
        else:
            self.badepochs += 1
        return not (self.badepochs == self.patience)

In [0]:
epochs = 100
log_interval = 10
mdl_ncode = 3
n_config = 1000
nfeat = 4

In [0]:
mdl_MSE = np.zeros(n_config)
mdl_KLD = np.zeros(n_config)
mdl_MMD = np.zeros(n_config)

In [0]:
from tqdm import tqdm
from tqdm import tqdm_notebook

In [0]:
for config in range(n_config):
    alpha = 0
    lambd = np.exp(np.random.uniform(0, np.log(1e5)))
    dropout = 0#0.9*np.random.uniform()
    dfac = 1./(1.-dropout)

    nhidden = 200
    nhidden2 = 200
    print('config %i, alpha = %0.1f, lambda = %0.1f, dropout = %0.2f; 2 hidden layers with %i, %i nodes' % (config, alpha, lambd, dropout, nhidden, nhidden2))
    model = InfoVAE(alpha=alpha, lambd=lambd, nfeat=nfeat, nhidden=nhidden, nhidden2=nhidden2, ncode=mdl_ncode, dropout=dropout)
    model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=5)
    stopper = EarlyStopper(patience=10)

    for epoch in tqdm_notebook((range(1, epochs + 1))):
        valid_loss, valid_logL, valid_KLD, valid_MMD = train()
        if epoch % log_interval == 0:
            print('====> Epoch: {} VALIDATION Loss: {:.2e} logL: {:.2e} KL: {:.2e} MMD: {:.2e}'.format(
                  epoch, valid_loss, valid_logL, valid_KLD, valid_MMD))
        
        optimizer.step()
        scheduler.step(valid_loss)
        if (not stopper.step(valid_loss)) or (epoch == epochs):
            print('Stopping')
            print('====> Epoch: {} VALIDATION Loss: {:.2e} logL: {:.2e} KL: {:.2e} MMD: {:.2e}'.format(
                  epoch, valid_loss, valid_logL, valid_KLD, valid_MMD))
            model.MSE = -valid_logL
            model.KLD = valid_KLD
            model.MMD = valid_MMD
            mdl_MSE[config] = model.MSE
            mdl_KLD[config] = model.KLD
            mdl_MMD[config] = model.MMD
            torch.save(model, tag+'/%04i.pth' % config)
            break

config 0, alpha = 0.0, lambda = 226.4, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 5.33e+00 logL: 1.96e+00 KL: 2.49e-02 MMD: 1.48e-02
Epoch    11: reducing learning rate of group 0 to 1.0000e-05.
Epoch    19: reducing learning rate of group 0 to 1.0000e-06.
====> Epoch: 20 VALIDATION Loss: 5.52e+00 logL: 1.95e+00 KL: 2.65e-02 MMD: 1.58e-02
Stopping
====> Epoch: 23 VALIDATION Loss: 5.57e+00 logL: 1.95e+00 KL: 2.65e-02 MMD: 1.59e-02
config 1, alpha = 0.0, lambda = 7047.6, dropout = 0.00; 2 hidden layers with 200, 200 nodes


  "type " + obj.__name__ + ". It won't be checked "


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.11e+02 logL: 2.00e+00 KL: 8.92e-03 MMD: 1.55e-02
Epoch    14: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 18 VALIDATION Loss: 1.15e+02 logL: 2.00e+00 KL: 4.89e-03 MMD: 1.60e-02
config 2, alpha = 0.0, lambda = 158.7, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 4.21e+00 logL: 1.73e+00 KL: 1.29e-01 MMD: 1.49e-02
Epoch    10: reducing learning rate of group 0 to 1.0000e-05.
Epoch    17: reducing learning rate of group 0 to 1.0000e-06.
====> Epoch: 20 VALIDATION Loss: 4.52e+00 logL: 1.70e+00 KL: 1.45e-01 MMD: 1.69e-02
Stopping
====> Epoch: 21 VALIDATION Loss: 4.23e+00 logL: 1.71e+00 KL: 1.45e-01 MMD: 1.51e-02
config 3, alpha = 0.0, lambda = 18.6, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.03e+00 logL: 1.38e+00 KL: 3.81e-01 MMD: 1.54e-02
Epoch    16: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 2.01e+00 logL: 1.35e+00 KL: 4.05e-01 MMD: 1.47e-02
====> Epoch: 30 VALIDATION Loss: 2.00e+00 logL: 1.35e+00 KL: 4.11e-01 MMD: 1.36e-02
====> Epoch: 40 VALIDATION Loss: 2.02e+00 logL: 1.34e+00 KL: 4.17e-01 MMD: 1.50e-02
Epoch    40: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 44 VALIDATION Loss: 2.03e+00 logL: 1.34e+00 KL: 4.15e-01 MMD: 1.58e-02
config 4, alpha = 0.0, lambda = 2667.4, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 3.98e+01 logL: 1.99e+00 KL: 1.08e-02 MMD: 1.42e-02
Epoch    12: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 16 VALIDATION Loss: 3.94e+01 logL: 1.99e+00 KL: 9.91e-03 MMD: 1.40e-02
config 5, alpha = 0.0, lambda = 13.3, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.95e+00 logL: 1.37e+00 KL: 3.86e-01 MMD: 1.59e-02
====> Epoch: 20 VALIDATION Loss: 1.94e+00 logL: 1.35e+00 KL: 4.06e-01 MMD: 1.54e-02
Stopping
====> Epoch: 23 VALIDATION Loss: 1.96e+00 logL: 1.35e+00 KL: 4.08e-01 MMD: 1.64e-02
config 6, alpha = 0.0, lambda = 1375.5, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.25e+01 logL: 1.98e+00 KL: 2.19e-02 MMD: 1.49e-02
Epoch    19: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 2.27e+01 logL: 1.97e+00 KL: 2.18e-02 MMD: 1.51e-02
Stopping
====> Epoch: 23 VALIDATION Loss: 2.27e+01 logL: 1.97e+00 KL: 2.21e-02 MMD: 1.51e-02
config 7, alpha = 0.0, lambda = 1042.2, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.88e+01 logL: 1.99e+00 KL: 9.50e-03 MMD: 1.62e-02
Epoch    14: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 18 VALIDATION Loss: 1.84e+01 logL: 1.99e+00 KL: 9.65e-03 MMD: 1.57e-02
config 8, alpha = 0.0, lambda = 7.8, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.86e+00 logL: 1.34e+00 KL: 4.16e-01 MMD: 1.65e-02
Epoch    17: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 1.86e+00 logL: 1.33e+00 KL: 4.17e-01 MMD: 1.64e-02
Stopping
====> Epoch: 21 VALIDATION Loss: 1.85e+00 logL: 1.33e+00 KL: 4.20e-01 MMD: 1.53e-02
config 9, alpha = 0.0, lambda = 19.1, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.06e+00 logL: 1.35e+00 KL: 4.05e-01 MMD: 1.68e-02
Epoch    10: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 2.04e+00 logL: 1.35e+00 KL: 4.01e-01 MMD: 1.57e-02
Epoch    25: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 29 VALIDATION Loss: 2.02e+00 logL: 1.35e+00 KL: 4.08e-01 MMD: 1.47e-02
config 10, alpha = 0.0, lambda = 1760.8, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 3.04e+01 logL: 1.99e+00 KL: 8.88e-03 MMD: 1.62e-02
Stopping
====> Epoch: 17 VALIDATION Loss: 2.89e+01 logL: 1.99e+00 KL: 9.92e-03 MMD: 1.53e-02
config 11, alpha = 0.0, lambda = 1.5, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.76e+00 logL: 1.34e+00 KL: 4.11e-01 MMD: 1.77e-02
Epoch    17: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 19 VALIDATION Loss: 1.75e+00 logL: 1.33e+00 KL: 4.12e-01 MMD: 1.42e-02
config 12, alpha = 0.0, lambda = 24.8, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.14e+00 logL: 1.40e+00 KL: 3.59e-01 MMD: 1.60e-02
Epoch    12: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 2.12e+00 logL: 1.38e+00 KL: 3.86e-01 MMD: 1.50e-02
Epoch    20: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 24 VALIDATION Loss: 2.12e+00 logL: 1.37e+00 KL: 3.84e-01 MMD: 1.53e-02
config 13, alpha = 0.0, lambda = 11607.8, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.96e+02 logL: 1.99e+00 KL: 1.16e-02 MMD: 1.67e-02
Epoch    15: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 1.79e+02 logL: 1.99e+00 KL: 8.15e-03 MMD: 1.52e-02
Epoch    22: reducing learning rate of group 0 to 1.0000e-06.
====> Epoch: 30 VALIDATION Loss: 1.68e+02 logL: 1.99e+00 KL: 8.01e-03 MMD: 1.43e-02
Epoch    32: reducing learning rate of group 0 to 1.0000e-07.
Stopping
====> Epoch: 36 VALIDATION Loss: 1.89e+02 logL: 1.99e+00 KL: 7.95e-03 MMD: 1.61e-02
config 14, alpha = 0.0, lambda = 15272.5, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.37e+02 logL: 1.99e+00 KL: 9.64e-03 MMD: 1.54e-02
Epoch    10: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 14 VALIDATION Loss: 2.42e+02 logL: 1.99e+00 KL: 9.31e-03 MMD: 1.57e-02
config 15, alpha = 0.0, lambda = 2.9, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.78e+00 logL: 1.34e+00 KL: 4.14e-01 MMD: 1.54e-02
Epoch    14: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 18 VALIDATION Loss: 1.78e+00 logL: 1.34e+00 KL: 4.17e-01 MMD: 1.54e-02
config 16, alpha = 0.0, lambda = 5.5, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.82e+00 logL: 1.33e+00 KL: 4.21e-01 MMD: 1.60e-02
====> Epoch: 20 VALIDATION Loss: 1.81e+00 logL: 1.32e+00 KL: 4.29e-01 MMD: 1.49e-02
Epoch    23: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 27 VALIDATION Loss: 1.81e+00 logL: 1.32e+00 KL: 4.25e-01 MMD: 1.60e-02
config 17, alpha = 0.0, lambda = 1.8, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.76e+00 logL: 1.33e+00 KL: 4.16e-01 MMD: 1.60e-02
====> Epoch: 20 VALIDATION Loss: 1.76e+00 logL: 1.32e+00 KL: 4.22e-01 MMD: 1.61e-02
Stopping
====> Epoch: 24 VALIDATION Loss: 1.76e+00 logL: 1.32e+00 KL: 4.24e-01 MMD: 1.57e-02
config 18, alpha = 0.0, lambda = 1.2, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.75e+00 logL: 1.34e+00 KL: 4.12e-01 MMD: 1.55e-02
Epoch    18: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 1.75e+00 logL: 1.32e+00 KL: 4.22e-01 MMD: 1.59e-02
Epoch    29: reducing learning rate of group 0 to 1.0000e-06.
====> Epoch: 30 VALIDATION Loss: 1.74e+00 logL: 1.32e+00 KL: 4.18e-01 MMD: 1.56e-02
Stopping
====> Epoch: 33 VALIDATION Loss: 1.75e+00 logL: 1.33e+00 KL: 4.20e-01 MMD: 1.59e-02
config 19, alpha = 0.0, lambda = 4862.6, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 7.87e+01 logL: 1.99e+00 KL: 9.19e-03 MMD: 1.58e-02
Epoch    10: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 14 VALIDATION Loss: 8.03e+01 logL: 1.99e+00 KL: 8.80e-03 MMD: 1.61e-02
config 20, alpha = 0.0, lambda = 28099.4, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 4.16e+02 logL: 2.00e+00 KL: 5.23e-03 MMD: 1.47e-02
Epoch    10: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 4.43e+02 logL: 2.00e+00 KL: 5.00e-03 MMD: 1.57e-02
Epoch    20: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 24 VALIDATION Loss: 4.41e+02 logL: 2.00e+00 KL: 4.90e-03 MMD: 1.56e-02
config 21, alpha = 0.0, lambda = 2.1, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.77e+00 logL: 1.34e+00 KL: 4.11e-01 MMD: 1.58e-02
Stopping
====> Epoch: 17 VALIDATION Loss: 1.77e+00 logL: 1.34e+00 KL: 4.10e-01 MMD: 1.64e-02
config 22, alpha = 0.0, lambda = 13.4, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.97e+00 logL: 1.39e+00 KL: 3.73e-01 MMD: 1.61e-02
Epoch    11: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 1.93e+00 logL: 1.35e+00 KL: 4.01e-01 MMD: 1.47e-02
Epoch    21: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 25 VALIDATION Loss: 1.93e+00 logL: 1.35e+00 KL: 4.03e-01 MMD: 1.43e-02
config 23, alpha = 0.0, lambda = 227.4, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch     9: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 10 VALIDATION Loss: 5.36e+00 logL: 1.95e+00 KL: 2.55e-02 MMD: 1.49e-02
Stopping
====> Epoch: 13 VALIDATION Loss: 5.25e+00 logL: 1.95e+00 KL: 2.74e-02 MMD: 1.44e-02
config 24, alpha = 0.0, lambda = 2.2, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.77e+00 logL: 1.35e+00 KL: 4.04e-01 MMD: 1.61e-02
====> Epoch: 20 VALIDATION Loss: 1.76e+00 logL: 1.31e+00 KL: 4.27e-01 MMD: 1.53e-02
Epoch    26: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 30 VALIDATION Loss: 1.77e+00 logL: 1.33e+00 KL: 4.21e-01 MMD: 1.52e-02
Stopping
====> Epoch: 30 VALIDATION Loss: 1.77e+00 logL: 1.33e+00 KL: 4.21e-01 MMD: 1.52e-02
config 25, alpha = 0.0, lambda = 3.2, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.79e+00 logL: 1.35e+00 KL: 3.98e-01 MMD: 1.75e-02
Epoch    15: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 1.79e+00 logL: 1.33e+00 KL: 4.17e-01 MMD: 1.77e-02
Epoch    22: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 26 VALIDATION Loss: 1.78e+00 logL: 1.34e+00 KL: 4.12e-01 MMD: 1.53e-02
config 26, alpha = 0.0, lambda = 17383.8, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.62e+02 logL: 1.99e+00 KL: 9.89e-03 MMD: 1.50e-02
Epoch    10: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 14 VALIDATION Loss: 2.73e+02 logL: 1.99e+00 KL: 8.82e-03 MMD: 1.56e-02
config 27, alpha = 0.0, lambda = 76.9, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.94e+00 logL: 1.50e+00 KL: 2.83e-01 MMD: 1.52e-02
Epoch    17: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 2.88e+00 logL: 1.47e+00 KL: 3.11e-01 MMD: 1.44e-02
Epoch    24: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 28 VALIDATION Loss: 2.98e+00 logL: 1.45e+00 KL: 3.21e-01 MMD: 1.59e-02
config 28, alpha = 0.0, lambda = 37497.8, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 5.23e+02 logL: 1.98e+00 KL: 1.13e-02 MMD: 1.39e-02
Epoch    16: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 5.52e+02 logL: 1.99e+00 KL: 9.56e-03 MMD: 1.47e-02
Stopping
====> Epoch: 20 VALIDATION Loss: 5.52e+02 logL: 1.99e+00 KL: 9.56e-03 MMD: 1.47e-02
config 29, alpha = 0.0, lambda = 73.2, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.85e+00 logL: 1.48e+00 KL: 2.99e-01 MMD: 1.48e-02
Epoch    12: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 16 VALIDATION Loss: 2.95e+00 logL: 1.49e+00 KL: 2.98e-01 MMD: 1.61e-02
config 30, alpha = 0.0, lambda = 1.3, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.75e+00 logL: 1.33e+00 KL: 4.14e-01 MMD: 1.59e-02
====> Epoch: 20 VALIDATION Loss: 1.76e+00 logL: 1.34e+00 KL: 4.13e-01 MMD: 1.58e-02
Epoch    28: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 30 VALIDATION Loss: 1.75e+00 logL: 1.32e+00 KL: 4.24e-01 MMD: 1.61e-02
Stopping
====> Epoch: 32 VALIDATION Loss: 1.75e+00 logL: 1.33e+00 KL: 4.16e-01 MMD: 1.68e-02
config 31, alpha = 0.0, lambda = 87.7, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 3.16e+00 logL: 1.55e+00 KL: 2.51e-01 MMD: 1.56e-02
Epoch    18: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 2.96e+00 logL: 1.48e+00 KL: 3.02e-01 MMD: 1.36e-02
Epoch    26: reducing learning rate of group 0 to 1.0000e-06.
====> Epoch: 30 VALIDATION Loss: 3.07e+00 logL: 1.48e+00 KL: 3.04e-01 MMD: 1.48e-02
Stopping
====> Epoch: 30 VALIDATION Loss: 3.07e+00 logL: 1.48e+00 KL: 3.04e-01 MMD: 1.48e-02
config 32, alpha = 0.0, lambda = 13.2, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.96e+00 logL: 1.37e+00 KL: 3.98e-01 MMD: 1.60e-02
Epoch    18: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 1.94e+00 logL: 1.34e+00 KL: 4.15e-01 MMD: 1.51e-02
Epoch    28: reducing learning rate of group 0 to 1.0000e-06.
====> Epoch: 30 VALIDATION Loss: 1.94e+00 logL: 1.33e+00 KL: 4.19e-01 MMD: 1.53e-02
Epoch    37: reducing learning rate of group 0 to 1.0000e-07.
====> Epoch: 40 VALIDATION Loss: 1.94e+00 logL: 1.34e+00 KL: 4.17e-01 MMD: 1.52e-02
Stopping
====> Epoch: 41 VALIDATION Loss: 1.96e+00 logL: 1.34e+00 KL: 4.17e-01 MMD: 1.74e-02
config 33, alpha = 0.0, lambda = 47141.6, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 7.14e+02 logL: 1.99e+00 KL: 1.20e-02 MMD: 1.51e-02
====> Epoch: 20 VALIDATION Loss: 6.71e+02 logL: 1.99e+00 KL: 6.48e-03 MMD: 1.42e-02
Epoch    20: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 24 VALIDATION Loss: 6.85e+02 logL: 1.99e+00 KL: 6.04e-03 MMD: 1.45e-02
config 34, alpha = 0.0, lambda = 17.0, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.01e+00 logL: 1.41e+00 KL: 3.63e-01 MMD: 1.51e-02
Epoch    10: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 14 VALIDATION Loss: 2.01e+00 logL: 1.39e+00 KL: 3.70e-01 MMD: 1.55e-02
config 35, alpha = 0.0, lambda = 7.8, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.87e+00 logL: 1.35e+00 KL: 4.10e-01 MMD: 1.59e-02
====> Epoch: 20 VALIDATION Loss: 1.86e+00 logL: 1.32e+00 KL: 4.26e-01 MMD: 1.64e-02
Epoch    22: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 30 VALIDATION Loss: 1.86e+00 logL: 1.33e+00 KL: 4.21e-01 MMD: 1.62e-02
Epoch    32: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 36 VALIDATION Loss: 1.87e+00 logL: 1.33e+00 KL: 4.21e-01 MMD: 1.72e-02
config 36, alpha = 0.0, lambda = 23687.5, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch     7: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 10 VALIDATION Loss: 3.66e+02 logL: 1.98e+00 KL: 1.27e-02 MMD: 1.54e-02
====> Epoch: 20 VALIDATION Loss: 3.84e+02 logL: 1.98e+00 KL: 1.28e-02 MMD: 1.61e-02
Epoch    22: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 26 VALIDATION Loss: 3.71e+02 logL: 1.99e+00 KL: 1.25e-02 MMD: 1.56e-02
config 37, alpha = 0.0, lambda = 6297.0, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch     7: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 10 VALIDATION Loss: 1.06e+02 logL: 2.00e+00 KL: 6.48e-03 MMD: 1.66e-02
Stopping
====> Epoch: 11 VALIDATION Loss: 1.02e+02 logL: 2.00e+00 KL: 6.90e-03 MMD: 1.59e-02
config 38, alpha = 0.0, lambda = 42.5, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.46e+00 logL: 1.47e+00 KL: 3.09e-01 MMD: 1.65e-02
====> Epoch: 20 VALIDATION Loss: 2.41e+00 logL: 1.36e+00 KL: 4.07e-01 MMD: 1.56e-02
Epoch    29: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 30 VALIDATION Loss: 2.44e+00 logL: 1.35e+00 KL: 4.10e-01 MMD: 1.66e-02
Stopping
====> Epoch: 33 VALIDATION Loss: 2.33e+00 logL: 1.35e+00 KL: 4.11e-01 MMD: 1.39e-02
config 39, alpha = 0.0, lambda = 2082.6, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch     7: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 10 VALIDATION Loss: 3.48e+01 logL: 1.98e+00 KL: 1.27e-02 MMD: 1.58e-02
Stopping
====> Epoch: 11 VALIDATION Loss: 3.18e+01 logL: 1.98e+00 KL: 1.24e-02 MMD: 1.43e-02
config 40, alpha = 0.0, lambda = 3591.9, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch     8: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 10 VALIDATION Loss: 5.13e+01 logL: 1.99e+00 KL: 7.48e-03 MMD: 1.37e-02
Stopping
====> Epoch: 12 VALIDATION Loss: 5.01e+01 logL: 1.99e+00 KL: 7.74e-03 MMD: 1.34e-02
config 41, alpha = 0.0, lambda = 8714.2, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch     7: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 10 VALIDATION Loss: 1.35e+02 logL: 1.98e+00 KL: 1.09e-02 MMD: 1.53e-02
Epoch    14: reducing learning rate of group 0 to 1.0000e-06.
Stopping
====> Epoch: 18 VALIDATION Loss: 1.32e+02 logL: 1.98e+00 KL: 1.08e-02 MMD: 1.50e-02
config 42, alpha = 0.0, lambda = 7067.0, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch     8: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 10 VALIDATION Loss: 1.12e+02 logL: 1.98e+00 KL: 1.13e-02 MMD: 1.56e-02
Stopping
====> Epoch: 12 VALIDATION Loss: 1.15e+02 logL: 1.98e+00 KL: 1.19e-02 MMD: 1.60e-02
config 43, alpha = 0.0, lambda = 9.1, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.87e+00 logL: 1.34e+00 KL: 4.12e-01 MMD: 1.51e-02
Epoch    16: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 1.88e+00 logL: 1.34e+00 KL: 4.16e-01 MMD: 1.57e-02
Stopping
====> Epoch: 20 VALIDATION Loss: 1.88e+00 logL: 1.34e+00 KL: 4.16e-01 MMD: 1.57e-02
config 44, alpha = 0.0, lambda = 55339.6, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 8.75e+02 logL: 2.00e+00 KL: 8.05e-03 MMD: 1.58e-02
Epoch    12: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 16 VALIDATION Loss: 8.98e+02 logL: 1.99e+00 KL: 9.15e-03 MMD: 1.62e-02
config 45, alpha = 0.0, lambda = 18234.9, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.91e+02 logL: 1.99e+00 KL: 7.78e-03 MMD: 1.58e-02
Epoch    14: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 18 VALIDATION Loss: 2.86e+02 logL: 1.99e+00 KL: 6.66e-03 MMD: 1.56e-02
config 46, alpha = 0.0, lambda = 25.3, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 2.18e+00 logL: 1.42e+00 KL: 3.53e-01 MMD: 1.67e-02
====> Epoch: 20 VALIDATION Loss: 2.13e+00 logL: 1.36e+00 KL: 4.08e-01 MMD: 1.52e-02
Epoch    20: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 30 VALIDATION Loss: 2.15e+00 logL: 1.36e+00 KL: 4.06e-01 MMD: 1.60e-02
Epoch    34: reducing learning rate of group 0 to 1.0000e-06.
====> Epoch: 40 VALIDATION Loss: 2.10e+00 logL: 1.35e+00 KL: 4.07e-01 MMD: 1.42e-02
Epoch    43: reducing learning rate of group 0 to 1.0000e-07.
Stopping
====> Epoch: 47 VALIDATION Loss: 2.15e+00 logL: 1.34e+00 KL: 4.07e-01 MMD: 1.65e-02
config 47, alpha = 0.0, lambda = 96091.7, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch     7: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 10 VALIDATION Loss: 1.36e+03 logL: 1.99e+00 KL: 1.10e-02 MMD: 1.42e-02
Stopping
====> Epoch: 11 VALIDATION Loss: 1.46e+03 logL: 1.99e+00 KL: 1.18e-02 MMD: 1.52e-02
config 48, alpha = 0.0, lambda = 12.6, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.96e+00 logL: 1.37e+00 KL: 3.86e-01 MMD: 1.75e-02
Epoch    18: reducing learning rate of group 0 to 1.0000e-05.
====> Epoch: 20 VALIDATION Loss: 1.92e+00 logL: 1.33e+00 KL: 4.13e-01 MMD: 1.53e-02
Epoch    26: reducing learning rate of group 0 to 1.0000e-06.
====> Epoch: 30 VALIDATION Loss: 1.93e+00 logL: 1.33e+00 KL: 4.16e-01 MMD: 1.54e-02
Stopping
====> Epoch: 30 VALIDATION Loss: 1.93e+00 logL: 1.33e+00 KL: 4.16e-01 MMD: 1.54e-02
config 49, alpha = 0.0, lambda = 4.6, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.81e+00 logL: 1.34e+00 KL: 4.09e-01 MMD: 1.65e-02
====> Epoch: 20 VALIDATION Loss: 1.81e+00 logL: 1.33e+00 KL: 4.15e-01 MMD: 1.69e-02
Epoch    21: reducing learning rate of group 0 to 1.0000e-05.
Stopping
====> Epoch: 21 VALIDATION Loss: 1.81e+00 logL: 1.34e+00 KL: 4.16e-01 MMD: 1.54e-02
config 50, alpha = 0.0, lambda = 3.5, dropout = 0.00; 2 hidden layers with 200, 200 nodes


HBox(children=(IntProgress(value=0), HTML(value='')))

====> Epoch: 10 VALIDATION Loss: 1.80e+00 logL: 1.34e+00 KL: 4.10e-01 MMD: 1.68e-02
