In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import pandas as pd
import os

In [2]:
import torch
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from nflows.transforms.dequantization import UniformDequantization
from nflows.transforms.dequantization import VariationalDequantization

In [4]:
device = torch.device("cuda:0")
#device = torch.device("cpu")

In [5]:
# Make model folder
if not os.path.exists("models"):
    os.makedirs("models")

In [6]:
# Import data and weights
x_data_raw = torch.tensor(pd.read_csv("ee_data.csv", header=None, delimiter=",").to_numpy(), dtype=torch.float32, device=device)
x_weights_raw = torch.tensor(pd.read_csv("ee_weights.csv", header=None, delimiter=",").to_numpy(), dtype=torch.float32, device=device).squeeze()

# Permute the data
permutation = torch.randperm(x_data_raw.shape[0])
x_data_raw = x_data_raw[permutation]
x_weights_raw = x_weights_raw[permutation]

# Normalize weights by mean
x_weights_raw /= x_weights_raw.mean()

# Chop up the data into training and validation
data_size = x_data_raw.shape[0]
training_size = int(data_size*0.8)

x_data_train    = x_data_raw[:training_size]
x_weights_train = x_weights_raw[:training_size]
x_data_test     = x_data_raw[training_size:]
x_weights_test  = x_weights_raw[training_size:]

# Data dimension
data_dim = x_data_raw.shape[1]

In [7]:
# Determine the maximum labels of discrete dims
max_features = torch.max(x_data_raw, dim=0)[0]
is_discrete = max_features > 1
max_labels = torch.where(max_features > 1, max_features, torch.tensor(-1., device=device))

In [8]:
# Tensorboard writer
writer = SummaryWriter()

In [9]:
num_layers = 6
base_dist_uniform = BoxUniform(torch.zeros(data_dim), torch.ones(data_dim))
base_dist_variational = BoxUniform(torch.zeros(data_dim), torch.ones(data_dim))

transforms_uniform = []
transforms_variational = []

transforms_uniform.append(UniformDequantization(max_labels=max_labels))
transforms_variational.append(VariationalDequantization(max_labels=max_labels, rqs_hidden_features=15, rqs_flow_layers=2))

for _ in range(num_layers):
    transforms_uniform.append(RandomPermutation(features=data_dim))
    transforms_uniform.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=data_dim, 
        hidden_features=26,
        num_bins=10,
        num_blocks=4,
    ))

    transforms_variational.append(RandomPermutation(features=data_dim))
    transforms_variational.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=data_dim, 
        hidden_features=25,
        num_bins=10,
        num_blocks=4,
    ))

transform_uniform = CompositeTransform(transforms_uniform)
transform_variational = CompositeTransform(transforms_variational)

flow_uniform = Flow(transform_uniform, base_dist_uniform).to(device)
flow_variational = Flow(transform_variational, base_dist_variational).to(device)

optimizer_uniform = optim.Adam(flow_uniform.parameters())
optimizer_variational = optim.Adam(flow_variational.parameters())

In [10]:
# Number of parameters in uniform and variational models
uniform_parameters = filter(lambda p: p.requires_grad, flow_uniform.parameters())
variational_parameters = filter(lambda p: p.requires_grad, flow_variational.parameters())
num_uniform_parameters = sum([np.prod(p.size()) for p in uniform_parameters])
num_variational_parameters = sum([np.prod(p.size()) for p in variational_parameters])

print(num_uniform_parameters, num_variational_parameters)


54564 54778


In [11]:
n_epochs = 200
batch_size = 10000
n_batches = m.ceil(x_data_train.shape[0]/batch_size)

for epoch in range(n_epochs):
    permutation = torch.randperm(x_data_train.shape[0], device=device)    

    # Loop over batches
    cum_loss_uniform = 0
    cum_loss_variational = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, x_data_train.shape[0]-1 )
        indices = permutation[batch_begin:batch_end]
        x_data_train_batch = x_data_train[indices]
        x_weights_train_batch = x_weights_train[indices]
        
        # Take a step
        optimizer_uniform.zero_grad()
        optimizer_variational.zero_grad()

        loss_uniform = -(flow_uniform.log_prob(inputs=x_data_train_batch)*x_weights_train_batch).mean()
        loss_variational = -(flow_variational.log_prob(inputs=x_data_train_batch)*x_weights_train_batch).mean()

        loss_uniform.backward()
        loss_variational.backward()

        optimizer_uniform.step()
        optimizer_variational.step()

        # Compute cumulative loss
        cum_loss_uniform = (cum_loss_uniform*batch + loss_uniform.item())/(batch+1)
        cum_loss_variational = (cum_loss_variational*batch + loss_variational.item())/(batch+1)
        
        if batch%10 == 0:
            print("epoch = ", epoch, "batch = ",batch+1, "/", n_batches, "loss_uniform = ", cum_loss_uniform, " loss_variational = ", cum_loss_variational)
    
    # Validation log prob
    #loss_uniform_validation     = -(flow_uniform.log_prob(inputs=x_data_test)*x_weights_test).mean()
    #loss_variational_validation = -(flow_variational.log_prob(inputs=x_data_test)*x_weights_test).mean()

    writer.add_scalar("Loss_train/uniform_loss_train", cum_loss_uniform, epoch)
    writer.add_scalar("Loss_train/variational_loss_train", cum_loss_variational, epoch)

    #writer.add_scalar("Loss_test/uniform_loss_test", loss_uniform_validation, epoch)
    #writer.add_scalar("Loss_test/variational_loss_test", loss_variational_validation, epoch)

    torch.save(flow_uniform, "models/uniform_model.pt")
    torch.save(flow_variational, "models/variational_model.pt")

epoch =  0 batch =  1 / 101 loss_uniform =  1.039359211921692  loss_variational =  3.0154170989990234
epoch =  0 batch =  11 / 101 loss_uniform =  -0.4221790592101487  loss_variational =  0.727385669269345
epoch =  0 batch =  21 / 101 loss_uniform =  -1.3507582628656  loss_variational =  -0.6958527525975591
epoch =  0 batch =  31 / 101 loss_uniform =  -2.1927127260113917  loss_variational =  -1.6664091206846698
epoch =  0 batch =  41 / 101 loss_uniform =  -2.894826833703896  loss_variational =  -2.347036150352257
epoch =  0 batch =  51 / 101 loss_uniform =  -3.4124934742701987  loss_variational =  -2.8667809058930365
epoch =  0 batch =  61 / 101 loss_uniform =  -3.82716992833331  loss_variational =  -3.3034394829243916
epoch =  0 batch =  71 / 101 loss_uniform =  -4.160175432569124  loss_variational =  -3.6636405315407563
epoch =  0 batch =  81 / 101 loss_uniform =  -4.44180588805933  loss_variational =  -3.9787511815443453
epoch =  0 batch =  91 / 101 loss_uniform =  -4.67704908791315

epoch =  7 batch =  21 / 101 loss_uniform =  -6.7856463477725075  loss_variational =  -7.110301721663702
epoch =  7 batch =  31 / 101 loss_uniform =  -6.784555681290165  loss_variational =  -7.075018129041118
epoch =  7 batch =  41 / 101 loss_uniform =  -6.838515688733357  loss_variational =  -7.102017274717006
epoch =  7 batch =  51 / 101 loss_uniform =  -6.860841395808201  loss_variational =  -7.1079448625153185
epoch =  7 batch =  61 / 101 loss_uniform =  -6.876746552889465  loss_variational =  -7.113179488260238
epoch =  7 batch =  71 / 101 loss_uniform =  -6.891586216402725  loss_variational =  -7.119035546208771
epoch =  7 batch =  81 / 101 loss_uniform =  -6.910512376714636  loss_variational =  -7.129239317811566
epoch =  7 batch =  91 / 101 loss_uniform =  -6.919445734757644  loss_variational =  -7.131008672190236
epoch =  7 batch =  101 / 101 loss_uniform =  -6.861651198108597  loss_variational =  -7.066060063567492
epoch =  8 batch =  1 / 101 loss_uniform =  -7.41647005081176

epoch =  14 batch =  41 / 101 loss_uniform =  -6.41842538554494  loss_variational =  -7.024178039736864
epoch =  14 batch =  51 / 101 loss_uniform =  -6.525918567881865  loss_variational =  -7.06086025986017
epoch =  14 batch =  61 / 101 loss_uniform =  -6.611629126501865  loss_variational =  -7.091600848025963
epoch =  14 batch =  71 / 101 loss_uniform =  -6.682871214101012  loss_variational =  -7.120682649209466
epoch =  14 batch =  81 / 101 loss_uniform =  -6.738204543973192  loss_variational =  -7.143464506408314
epoch =  14 batch =  91 / 101 loss_uniform =  -6.78348910677564  loss_variational =  -7.1603276912982645
epoch =  14 batch =  101 / 101 loss_uniform =  -6.805865061165083  loss_variational =  -7.171474626748869
epoch =  15 batch =  1 / 101 loss_uniform =  -6.836126327514648  loss_variational =  -7.012639045715332
epoch =  15 batch =  11 / 101 loss_uniform =  -6.655221288854426  loss_variational =  -6.519225077195601
epoch =  15 batch =  21 / 101 loss_uniform =  -6.74365490

epoch =  21 batch =  51 / 101 loss_uniform =  -7.322493440964642  loss_variational =  -7.390672384523878
epoch =  21 batch =  61 / 101 loss_uniform =  -7.340564876306252  loss_variational =  -7.406331875285164
epoch =  21 batch =  71 / 101 loss_uniform =  -7.327074151643565  loss_variational =  -7.392808578383755
epoch =  21 batch =  81 / 101 loss_uniform =  -7.321454154120551  loss_variational =  -7.387681201652244
epoch =  21 batch =  91 / 101 loss_uniform =  -7.320013020064804  loss_variational =  -7.385328539125212
epoch =  21 batch =  101 / 101 loss_uniform =  -7.241978583949627  loss_variational =  -7.307405410426678
epoch =  22 batch =  1 / 101 loss_uniform =  -7.4455437660217285  loss_variational =  -7.572098255157471
epoch =  22 batch =  11 / 101 loss_uniform =  -7.322871988469904  loss_variational =  -7.421155366030606
epoch =  22 batch =  21 / 101 loss_uniform =  -7.297795136769612  loss_variational =  -7.389173098972866
epoch =  22 batch =  31 / 101 loss_uniform =  -7.30463

epoch =  28 batch =  61 / 101 loss_uniform =  -7.353679961845523  loss_variational =  -7.403968451453037
epoch =  28 batch =  71 / 101 loss_uniform =  -7.356687666664661  loss_variational =  -7.404500793403303
epoch =  28 batch =  81 / 101 loss_uniform =  -7.361766497294108  loss_variational =  -7.405957398591219
epoch =  28 batch =  91 / 101 loss_uniform =  -7.356721238775568  loss_variational =  -7.402060016170963
epoch =  28 batch =  101 / 101 loss_uniform =  -7.517218877773474  loss_variational =  -7.561400989494701
epoch =  29 batch =  1 / 101 loss_uniform =  -6.8059234619140625  loss_variational =  -6.869312286376953
epoch =  29 batch =  11 / 101 loss_uniform =  -6.663770719008013  loss_variational =  -6.818626360459761
epoch =  29 batch =  21 / 101 loss_uniform =  -6.827008769625709  loss_variational =  -6.99235150927589
epoch =  29 batch =  31 / 101 loss_uniform =  -6.908705249909432  loss_variational =  -7.068680132589033
epoch =  29 batch =  41 / 101 loss_uniform =  -6.976074

epoch =  35 batch =  81 / 101 loss_uniform =  -7.391881507120015  loss_variational =  -7.392819816683546
epoch =  35 batch =  91 / 101 loss_uniform =  -7.396471406077291  loss_variational =  -7.3978506077776895
epoch =  35 batch =  101 / 101 loss_uniform =  -7.3101106109743075  loss_variational =  -7.312373705429606
epoch =  36 batch =  1 / 101 loss_uniform =  -7.288185119628906  loss_variational =  -7.292384147644043
epoch =  36 batch =  11 / 101 loss_uniform =  -7.347646279768511  loss_variational =  -7.360758564688942
epoch =  36 batch =  21 / 101 loss_uniform =  -7.313851356506348  loss_variational =  -7.3291018803914385
epoch =  36 batch =  31 / 101 loss_uniform =  -7.339937886884136  loss_variational =  -7.3503878808790635
epoch =  36 batch =  41 / 101 loss_uniform =  -7.38087825077336  loss_variational =  -7.389080943130866
epoch =  36 batch =  51 / 101 loss_uniform =  -7.3812214626985435  loss_variational =  -7.393009307337742
epoch =  36 batch =  61 / 101 loss_uniform =  -7.36

epoch =  42 batch =  101 / 101 loss_uniform =  -7.402964256777622  loss_variational =  -7.415763264835471
epoch =  43 batch =  1 / 101 loss_uniform =  -7.025403022766113  loss_variational =  -6.889221668243408
epoch =  43 batch =  11 / 101 loss_uniform =  -7.123179479078813  loss_variational =  -6.792905287309126
epoch =  43 batch =  21 / 101 loss_uniform =  -7.256330194927397  loss_variational =  -6.984611624763126
epoch =  43 batch =  31 / 101 loss_uniform =  -7.2991119815457255  loss_variational =  -7.086299803949172
epoch =  43 batch =  41 / 101 loss_uniform =  -7.320552023445687  loss_variational =  -7.147930075482624
epoch =  43 batch =  51 / 101 loss_uniform =  -7.3529273575427485  loss_variational =  -7.207542877571256
epoch =  43 batch =  61 / 101 loss_uniform =  -7.3648775053805995  loss_variational =  -7.239472631548272
epoch =  43 batch =  71 / 101 loss_uniform =  -7.38234818821222  loss_variational =  -7.272902676756953
epoch =  43 batch =  81 / 101 loss_uniform =  -7.3814

epoch =  50 batch =  11 / 101 loss_uniform =  -7.386488090861928  loss_variational =  -7.161863673817027
epoch =  50 batch =  21 / 101 loss_uniform =  -7.3686084520249135  loss_variational =  -7.199963433401925
epoch =  50 batch =  31 / 101 loss_uniform =  -7.396278350583969  loss_variational =  -7.273980479086599
epoch =  50 batch =  41 / 101 loss_uniform =  -7.420491288347942  loss_variational =  -7.327030298186512
epoch =  50 batch =  51 / 101 loss_uniform =  -7.401922263351142  loss_variational =  -7.326568089279474
epoch =  50 batch =  61 / 101 loss_uniform =  -7.39952021739522  loss_variational =  -7.339176295233554
epoch =  50 batch =  71 / 101 loss_uniform =  -7.39086263952121  loss_variational =  -7.340039871108364
epoch =  50 batch =  81 / 101 loss_uniform =  -7.396597862243652  loss_variational =  -7.35361236407433
epoch =  50 batch =  91 / 101 loss_uniform =  -7.405312732025817  loss_variational =  -7.367074264274849
epoch =  50 batch =  101 / 101 loss_uniform =  -7.3306093

epoch =  57 batch =  21 / 101 loss_uniform =  -7.356587591625395  loss_variational =  -7.378618830726261
epoch =  57 batch =  31 / 101 loss_uniform =  -7.360534729496125  loss_variational =  -7.3898392338906564
epoch =  57 batch =  41 / 101 loss_uniform =  -7.38445563432647  loss_variational =  -7.416266732099579
epoch =  57 batch =  51 / 101 loss_uniform =  -7.383699847202675  loss_variational =  -7.417967534532734
epoch =  57 batch =  61 / 101 loss_uniform =  -7.39681941954816  loss_variational =  -7.4315809343681964
epoch =  57 batch =  71 / 101 loss_uniform =  -7.384311998394174  loss_variational =  -7.420406529601191
epoch =  57 batch =  81 / 101 loss_uniform =  -7.38902156735644  loss_variational =  -7.426509056562259
epoch =  57 batch =  91 / 101 loss_uniform =  -7.393632773514632  loss_variational =  -7.432744806939429
epoch =  57 batch =  101 / 101 loss_uniform =  -7.349489217937583  loss_variational =  -7.383152953468927
epoch =  58 batch =  1 / 101 loss_uniform =  -7.5771927

epoch =  64 batch =  41 / 101 loss_uniform =  -7.438616892186607  loss_variational =  -7.475194128548226
epoch =  64 batch =  51 / 101 loss_uniform =  -7.445355396644742  loss_variational =  -7.482899852827484
epoch =  64 batch =  61 / 101 loss_uniform =  -7.455211186018146  loss_variational =  -7.492806176670262
epoch =  64 batch =  71 / 101 loss_uniform =  -7.454541508580597  loss_variational =  -7.495872235633958
epoch =  64 batch =  81 / 101 loss_uniform =  -7.441670453106916  loss_variational =  -7.4865883662376875
epoch =  64 batch =  91 / 101 loss_uniform =  -7.438810804388026  loss_variational =  -7.484534195491245
epoch =  64 batch =  101 / 101 loss_uniform =  -7.718545937302089  loss_variational =  -7.759276125690724
epoch =  65 batch =  1 / 101 loss_uniform =  -6.787951469421387  loss_variational =  -5.48668098449707
epoch =  65 batch =  11 / 101 loss_uniform =  -6.909181421453303  loss_variational =  -5.2308289354497735
epoch =  65 batch =  21 / 101 loss_uniform =  -7.03724

epoch =  71 batch =  61 / 101 loss_uniform =  -7.329546099803487  loss_variational =  -7.419966299025739
epoch =  71 batch =  71 / 101 loss_uniform =  -7.353816475666744  loss_variational =  -7.43912574606882
epoch =  71 batch =  81 / 101 loss_uniform =  -7.362212516643383  loss_variational =  -7.4455832905239525
epoch =  71 batch =  91 / 101 loss_uniform =  -7.361663870759063  loss_variational =  -7.444990985996121
epoch =  71 batch =  101 / 101 loss_uniform =  -7.416605510333977  loss_variational =  -7.496553543770667
epoch =  72 batch =  1 / 101 loss_uniform =  -6.120482444763184  loss_variational =  -7.2599778175354
epoch =  72 batch =  11 / 101 loss_uniform =  -6.7115466811440205  loss_variational =  -7.189560370011763
epoch =  72 batch =  21 / 101 loss_uniform =  -6.810209773835682  loss_variational =  -7.2251581237429665
epoch =  72 batch =  31 / 101 loss_uniform =  -6.9446185481163765  loss_variational =  -7.262262898106729
epoch =  72 batch =  41 / 101 loss_uniform =  -7.00863

epoch =  78 batch =  71 / 101 loss_uniform =  -7.280148365128208  loss_variational =  -7.378096526777241
epoch =  78 batch =  81 / 101 loss_uniform =  -7.2892389768435635  loss_variational =  -7.383287400375178
epoch =  78 batch =  91 / 101 loss_uniform =  -7.312124603397244  loss_variational =  -7.402272538824396
epoch =  78 batch =  101 / 101 loss_uniform =  -7.371002768525983  loss_variational =  -7.455750116027228
epoch =  79 batch =  1 / 101 loss_uniform =  -6.927256107330322  loss_variational =  -7.150731086730957
epoch =  79 batch =  11 / 101 loss_uniform =  -7.1773762702941895  loss_variational =  -7.216276515613902
epoch =  79 batch =  21 / 101 loss_uniform =  -7.243048145657494  loss_variational =  -7.296059290568034
epoch =  79 batch =  31 / 101 loss_uniform =  -7.296152207159227  loss_variational =  -7.343765797153596
epoch =  79 batch =  41 / 101 loss_uniform =  -7.314559087520692  loss_variational =  -7.361050291759212
epoch =  79 batch =  51 / 101 loss_uniform =  -7.3485

epoch =  85 batch =  81 / 101 loss_uniform =  -7.281840236098678  loss_variational =  -7.364785371003328
epoch =  85 batch =  91 / 101 loss_uniform =  -7.2967889073130845  loss_variational =  -7.375802726536007
epoch =  85 batch =  101 / 101 loss_uniform =  -7.410561306641833  loss_variational =  -7.481069970839094
epoch =  86 batch =  1 / 101 loss_uniform =  -7.725101470947266  loss_variational =  -7.410797595977783
epoch =  86 batch =  11 / 101 loss_uniform =  -7.184988108548251  loss_variational =  -6.90491346879439
epoch =  86 batch =  21 / 101 loss_uniform =  -7.2482693535940985  loss_variational =  -7.068295683179583
epoch =  86 batch =  31 / 101 loss_uniform =  -7.274385913725822  loss_variational =  -7.144010989896713
epoch =  86 batch =  41 / 101 loss_uniform =  -7.297441377872374  loss_variational =  -7.203315676712409
epoch =  86 batch =  51 / 101 loss_uniform =  -7.3215760249717565  loss_variational =  -7.249546135173125
epoch =  86 batch =  61 / 101 loss_uniform =  -7.3387

epoch =  92 batch =  91 / 101 loss_uniform =  -7.423056015601525  loss_variational =  -7.437275593097393
epoch =  92 batch =  101 / 101 loss_uniform =  -7.434387952974527  loss_variational =  -7.451079854870787
epoch =  93 batch =  1 / 101 loss_uniform =  -7.56840705871582  loss_variational =  -7.591821670532227
epoch =  93 batch =  11 / 101 loss_uniform =  -7.217133652080189  loss_variational =  -7.32918830351396
epoch =  93 batch =  21 / 101 loss_uniform =  -7.273716427031017  loss_variational =  -7.354805719284784
epoch =  93 batch =  31 / 101 loss_uniform =  -7.306480823024627  loss_variational =  -7.37447770949333
epoch =  93 batch =  41 / 101 loss_uniform =  -7.36073310200761  loss_variational =  -7.421193506659531
epoch =  93 batch =  51 / 101 loss_uniform =  -7.378510811749627  loss_variational =  -7.433018862032423
epoch =  93 batch =  61 / 101 loss_uniform =  -7.388470016542028  loss_variational =  -7.438373956523958
epoch =  93 batch =  71 / 101 loss_uniform =  -7.4068563085

epoch =  100 batch =  11 / 101 loss_uniform =  -7.158711260015314  loss_variational =  -7.185826171528209
epoch =  100 batch =  21 / 101 loss_uniform =  -7.288526966458275  loss_variational =  -7.30496579124814
epoch =  100 batch =  31 / 101 loss_uniform =  -7.332135646573959  loss_variational =  -7.346684778890302
epoch =  100 batch =  41 / 101 loss_uniform =  -7.359499989486322  loss_variational =  -7.370259819961175
epoch =  100 batch =  51 / 101 loss_uniform =  -7.375064494563084  loss_variational =  -7.382372416701972
epoch =  100 batch =  61 / 101 loss_uniform =  -7.398993593747498  loss_variational =  -7.403986516546031
epoch =  100 batch =  71 / 101 loss_uniform =  -7.404127497068593  loss_variational =  -7.407502530326306
epoch =  100 batch =  81 / 101 loss_uniform =  -7.405507564544678  loss_variational =  -7.407821602291531
epoch =  100 batch =  91 / 101 loss_uniform =  -7.414361220139724  loss_variational =  -7.416144612071278
epoch =  100 batch =  101 / 101 loss_uniform = 

epoch =  107 batch =  21 / 101 loss_uniform =  -7.3597711608523415  loss_variational =  -7.380282061440604
epoch =  107 batch =  31 / 101 loss_uniform =  -7.375102935298797  loss_variational =  -7.394151949113415
epoch =  107 batch =  41 / 101 loss_uniform =  -7.412888526916504  loss_variational =  -7.432191418438423
epoch =  107 batch =  51 / 101 loss_uniform =  -7.429113930346919  loss_variational =  -7.446981028014538
epoch =  107 batch =  61 / 101 loss_uniform =  -7.431139985068899  loss_variational =  -7.450848759197798
epoch =  107 batch =  71 / 101 loss_uniform =  -7.430302290849283  loss_variational =  -7.449846999745973
epoch =  107 batch =  81 / 101 loss_uniform =  -7.42537992383227  loss_variational =  -7.444125752390167
epoch =  107 batch =  91 / 101 loss_uniform =  -7.434476873376867  loss_variational =  -7.452559030972994
epoch =  107 batch =  101 / 101 loss_uniform =  -7.380797221518979  loss_variational =  -7.398100386751761
epoch =  108 batch =  1 / 101 loss_uniform = 

epoch =  114 batch =  31 / 101 loss_uniform =  -7.460361957550049  loss_variational =  -7.474187635606335
epoch =  114 batch =  41 / 101 loss_uniform =  -7.456898840462289  loss_variational =  -7.471588285957894
epoch =  114 batch =  51 / 101 loss_uniform =  -7.43905221714693  loss_variational =  -7.454421015346751
epoch =  114 batch =  61 / 101 loss_uniform =  -7.443940451887787  loss_variational =  -7.458921268338063
epoch =  114 batch =  71 / 101 loss_uniform =  -7.451025734485035  loss_variational =  -7.465917909649057
epoch =  114 batch =  81 / 101 loss_uniform =  -7.453627068319438  loss_variational =  -7.46987826735885
epoch =  114 batch =  91 / 101 loss_uniform =  -7.450344405331454  loss_variational =  -7.467324251657004
epoch =  114 batch =  101 / 101 loss_uniform =  -7.387775963338295  loss_variational =  -7.405044752329883
epoch =  115 batch =  1 / 101 loss_uniform =  -7.103956699371338  loss_variational =  -7.2320942878723145
epoch =  115 batch =  11 / 101 loss_uniform =  

epoch =  121 batch =  41 / 101 loss_uniform =  -7.381101294261653  loss_variational =  -7.385495662689209
epoch =  121 batch =  51 / 101 loss_uniform =  -7.39835262298584  loss_variational =  -7.405893811992571
epoch =  121 batch =  61 / 101 loss_uniform =  -7.413904612181617  loss_variational =  -7.422740654867203
epoch =  121 batch =  71 / 101 loss_uniform =  -7.416240255597612  loss_variational =  -7.426713956913478
epoch =  121 batch =  81 / 101 loss_uniform =  -7.422697043713228  loss_variational =  -7.433774683210585
epoch =  121 batch =  91 / 101 loss_uniform =  -7.428634874113313  loss_variational =  -7.44021259035383
epoch =  121 batch =  101 / 101 loss_uniform =  -7.364614751374368  loss_variational =  -7.376717315451933
epoch =  122 batch =  1 / 101 loss_uniform =  -7.23789119720459  loss_variational =  -7.27773904800415
epoch =  122 batch =  11 / 101 loss_uniform =  -7.4040399031205615  loss_variational =  -7.432887900959361
epoch =  122 batch =  21 / 101 loss_uniform =  -7

epoch =  128 batch =  51 / 101 loss_uniform =  -7.461068564770269  loss_variational =  -7.501202985352161
epoch =  128 batch =  61 / 101 loss_uniform =  -7.442844664464231  loss_variational =  -7.482825287052842
epoch =  128 batch =  71 / 101 loss_uniform =  -7.430586996212812  loss_variational =  -7.475831072095414
epoch =  128 batch =  81 / 101 loss_uniform =  -7.43217858561763  loss_variational =  -7.477586010356008
epoch =  128 batch =  91 / 101 loss_uniform =  -7.446873235178518  loss_variational =  -7.491100164560171
epoch =  128 batch =  101 / 101 loss_uniform =  -7.492575914552896  loss_variational =  -7.5315690701550775
epoch =  129 batch =  1 / 101 loss_uniform =  -7.471210956573486  loss_variational =  -7.409645080566406
epoch =  129 batch =  11 / 101 loss_uniform =  -7.2560177716341885  loss_variational =  -7.136312918229536
epoch =  129 batch =  21 / 101 loss_uniform =  -7.289433116004581  loss_variational =  -7.225932212102981
epoch =  129 batch =  31 / 101 loss_uniform =

epoch =  135 batch =  61 / 101 loss_uniform =  -7.123063556483535  loss_variational =  -7.445354383499896
epoch =  135 batch =  71 / 101 loss_uniform =  -7.166750813873721  loss_variational =  -7.457881564825353
epoch =  135 batch =  81 / 101 loss_uniform =  -7.1989596743642545  loss_variational =  -7.464209356425721
epoch =  135 batch =  91 / 101 loss_uniform =  -7.210769532801031  loss_variational =  -7.453617557064518
epoch =  135 batch =  101 / 101 loss_uniform =  -7.150892365723848  loss_variational =  -7.374501482583582
epoch =  136 batch =  1 / 101 loss_uniform =  -7.099291801452637  loss_variational =  -7.141332626342773
epoch =  136 batch =  11 / 101 loss_uniform =  -7.395949233661998  loss_variational =  -7.444174159656871
epoch =  136 batch =  21 / 101 loss_uniform =  -7.429993629455566  loss_variational =  -7.47765538806007
epoch =  136 batch =  31 / 101 loss_uniform =  -7.40452294195852  loss_variational =  -7.446468045634608
epoch =  136 batch =  41 / 101 loss_uniform =  

epoch =  142 batch =  71 / 101 loss_uniform =  -7.402651900976476  loss_variational =  -7.386633510320959
epoch =  142 batch =  81 / 101 loss_uniform =  -7.402893089953764  loss_variational =  -7.3867900930804975
epoch =  142 batch =  91 / 101 loss_uniform =  -7.4191777889545145  loss_variational =  -7.39895163001595
epoch =  142 batch =  101 / 101 loss_uniform =  -7.354646633996969  loss_variational =  -7.332967581605484
epoch =  143 batch =  1 / 101 loss_uniform =  -7.793947696685791  loss_variational =  -7.785207748413086
epoch =  143 batch =  11 / 101 loss_uniform =  -7.5050152431834825  loss_variational =  -7.483800801363858
epoch =  143 batch =  21 / 101 loss_uniform =  -7.472840672447568  loss_variational =  -7.447501636686779
epoch =  143 batch =  31 / 101 loss_uniform =  -7.438653192212505  loss_variational =  -7.413944536639798
epoch =  143 batch =  41 / 101 loss_uniform =  -7.454971243695515  loss_variational =  -7.4319401834069225
epoch =  143 batch =  51 / 101 loss_uniform

epoch =  149 batch =  81 / 101 loss_uniform =  -7.43519691184715  loss_variational =  -7.427314128404782
epoch =  149 batch =  91 / 101 loss_uniform =  -7.435637814658029  loss_variational =  -7.429506930676135
epoch =  149 batch =  101 / 101 loss_uniform =  -7.355016975101121  loss_variational =  -7.349562082329009
epoch =  150 batch =  1 / 101 loss_uniform =  -7.471591949462891  loss_variational =  -7.537527084350586
epoch =  150 batch =  11 / 101 loss_uniform =  -7.394115144556219  loss_variational =  -7.4225773377852
epoch =  150 batch =  21 / 101 loss_uniform =  -7.4441241309756325  loss_variational =  -7.4662812777927945
epoch =  150 batch =  31 / 101 loss_uniform =  -7.444834570730886  loss_variational =  -7.463654764236942
epoch =  150 batch =  41 / 101 loss_uniform =  -7.44862856516024  loss_variational =  -7.453817181470917
epoch =  150 batch =  51 / 101 loss_uniform =  -7.446271232530182  loss_variational =  -7.444763146194757
epoch =  150 batch =  61 / 101 loss_uniform =  -

epoch =  156 batch =  91 / 101 loss_uniform =  -7.4316427786271655  loss_variational =  -7.449049546168401
epoch =  156 batch =  101 / 101 loss_uniform =  -7.364047991167201  loss_variational =  -7.379095143315816
epoch =  157 batch =  1 / 101 loss_uniform =  -7.661181926727295  loss_variational =  -7.6563920974731445
epoch =  157 batch =  11 / 101 loss_uniform =  -7.439349954778498  loss_variational =  -7.411531708457253
epoch =  157 batch =  21 / 101 loss_uniform =  -7.436200300852458  loss_variational =  -7.426065717424665
epoch =  157 batch =  31 / 101 loss_uniform =  -7.459768387579149  loss_variational =  -7.464341255926317
epoch =  157 batch =  41 / 101 loss_uniform =  -7.451111025926544  loss_variational =  -7.465354803131848
epoch =  157 batch =  51 / 101 loss_uniform =  -7.47261921564738  loss_variational =  -7.489489256166944
epoch =  157 batch =  61 / 101 loss_uniform =  -7.460259890947186  loss_variational =  -7.476978724120094
epoch =  157 batch =  71 / 101 loss_uniform =

epoch =  163 batch =  101 / 101 loss_uniform =  -7.509891042614927  loss_variational =  -7.350356682692424
epoch =  164 batch =  1 / 101 loss_uniform =  -7.781054496765137  loss_variational =  -7.59788179397583
epoch =  164 batch =  11 / 101 loss_uniform =  -7.362235762856224  loss_variational =  -7.251187888058749
epoch =  164 batch =  21 / 101 loss_uniform =  -7.384686129433768  loss_variational =  -7.270432699294317
epoch =  164 batch =  31 / 101 loss_uniform =  -7.3980097616872476  loss_variational =  -7.2833964132493545
epoch =  164 batch =  41 / 101 loss_uniform =  -7.408263125070712  loss_variational =  -7.293769231656703
epoch =  164 batch =  51 / 101 loss_uniform =  -7.4401353293774175  loss_variational =  -7.327055585150625
epoch =  164 batch =  61 / 101 loss_uniform =  -7.446140187685607  loss_variational =  -7.334395635323446
epoch =  164 batch =  71 / 101 loss_uniform =  -7.4323704477766865  loss_variational =  -7.3208129909676565
epoch =  164 batch =  81 / 101 loss_unifor

epoch =  171 batch =  11 / 101 loss_uniform =  -7.457661368630149  loss_variational =  -7.439451564442027
epoch =  171 batch =  21 / 101 loss_uniform =  -7.465119793301537  loss_variational =  -7.4474135580517
epoch =  171 batch =  31 / 101 loss_uniform =  -7.447474018219979  loss_variational =  -7.4297140029168895
epoch =  171 batch =  41 / 101 loss_uniform =  -7.444770196589028  loss_variational =  -7.43164232300549
epoch =  171 batch =  51 / 101 loss_uniform =  -7.43800687789917  loss_variational =  -7.426338897031896
epoch =  171 batch =  61 / 101 loss_uniform =  -7.424772184403216  loss_variational =  -7.414505911655113
epoch =  171 batch =  71 / 101 loss_uniform =  -7.434064549459538  loss_variational =  -7.42425531736562
epoch =  171 batch =  81 / 101 loss_uniform =  -7.448674025358977  loss_variational =  -7.438159972061346
epoch =  171 batch =  91 / 101 loss_uniform =  -7.463545862134996  loss_variational =  -7.452216295095591
epoch =  171 batch =  101 / 101 loss_uniform =  -7

epoch =  178 batch =  21 / 101 loss_uniform =  -7.466577847798665  loss_variational =  -7.4430866695585705
epoch =  178 batch =  31 / 101 loss_uniform =  -7.423993356766239  loss_variational =  -7.391540896508001
epoch =  178 batch =  41 / 101 loss_uniform =  -7.4404080786356115  loss_variational =  -7.401588160817216
epoch =  178 batch =  51 / 101 loss_uniform =  -7.405391281726313  loss_variational =  -7.361982233384076
epoch =  178 batch =  61 / 101 loss_uniform =  -7.424645947628334  loss_variational =  -7.378542814098421
epoch =  178 batch =  71 / 101 loss_uniform =  -7.444310430070044  loss_variational =  -7.395793807338661
epoch =  178 batch =  81 / 101 loss_uniform =  -7.446136433401225  loss_variational =  -7.399062551098106
epoch =  178 batch =  91 / 101 loss_uniform =  -7.4501468899485825  loss_variational =  -7.4034120329133755
epoch =  178 batch =  101 / 101 loss_uniform =  -7.433414969113794  loss_variational =  -7.385224153499792
epoch =  179 batch =  1 / 101 loss_unifor

epoch =  185 batch =  31 / 101 loss_uniform =  -7.49330965165169  loss_variational =  -7.433893080680601
epoch =  185 batch =  41 / 101 loss_uniform =  -7.5007137554447825  loss_variational =  -7.443253156615467
epoch =  185 batch =  51 / 101 loss_uniform =  -7.503905735763849  loss_variational =  -7.450237152623195
epoch =  185 batch =  61 / 101 loss_uniform =  -7.469805811272293  loss_variational =  -7.420088924345423
epoch =  185 batch =  71 / 101 loss_uniform =  -7.454098479848512  loss_variational =  -7.407083430760343
epoch =  185 batch =  81 / 101 loss_uniform =  -7.45509676285732  loss_variational =  -7.40986227106165
epoch =  185 batch =  91 / 101 loss_uniform =  -7.450374388432765  loss_variational =  -7.407080582209995
epoch =  185 batch =  101 / 101 loss_uniform =  -7.552681592431399  loss_variational =  -7.51521357923451
epoch =  186 batch =  1 / 101 loss_uniform =  -7.48994255065918  loss_variational =  -7.432224750518799
epoch =  186 batch =  11 / 101 loss_uniform =  -7.

epoch =  192 batch =  41 / 101 loss_uniform =  -7.3232860332582055  loss_variational =  -7.2415303486149485
epoch =  192 batch =  51 / 101 loss_uniform =  -7.353232982111912  loss_variational =  -7.2865345618304085
epoch =  192 batch =  61 / 101 loss_uniform =  -7.384973096065834  loss_variational =  -7.327257031300029
epoch =  192 batch =  71 / 101 loss_uniform =  -7.386563777923584  loss_variational =  -7.334472468201543
epoch =  192 batch =  81 / 101 loss_uniform =  -7.389992078145345  loss_variational =  -7.34266123359586
epoch =  192 batch =  91 / 101 loss_uniform =  -7.402824129377093  loss_variational =  -7.358514696687132
epoch =  192 batch =  101 / 101 loss_uniform =  -7.577186046260418  loss_variational =  -7.515251891447766
epoch =  193 batch =  1 / 101 loss_uniform =  -6.987879276275635  loss_variational =  -6.824455261230469
epoch =  193 batch =  11 / 101 loss_uniform =  -7.188717495311391  loss_variational =  -7.147468913685191
epoch =  193 batch =  21 / 101 loss_uniform 

epoch =  199 batch =  51 / 101 loss_uniform =  -7.481681393642051  loss_variational =  -7.494746946821026
epoch =  199 batch =  61 / 101 loss_uniform =  -7.487934323607898  loss_variational =  -7.49907946977459
epoch =  199 batch =  71 / 101 loss_uniform =  -7.481984427277471  loss_variational =  -7.491020316808996
epoch =  199 batch =  81 / 101 loss_uniform =  -7.479757944742839  loss_variational =  -7.486784540576699
epoch =  199 batch =  91 / 101 loss_uniform =  -7.4674507811829285  loss_variational =  -7.473547689207307
epoch =  199 batch =  101 / 101 loss_uniform =  -7.531852708004489  loss_variational =  -7.5393625344380295


In [12]:
n_sample = 100000
with torch.no_grad():
    x_uniform = flow_uniform.sample(n_sample).cpu()
    x_variational = flow_variational.sample(n_sample).cpu()
x_data_plot = x_data_raw.cpu()[:n_sample,:]
x_weights_plot = x_weights_raw.cpu()[:n_sample]

In [17]:
x_variational[:,3]

tensor([7., 7., 5.,  ..., 7., 7., 6.])

In [20]:
x_data_plot[:,3]

tensor([ 5.,  0., 12.,  ..., 16., 14.,  9.])

In [21]:
plt.yscale("log")
bins = np.linspace(-0.5, 16.5, 18)
plt.hist(x_data_plot[:,3], histtype='stepfilled', edgecolor="black", facecolor="lightgray", bins=bins, weights=x_weights_plot)
#plt.hist(x_uniform[:,3], edgecolor="red", histtype="step", bins = bins)
#plt.hist(x_variational[:,3], edgecolor="green", histtype="step", bins = bins)
plt.show()

KeyboardInterrupt: 

In [15]:
with torch.no_grad():
    x_uniform_dequantized, _ = flow_uniform._transform._transforms[0].forward(x_uniform)
    x_variational_dequantized, _ = flow_variational._transform._transforms[0].forward(x_variational)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [None]:
plt.yscale("log")
plt.hist(x_uniform_dequantized[:,2], histtype='stepfilled', edgecolor="black", facecolor="lightgray", bins = 100)
plt.hist(x_variational_dequantized[:,2], edgecolor="red", histtype="step", bins = 100)
plt.show()