In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt

In [2]:
import torch
from torch import nn
from torch import optim

In [3]:
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.flows.base import Flow
from nflows.transforms.dropout import UniformStochasticDropout
from nflows.transforms.dropout import VariationalStochasticDropout
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform

In [4]:
device = torch.device("cuda:0")
#device = torch.device("cpu")

In [5]:
# This works with any size x
def p(x, n_probs):
    sums = torch.sum(x, axis=1)
    probs = torch.cos(torch.ger(sums, torch.arange(1, n_probs+1, dtype=torch.float32)))**2
    norm = torch.sum(probs, axis=1)

    for i in range(n_probs):
        probs[:,i] /= norm
    
    return probs

In [6]:
def generate(n, drop_indices):
    n_probs = torch.max(drop_indices) + 1
    x = torch.rand(n, drop_indices.shape[0])
    probs = p(x, n_probs)

    # Pick a prob
    probs_cumsum = torch.cumsum(probs, axis=1)

    # Tensor with bools that are true when r passes the cumprob
    larger_than_cumprob = torch.rand(n,1) < probs_cumsum
    # Do the arange trick to find first nonzero
    # This is the HIGHEST LABEL FROM DROP_INDICES THAT IS KEPT
    selected_index = torch.argmax(larger_than_cumprob*torch.arange(n_probs, 0, -1), axis=1)

    '''
    print("The index of the selected probability")
    print("This is also the highest label in drop_indices that is kept")
    print(selected_index)
    ''' 
    
    # Find the index of the first true
    drop_mask = drop_indices > selected_index[:,None]
    x[drop_mask] = 0
    
    return x

In [7]:
drop_indices = torch.tensor([0,0,1,1,1,2,3,3,4])
n_data = int(1e6)
x_data = generate(n_data, drop_indices).to(device)

In [8]:
num_layers = 6
base_dist_uniform = BoxUniform(torch.zeros(drop_indices.shape[0]), torch.ones(drop_indices.shape[0]))
base_dist_variational = BoxUniform(torch.zeros(drop_indices.shape[0]), torch.ones(drop_indices.shape[0]))

transforms_uniform = []
transforms_variational = []

transforms_uniform.append(UniformStochasticDropout(drop_indices))
transforms_variational.append(VariationalStochasticDropout(drop_indices))

for _ in range(num_layers):
    transforms_uniform.append(RandomPermutation(features=drop_indices.shape[0]))
    transforms_uniform.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=drop_indices.shape[0], 
        hidden_features=50,
        num_bins=10,
        num_blocks=4,
    ))

    transforms_variational.append(RandomPermutation(features=drop_indices.shape[0]))
    transforms_variational.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=drop_indices.shape[0], 
        hidden_features=50,
        num_bins=10,
        num_blocks=4,
    ))

transform_uniform = CompositeTransform(transforms_uniform)
transform_variational = CompositeTransform(transforms_variational)

flow_uniform = Flow(transform_uniform, base_dist_uniform).to(device)
flow_variational = Flow(transform_variational, base_dist_variational).to(device)

optimizer_uniform = optim.Adam(flow_uniform.parameters())
optimizer_variational = optim.Adam(flow_variational.parameters())

In [9]:
n_epochs = 200
batch_size = 10000
n_batches = m.ceil(x_data.shape[0]/batch_size)

for epoch in range(n_epochs):
    permutation = torch.randperm(x_data.shape[0], device=device)    

    # Loop over batches
    cum_loss_uniform = 0
    cum_loss_variational = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, x_data.shape[0]-1 )
        indices = permutation[batch_begin:batch_end]
        batch_x = x_data[indices]
        
        # Take a step
        optimizer_uniform.zero_grad()
        optimizer_variational.zero_grad()

        loss_uniform = -flow_uniform.log_prob(inputs=batch_x).mean()
        loss_variational = -flow_variational.log_prob(inputs=batch_x).mean()

        loss_uniform.backward()
        loss_variational.backward()

        optimizer_uniform.step()
        optimizer_variational.step()

        # Compute cumulative loss
        cum_loss_uniform = (cum_loss_uniform*batch + loss_uniform.item())/(batch+1)
        cum_loss_variational = (cum_loss_variational*batch + loss_variational.item())/(batch+1)

        print("epoch = ", epoch, "batch = ",batch+1, "/", n_batches, "loss_uniform = ", cum_loss_uniform, " loss_variational = ", cum_loss_variational)

  log_probs_dropout_selected = torch.log(F.softmax(self._weights)[probs_dropout_index])


epoch =  0 batch =  1 / 1000 loss_uniform =  4.152677059173584  loss_variational =  6.120657920837402
epoch =  0 batch =  2 / 1000 loss_uniform =  4.094895124435425  loss_variational =  5.9657301902771
epoch =  0 batch =  3 / 1000 loss_uniform =  3.9773677984873452  loss_variational =  5.8011555671691895
epoch =  0 batch =  4 / 1000 loss_uniform =  3.885342299938202  loss_variational =  5.615455746650696
epoch =  0 batch =  5 / 1000 loss_uniform =  3.818865489959717  loss_variational =  5.487174034118652
epoch =  0 batch =  6 / 1000 loss_uniform =  3.7399772802988687  loss_variational =  5.333377122879028
epoch =  0 batch =  7 / 1000 loss_uniform =  3.675703832081386  loss_variational =  5.201341833387103
epoch =  0 batch =  8 / 1000 loss_uniform =  3.610759735107422  loss_variational =  5.082662045955658
epoch =  0 batch =  9 / 1000 loss_uniform =  3.5531297789679632  loss_variational =  4.984467294481066
epoch =  0 batch =  10 / 1000 loss_uniform =  3.492845368385315  loss_variationa

epoch =  0 batch =  81 / 1000 loss_uniform =  2.3266981634092923  loss_variational =  2.886233635890631
epoch =  0 batch =  82 / 1000 loss_uniform =  2.3210371063976756  loss_variational =  2.8754255902476427
epoch =  0 batch =  83 / 1000 loss_uniform =  2.315337341952037  loss_variational =  2.8648890343057105
epoch =  0 batch =  84 / 1000 loss_uniform =  2.3096360493273966  loss_variational =  2.8552088127249764
epoch =  0 batch =  85 / 1000 loss_uniform =  2.3046014252830958  loss_variational =  2.845240759849548
epoch =  0 batch =  86 / 1000 loss_uniform =  2.299064548902734  loss_variational =  2.8354436145272364
epoch =  0 batch =  87 / 1000 loss_uniform =  2.2936447310721744  loss_variational =  2.82580261120851
epoch =  0 batch =  88 / 1000 loss_uniform =  2.2883881086652935  loss_variational =  2.8162853392687706
epoch =  0 batch =  89 / 1000 loss_uniform =  2.2831771735394946  loss_variational =  2.807171682293495
epoch =  0 batch =  90 / 1000 loss_uniform =  2.27812335226270

epoch =  0 batch =  161 / 1000 loss_uniform =  2.047660806904669  loss_variational =  2.3900273601460893
epoch =  0 batch =  162 / 1000 loss_uniform =  2.0456321290981627  loss_variational =  2.3865088283279787
epoch =  0 batch =  163 / 1000 loss_uniform =  2.0436333001025617  loss_variational =  2.3829907822462673
epoch =  0 batch =  164 / 1000 loss_uniform =  2.041665043045835  loss_variational =  2.3793153094082338
epoch =  0 batch =  165 / 1000 loss_uniform =  2.0395518613584116  loss_variational =  2.3756414189483173
epoch =  0 batch =  166 / 1000 loss_uniform =  2.0376053797193325  loss_variational =  2.372096761163458
epoch =  0 batch =  167 / 1000 loss_uniform =  2.0357514441370252  loss_variational =  2.368721665022615
epoch =  0 batch =  168 / 1000 loss_uniform =  2.0339115616821113  loss_variational =  2.3652909000714613
epoch =  0 batch =  169 / 1000 loss_uniform =  2.0320138839574966  loss_variational =  2.3620890936202543
epoch =  0 batch =  170 / 1000 loss_uniform =  2.0

epoch =  0 batch =  239 / 1000 loss_uniform =  1.9325100068766696  loss_variational =  2.1872250500084456
epoch =  0 batch =  240 / 1000 loss_uniform =  1.9314478784799571  loss_variational =  2.185433349510033
epoch =  0 batch =  241 / 1000 loss_uniform =  1.9303882572166153  loss_variational =  2.1836693855736753
epoch =  0 batch =  242 / 1000 loss_uniform =  1.9292671192776067  loss_variational =  2.1820429154664023
epoch =  0 batch =  243 / 1000 loss_uniform =  1.9283219561164757  loss_variational =  2.1802525981463514
epoch =  0 batch =  244 / 1000 loss_uniform =  1.9272636531806377  loss_variational =  2.178476048786131
epoch =  0 batch =  245 / 1000 loss_uniform =  1.9262617763207879  loss_variational =  2.176736729485647
epoch =  0 batch =  246 / 1000 loss_uniform =  1.9252146283785498  loss_variational =  2.1748035797258694
epoch =  0 batch =  247 / 1000 loss_uniform =  1.9242078268576244  loss_variational =  2.17288868099089
epoch =  0 batch =  248 / 1000 loss_uniform =  1.92

epoch =  0 batch =  317 / 1000 loss_uniform =  1.8685860302922097  loss_variational =  2.076131899650163
epoch =  0 batch =  318 / 1000 loss_uniform =  1.8679304089186324  loss_variational =  2.0750249954139646
epoch =  0 batch =  319 / 1000 loss_uniform =  1.8672990361724897  loss_variational =  2.0738772082851957
epoch =  0 batch =  320 / 1000 loss_uniform =  1.8666543878614898  loss_variational =  2.072725166752933
epoch =  0 batch =  321 / 1000 loss_uniform =  1.8660089353163292  loss_variational =  2.071570485180411
epoch =  0 batch =  322 / 1000 loss_uniform =  1.865405592488946  loss_variational =  2.0704875218201853
epoch =  0 batch =  323 / 1000 loss_uniform =  1.8647686294726908  loss_variational =  2.0693698700736536
epoch =  0 batch =  324 / 1000 loss_uniform =  1.8641192732769762  loss_variational =  2.0683162823135457
epoch =  0 batch =  325 / 1000 loss_uniform =  1.8635186136685882  loss_variational =  2.067366305497975
epoch =  0 batch =  326 / 1000 loss_uniform =  1.86

epoch =  0 batch =  395 / 1000 loss_uniform =  1.8267748561086532  loss_variational =  2.0047204974331434
epoch =  0 batch =  396 / 1000 loss_uniform =  1.8263300569972605  loss_variational =  2.0039694167748845
epoch =  0 batch =  397 / 1000 loss_uniform =  1.8259011906400433  loss_variational =  2.0032381218686814
epoch =  0 batch =  398 / 1000 loss_uniform =  1.8254244075947668  loss_variational =  2.0024736113284702
epoch =  0 batch =  399 / 1000 loss_uniform =  1.8249941224741155  loss_variational =  2.001693761438356
epoch =  0 batch =  400 / 1000 loss_uniform =  1.8245380419492718  loss_variational =  2.0009630483388907
epoch =  0 batch =  401 / 1000 loss_uniform =  1.824127712451906  loss_variational =  2.0002194134671796
epoch =  0 batch =  402 / 1000 loss_uniform =  1.823722595006079  loss_variational =  1.9994512585858213
epoch =  0 batch =  403 / 1000 loss_uniform =  1.8233171358297833  loss_variational =  1.9986753336549106
epoch =  0 batch =  404 / 1000 loss_uniform =  1.

epoch =  0 batch =  473 / 1000 loss_uniform =  1.7975454753851536  loss_variational =  1.954235821647322
epoch =  0 batch =  474 / 1000 loss_uniform =  1.7972292301524035  loss_variational =  1.9537267164338998
epoch =  0 batch =  475 / 1000 loss_uniform =  1.796922266859757  loss_variational =  1.9531482829545679
epoch =  0 batch =  476 / 1000 loss_uniform =  1.7965903314722684  loss_variational =  1.9525916075506136
epoch =  0 batch =  477 / 1000 loss_uniform =  1.7963081763225528  loss_variational =  1.952057540041846
epoch =  0 batch =  478 / 1000 loss_uniform =  1.7960029760663976  loss_variational =  1.9515471790126182
epoch =  0 batch =  479 / 1000 loss_uniform =  1.7957285754615924  loss_variational =  1.9510152235409417
epoch =  0 batch =  480 / 1000 loss_uniform =  1.7953951266904669  loss_variational =  1.9504346753160164
epoch =  0 batch =  481 / 1000 loss_uniform =  1.7950584142470805  loss_variational =  1.9499273124207086
epoch =  0 batch =  482 / 1000 loss_uniform =  1.

epoch =  0 batch =  551 / 1000 loss_uniform =  1.775925948918406  loss_variational =  1.9170797730097975
epoch =  0 batch =  552 / 1000 loss_uniform =  1.7756817520096675  loss_variational =  1.9166831531818367
epoch =  0 batch =  553 / 1000 loss_uniform =  1.7754593986283582  loss_variational =  1.9162508167366874
epoch =  0 batch =  554 / 1000 loss_uniform =  1.77520869354909  loss_variational =  1.9158762581511959
epoch =  0 batch =  555 / 1000 loss_uniform =  1.7750082540082495  loss_variational =  1.9154879144720134
epoch =  0 batch =  556 / 1000 loss_uniform =  1.774779812895136  loss_variational =  1.9150521392873727
epoch =  0 batch =  557 / 1000 loss_uniform =  1.774561965700967  loss_variational =  1.9146265827346751
epoch =  0 batch =  558 / 1000 loss_uniform =  1.774342632635519  loss_variational =  1.9141931211222034
epoch =  0 batch =  559 / 1000 loss_uniform =  1.7741003149523922  loss_variational =  1.9137897092669087
epoch =  0 batch =  560 / 1000 loss_uniform =  1.773

epoch =  0 batch =  629 / 1000 loss_uniform =  1.7594542859658142  loss_variational =  1.887785662528252
epoch =  0 batch =  630 / 1000 loss_uniform =  1.759268567107972  loss_variational =  1.8874730501856127
epoch =  0 batch =  631 / 1000 loss_uniform =  1.7590885168021144  loss_variational =  1.887177166394689
epoch =  0 batch =  632 / 1000 loss_uniform =  1.7589032702808132  loss_variational =  1.8868576978580864
epoch =  0 batch =  633 / 1000 loss_uniform =  1.758720581942071  loss_variational =  1.8865483445578846
epoch =  0 batch =  634 / 1000 loss_uniform =  1.7585227047607344  loss_variational =  1.8862004896819784
epoch =  0 batch =  635 / 1000 loss_uniform =  1.7583373805669342  loss_variational =  1.8858699098346743
epoch =  0 batch =  636 / 1000 loss_uniform =  1.7581656031638564  loss_variational =  1.8855326693013033
epoch =  0 batch =  637 / 1000 loss_uniform =  1.7579610841801994  loss_variational =  1.885204284120018
epoch =  0 batch =  638 / 1000 loss_uniform =  1.75

epoch =  0 batch =  707 / 1000 loss_uniform =  1.7463610845037019  loss_variational =  1.8643335749775891
epoch =  0 batch =  708 / 1000 loss_uniform =  1.7461971743295415  loss_variational =  1.8640572654325414
epoch =  0 batch =  709 / 1000 loss_uniform =  1.7460591554305442  loss_variational =  1.8638415281123608
epoch =  0 batch =  710 / 1000 loss_uniform =  1.745913403638651  loss_variational =  1.8635958904951395
epoch =  0 batch =  711 / 1000 loss_uniform =  1.7457557645025126  loss_variational =  1.8632810996051583
epoch =  0 batch =  712 / 1000 loss_uniform =  1.745584067668807  loss_variational =  1.863004827767276
epoch =  0 batch =  713 / 1000 loss_uniform =  1.7454479546566992  loss_variational =  1.8627414631475743
epoch =  0 batch =  714 / 1000 loss_uniform =  1.7453332555060286  loss_variational =  1.8624863766488577
epoch =  0 batch =  715 / 1000 loss_uniform =  1.7451878429292793  loss_variational =  1.8622195927413197
epoch =  0 batch =  716 / 1000 loss_uniform =  1.

epoch =  0 batch =  785 / 1000 loss_uniform =  1.7356111843874493  loss_variational =  1.8450539104498118
epoch =  0 batch =  786 / 1000 loss_uniform =  1.7354937477876207  loss_variational =  1.8448568012271525
epoch =  0 batch =  787 / 1000 loss_uniform =  1.7353755309893786  loss_variational =  1.8446455907639892
epoch =  0 batch =  788 / 1000 loss_uniform =  1.735223455780048  loss_variational =  1.844449044181611
epoch =  0 batch =  789 / 1000 loss_uniform =  1.7350864745818158  loss_variational =  1.8441955200166307
epoch =  0 batch =  790 / 1000 loss_uniform =  1.7349641381939749  loss_variational =  1.843954781037343
epoch =  0 batch =  791 / 1000 loss_uniform =  1.7348377211808248  loss_variational =  1.843713035022866
epoch =  0 batch =  792 / 1000 loss_uniform =  1.7347061530207137  loss_variational =  1.8434746930695547
epoch =  0 batch =  793 / 1000 loss_uniform =  1.7345608324000388  loss_variational =  1.8432463569869613
epoch =  0 batch =  794 / 1000 loss_uniform =  1.7

epoch =  0 batch =  863 / 1000 loss_uniform =  1.7264377376019329  loss_variational =  1.8286413962705665
epoch =  0 batch =  864 / 1000 loss_uniform =  1.7263399921357627  loss_variational =  1.828468347175254
epoch =  0 batch =  865 / 1000 loss_uniform =  1.7262319189964686  loss_variational =  1.82825764259162
epoch =  0 batch =  866 / 1000 loss_uniform =  1.726139164411443  loss_variational =  1.8280771556407147
epoch =  0 batch =  867 / 1000 loss_uniform =  1.7260359144815793  loss_variational =  1.8278918259421575
epoch =  0 batch =  868 / 1000 loss_uniform =  1.725934174615666  loss_variational =  1.82770372920322
epoch =  0 batch =  869 / 1000 loss_uniform =  1.725830625198241  loss_variational =  1.827498980029406
epoch =  0 batch =  870 / 1000 loss_uniform =  1.7257400512695307  loss_variational =  1.827293953128245
epoch =  0 batch =  871 / 1000 loss_uniform =  1.7256383416847023  loss_variational =  1.827111829846379
epoch =  0 batch =  872 / 1000 loss_uniform =  1.72552257

epoch =  0 batch =  941 / 1000 loss_uniform =  1.7186386819600294  loss_variational =  1.8150543561017225
epoch =  0 batch =  942 / 1000 loss_uniform =  1.7185592655163657  loss_variational =  1.8149134803982552
epoch =  0 batch =  943 / 1000 loss_uniform =  1.718456823651383  loss_variational =  1.8147404631273873
epoch =  0 batch =  944 / 1000 loss_uniform =  1.7183659301470897  loss_variational =  1.8145704617944818
epoch =  0 batch =  945 / 1000 loss_uniform =  1.7182739877196211  loss_variational =  1.81440085885386
epoch =  0 batch =  946 / 1000 loss_uniform =  1.7181777991905771  loss_variational =  1.8142459514781237
epoch =  0 batch =  947 / 1000 loss_uniform =  1.718089684513579  loss_variational =  1.8140651211693273
epoch =  0 batch =  948 / 1000 loss_uniform =  1.7179883028133  loss_variational =  1.8138940515397473
epoch =  0 batch =  949 / 1000 loss_uniform =  1.7178986997071755  loss_variational =  1.8137316738968783
epoch =  0 batch =  950 / 1000 loss_uniform =  1.7178

epoch =  1 batch =  19 / 1000 loss_uniform =  1.6278203098397506  loss_variational =  1.6528499879335101
epoch =  1 batch =  20 / 1000 loss_uniform =  1.628013402223587  loss_variational =  1.653170371055603
epoch =  1 batch =  21 / 1000 loss_uniform =  1.6279938277744113  loss_variational =  1.6529998722530546
epoch =  1 batch =  22 / 1000 loss_uniform =  1.6284429214217446  loss_variational =  1.6529346975413235
epoch =  1 batch =  23 / 1000 loss_uniform =  1.628336621367413  loss_variational =  1.6529978980188784
epoch =  1 batch =  24 / 1000 loss_uniform =  1.6287933041652043  loss_variational =  1.6532084792852402
epoch =  1 batch =  25 / 1000 loss_uniform =  1.6288789653778075  loss_variational =  1.653622260093689
epoch =  1 batch =  26 / 1000 loss_uniform =  1.629077970981598  loss_variational =  1.6538846859565148
epoch =  1 batch =  27 / 1000 loss_uniform =  1.629141953256395  loss_variational =  1.6539628594009965
epoch =  1 batch =  28 / 1000 loss_uniform =  1.6290137682642

epoch =  1 batch =  99 / 1000 loss_uniform =  1.629113298473936  loss_variational =  1.6565155802351055
epoch =  1 batch =  100 / 1000 loss_uniform =  1.6291351640224456  loss_variational =  1.6563155877590179
epoch =  1 batch =  101 / 1000 loss_uniform =  1.629168196479873  loss_variational =  1.656421954088872
epoch =  1 batch =  102 / 1000 loss_uniform =  1.6292017639852037  loss_variational =  1.656356372085272
epoch =  1 batch =  103 / 1000 loss_uniform =  1.6291343714427022  loss_variational =  1.656329342462484
epoch =  1 batch =  104 / 1000 loss_uniform =  1.6292056017197096  loss_variational =  1.6563335760281637
epoch =  1 batch =  105 / 1000 loss_uniform =  1.629287959280468  loss_variational =  1.6562902518681117
epoch =  1 batch =  106 / 1000 loss_uniform =  1.6293926463936859  loss_variational =  1.6561358109960016
epoch =  1 batch =  107 / 1000 loss_uniform =  1.6294617942560499  loss_variational =  1.6561722532611027
epoch =  1 batch =  108 / 1000 loss_uniform =  1.6294

epoch =  1 batch =  177 / 1000 loss_uniform =  1.629983339606032  loss_variational =  1.654963948632364
epoch =  1 batch =  178 / 1000 loss_uniform =  1.629958206348205  loss_variational =  1.6549839678775056
epoch =  1 batch =  179 / 1000 loss_uniform =  1.6299071964604894  loss_variational =  1.6549861198031033
epoch =  1 batch =  180 / 1000 loss_uniform =  1.6299321631590524  loss_variational =  1.655004576179716
epoch =  1 batch =  181 / 1000 loss_uniform =  1.6298844037135003  loss_variational =  1.6549438995551007
epoch =  1 batch =  182 / 1000 loss_uniform =  1.6298224670546395  loss_variational =  1.654971116191738
epoch =  1 batch =  183 / 1000 loss_uniform =  1.6298347705048941  loss_variational =  1.6550593043937054
epoch =  1 batch =  184 / 1000 loss_uniform =  1.629820797106494  loss_variational =  1.6550456194773961
epoch =  1 batch =  185 / 1000 loss_uniform =  1.629782440211322  loss_variational =  1.6549670122765203
epoch =  1 batch =  186 / 1000 loss_uniform =  1.6297

epoch =  1 batch =  255 / 1000 loss_uniform =  1.6294018932417327  loss_variational =  1.6544519073822916
epoch =  1 batch =  256 / 1000 loss_uniform =  1.629426009953022  loss_variational =  1.6544656865298746
epoch =  1 batch =  257 / 1000 loss_uniform =  1.6293736779736174  loss_variational =  1.6544684385047346
epoch =  1 batch =  258 / 1000 loss_uniform =  1.6293517600658327  loss_variational =  1.6544117904448692
epoch =  1 batch =  259 / 1000 loss_uniform =  1.6293591734985586  loss_variational =  1.654409821889575
epoch =  1 batch =  260 / 1000 loss_uniform =  1.6293446059410388  loss_variational =  1.6544238833280709
epoch =  1 batch =  261 / 1000 loss_uniform =  1.6293398763028142  loss_variational =  1.6544518333741987
epoch =  1 batch =  262 / 1000 loss_uniform =  1.629299983723473  loss_variational =  1.6544118218749533
epoch =  1 batch =  263 / 1000 loss_uniform =  1.6292881820591683  loss_variational =  1.654402312217103
epoch =  1 batch =  264 / 1000 loss_uniform =  1.6

epoch =  1 batch =  333 / 1000 loss_uniform =  1.6294577397383734  loss_variational =  1.6536214888632834
epoch =  1 batch =  334 / 1000 loss_uniform =  1.629475187041803  loss_variational =  1.6536187700882643
epoch =  1 batch =  335 / 1000 loss_uniform =  1.6294726122671106  loss_variational =  1.6536121364849716
epoch =  1 batch =  336 / 1000 loss_uniform =  1.6294883737961459  loss_variational =  1.6536008027337847
epoch =  1 batch =  337 / 1000 loss_uniform =  1.6295248238198488  loss_variational =  1.6536014688474843
epoch =  1 batch =  338 / 1000 loss_uniform =  1.6294850460171  loss_variational =  1.6535276650677067
epoch =  1 batch =  339 / 1000 loss_uniform =  1.6294826803657516  loss_variational =  1.6535806961819135
epoch =  1 batch =  340 / 1000 loss_uniform =  1.6294774458688854  loss_variational =  1.6535287888611063
epoch =  1 batch =  341 / 1000 loss_uniform =  1.629498542928277  loss_variational =  1.6535185425162664
epoch =  1 batch =  342 / 1000 loss_uniform =  1.62

KeyboardInterrupt: 

In [None]:
n_sample = 100000
with torch.no_grad():
    x_uniform = flow_uniform.sample(n_sample).cpu().numpy()
    x_variational = flow_variational.sample(n_sample).cpu().numpy()
x_data_plot = x_data.cpu()[:n_sample,:].numpy()

In [None]:
bins = np.linspace(-0.5, drop_indices.shape[0]+0.5, drop_indices.shape[0]+2)
plt.hist(np.sum(x_data_plot == 0, axis=1), histtype='stepfilled', edgecolor="black", facecolor="lightgray", bins = bins)
plt.hist(np.sum(x_uniform == 0, axis=1), edgecolor="red", histtype="step", bins = bins)
plt.hist(np.sum(x_variational == 0, axis=1), edgecolor="green", histtype="step", bins = bins)
plt.show()

In [None]:
bins = np.linspace(0, drop_indices.shape[0], 20)
plt.hist(np.sum(x_data_plot, axis=1), histtype='stepfilled', edgecolor="black", facecolor="lightgray", bins = bins)
plt.hist(np.sum(x_uniform, axis=1), edgecolor="red", histtype="step", bins = bins)
plt.hist(np.sum(x_variational, axis=1), edgecolor="green", histtype="step", bins = bins)
plt.show()

In [None]:
bins = np.linspace(0, 1, 20)
plt.hist(x_data_plot[:,2], histtype='stepfilled', edgecolor="black", facecolor="lightgray", bins = bins)
plt.hist(x_uniform[:,2], edgecolor="red", histtype="step", bins = bins)
plt.hist(x_variational[:,2], edgecolor="green", histtype="step", bins = bins)
plt.show()