In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt

In [2]:
import torch
from torch import nn
from torch import optim

In [3]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline

In [4]:
device = torch.device("cuda:0")
#device = torch.device("cpu")

In [5]:
# Define some arbitrary probability distribution
N = (m.e-1)**2 * (1 + 8*m.pi**2) / m.e**2 / (1 + 16*m.pi**2)
def p_2d(x1_,x2_):
    if x1_ > 1. or x1_ < 0. or x2_ > 1. or x2_ < 0.:
        return 0
    return m.cos(2*m.pi*x1_)**2 *m.exp(-x1_-x2_)/N

def p_unnorm(x):
    if np.all(x > 0) and np.all(x < 1):
        return m.exp(-np.sum(x)/len(x)) * np.cos(2*m.pi*np.sum(x))**2 
    else:
        return 0

In [6]:
# Rejection sampling
def generate_data_2d(n):
    out = np.zeros((n,2))
    counter = 0
    while (counter < n):
        x = r.random()
        y = r.random()
        if (p_2d(x,y)*N > 1):
            print(p_2d(x,y)*N)   
        if (r.random() < p_2d(x,y)*N):
            out[counter][0] = x
            out[counter][1] = y
            counter += 1
    return out

def generate_data_nd(n, d):
    out = np.zeros((n, d))
    counter = 0
    while (counter < n):
        x = np.random.rand(d)
        if (r.random() < p_unnorm(x)):
            out[counter] = x
            counter += 1
    return out

In [7]:
data_size = int(1e6)
dim = 2
x_data = torch.tensor(generate_data_2d(data_size), dtype=torch.float32, device=device)

In [8]:
num_layers = 5
base_dist = BoxUniform(torch.zeros(dim), torch.ones(dim))

transforms = []
for _ in range(num_layers):
    transforms.append(ReversePermutation(features=dim))
    transforms.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=dim, 
        hidden_features=100,
        num_bins=15,
        num_blocks=3,
        #tails="constrained"
    ))
transform = CompositeTransform(transforms)

flow = Flow(transform, base_dist).to(device)
optimizer = optim.Adam(flow.parameters())

In [None]:
n_epochs = 100
batch_size = 10000
n_batches = m.ceil(data_size/batch_size)

for epoch in range(n_epochs):
    permutation = torch.randperm(x_data.size()[0], device=device)    

    # Loop over batches
    cum_loss = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, data_size-1 )
        indices = permutation[batch_begin:batch_end]
        batch_x = x_data[indices]
        
        # Take a step
        optimizer.zero_grad()
        loss = -flow.log_prob(inputs=batch_x).mean()
        loss.backward()
        optimizer.step()

        # Compute cumulative loss
        cum_loss = (cum_loss*batch + loss.item())/(batch+1)

        print("epoch = ", epoch, "batch = ",batch+1, "/", n_batches, "loss = ", cum_loss)

epoch =  0 batch =  1 / 100 loss =  0.5751538276672363
epoch =  0 batch =  2 / 100 loss =  0.4526011794805527
epoch =  0 batch =  3 / 100 loss =  0.35642102360725403
epoch =  0 batch =  4 / 100 loss =  0.2868288066238165
epoch =  0 batch =  5 / 100 loss =  0.23166247587651015
epoch =  0 batch =  6 / 100 loss =  0.19038011056060591
epoch =  0 batch =  7 / 100 loss =  0.15620055994285004
epoch =  0 batch =  8 / 100 loss =  0.12578502937685698
epoch =  0 batch =  9 / 100 loss =  0.09643878446271022
epoch =  0 batch =  10 / 100 loss =  0.06965852743014693
epoch =  0 batch =  11 / 100 loss =  0.04321587144989859
epoch =  0 batch =  12 / 100 loss =  0.022718641208484765
epoch =  0 batch =  13 / 100 loss =  0.002547623374714297
epoch =  0 batch =  14 / 100 loss =  -0.014987185131758455
epoch =  0 batch =  15 / 100 loss =  -0.030622347754736746
epoch =  0 batch =  16 / 100 loss =  -0.04587827675277367
epoch =  0 batch =  17 / 100 loss =  -0.05912733302616021
epoch =  0 batch =  18 / 100 loss =

epoch =  1 batch =  45 / 100 loss =  -0.3860108408663007
epoch =  1 batch =  46 / 100 loss =  -0.38607200980186457
epoch =  1 batch =  47 / 100 loss =  -0.38607203897009496
epoch =  1 batch =  48 / 100 loss =  -0.38620267994701857
epoch =  1 batch =  49 / 100 loss =  -0.3862900253461331
epoch =  1 batch =  50 / 100 loss =  -0.3863795500993728
epoch =  1 batch =  51 / 100 loss =  -0.38611393407279365
epoch =  1 batch =  52 / 100 loss =  -0.3860290864339241
epoch =  1 batch =  53 / 100 loss =  -0.3860304782975394
epoch =  1 batch =  54 / 100 loss =  -0.38596574962139124
epoch =  1 batch =  55 / 100 loss =  -0.38595721938393324
epoch =  1 batch =  56 / 100 loss =  -0.3861035678003515
epoch =  1 batch =  57 / 100 loss =  -0.3861840994734512
epoch =  1 batch =  58 / 100 loss =  -0.38605237829274136
epoch =  1 batch =  59 / 100 loss =  -0.3859955890704009
epoch =  1 batch =  60 / 100 loss =  -0.38585916161537165
epoch =  1 batch =  61 / 100 loss =  -0.3861539080494739
epoch =  1 batch =  62 

epoch =  2 batch =  93 / 100 loss =  -0.3887728436659742
epoch =  2 batch =  94 / 100 loss =  -0.38875409707109987
epoch =  2 batch =  95 / 100 loss =  -0.38883995162813295
epoch =  2 batch =  96 / 100 loss =  -0.3887932834525904
epoch =  2 batch =  97 / 100 loss =  -0.38891471876311556
epoch =  2 batch =  98 / 100 loss =  -0.38889044675291806
epoch =  2 batch =  99 / 100 loss =  -0.38898420845619364
epoch =  2 batch =  100 / 100 loss =  -0.38899416178464896
epoch =  3 batch =  1 / 100 loss =  -0.381522536277771
epoch =  3 batch =  2 / 100 loss =  -0.381961852312088
epoch =  3 batch =  3 / 100 loss =  -0.38755542039871216
epoch =  3 batch =  4 / 100 loss =  -0.38791442662477493
epoch =  3 batch =  5 / 100 loss =  -0.3905823647975922
epoch =  3 batch =  6 / 100 loss =  -0.38848939041296643
epoch =  3 batch =  7 / 100 loss =  -0.3872353562286922
epoch =  3 batch =  8 / 100 loss =  -0.38740143552422523
epoch =  3 batch =  9 / 100 loss =  -0.388008713722229
epoch =  3 batch =  10 / 100 los

epoch =  4 batch =  42 / 100 loss =  -0.39069113703001124
epoch =  4 batch =  43 / 100 loss =  -0.3904916152011517
epoch =  4 batch =  44 / 100 loss =  -0.3907720981673762
epoch =  4 batch =  45 / 100 loss =  -0.39085068504015613
epoch =  4 batch =  46 / 100 loss =  -0.39094426968823315
epoch =  4 batch =  47 / 100 loss =  -0.39117414076277557
epoch =  4 batch =  48 / 100 loss =  -0.3911158597717683
epoch =  4 batch =  49 / 100 loss =  -0.3910955221069103
epoch =  4 batch =  50 / 100 loss =  -0.39098649144172676
epoch =  4 batch =  51 / 100 loss =  -0.3907086656374091
epoch =  4 batch =  52 / 100 loss =  -0.39064008799883043
epoch =  4 batch =  53 / 100 loss =  -0.3905131563825428
epoch =  4 batch =  54 / 100 loss =  -0.3906481398476495
epoch =  4 batch =  55 / 100 loss =  -0.39082171320915227
epoch =  4 batch =  56 / 100 loss =  -0.39066118427685337
epoch =  4 batch =  57 / 100 loss =  -0.39083050008405723
epoch =  4 batch =  58 / 100 loss =  -0.3907955597186911
epoch =  4 batch =  59

epoch =  5 batch =  86 / 100 loss =  -0.3906996402629587
epoch =  5 batch =  87 / 100 loss =  -0.39057807264656863
epoch =  5 batch =  88 / 100 loss =  -0.390647671439431
epoch =  5 batch =  89 / 100 loss =  -0.3906239944227627
epoch =  5 batch =  90 / 100 loss =  -0.39055388271808633
epoch =  5 batch =  91 / 100 loss =  -0.39053230658992316
epoch =  5 batch =  92 / 100 loss =  -0.3904743045568467
epoch =  5 batch =  93 / 100 loss =  -0.3904318457008691
epoch =  5 batch =  94 / 100 loss =  -0.39047455819363297
epoch =  5 batch =  95 / 100 loss =  -0.3904906548951802
epoch =  5 batch =  96 / 100 loss =  -0.3905744934454561
epoch =  5 batch =  97 / 100 loss =  -0.3906701161074885
epoch =  5 batch =  98 / 100 loss =  -0.39084812603434743
epoch =  5 batch =  99 / 100 loss =  -0.39083987775475093
epoch =  5 batch =  100 / 100 loss =  -0.3907132443785668
epoch =  6 batch =  1 / 100 loss =  -0.3827058970928192
epoch =  6 batch =  2 / 100 loss =  -0.3896953761577606
epoch =  6 batch =  3 / 100

epoch =  7 batch =  30 / 100 loss =  -0.39313688774903616
epoch =  7 batch =  31 / 100 loss =  -0.39311677794302663
epoch =  7 batch =  32 / 100 loss =  -0.3933841083198786
epoch =  7 batch =  33 / 100 loss =  -0.3932978321205486
epoch =  7 batch =  34 / 100 loss =  -0.3928861959892161
epoch =  7 batch =  35 / 100 loss =  -0.3928921273776463
epoch =  7 batch =  36 / 100 loss =  -0.39289884020884824
epoch =  7 batch =  37 / 100 loss =  -0.3928937404542355
epoch =  7 batch =  38 / 100 loss =  -0.3929386491838254
epoch =  7 batch =  39 / 100 loss =  -0.3926345025881742
epoch =  7 batch =  40 / 100 loss =  -0.39247647598385804
epoch =  7 batch =  41 / 100 loss =  -0.3924296512836362
epoch =  7 batch =  42 / 100 loss =  -0.3922725156659171
epoch =  7 batch =  43 / 100 loss =  -0.39248472937317774
epoch =  7 batch =  44 / 100 loss =  -0.39223443107171485
epoch =  7 batch =  45 / 100 loss =  -0.39209562738736464
epoch =  7 batch =  46 / 100 loss =  -0.39214394597903535
epoch =  7 batch =  47 

epoch =  8 batch =  74 / 100 loss =  -0.39205459404636067
epoch =  8 batch =  75 / 100 loss =  -0.391943267583847
epoch =  8 batch =  76 / 100 loss =  -0.3920928792733895
epoch =  8 batch =  77 / 100 loss =  -0.3919922682372006
epoch =  8 batch =  78 / 100 loss =  -0.39204786106562
epoch =  8 batch =  79 / 100 loss =  -0.39203505802758126
epoch =  8 batch =  80 / 100 loss =  -0.3919647000730037
epoch =  8 batch =  81 / 100 loss =  -0.3919622096014611
epoch =  8 batch =  82 / 100 loss =  -0.392021037093023
epoch =  8 batch =  83 / 100 loss =  -0.3920288789703185
epoch =  8 batch =  84 / 100 loss =  -0.39197716152384166
epoch =  8 batch =  85 / 100 loss =  -0.391930472500184
epoch =  8 batch =  86 / 100 loss =  -0.39189695792142737
epoch =  8 batch =  87 / 100 loss =  -0.3919731827302911
epoch =  8 batch =  88 / 100 loss =  -0.3920331759886308
epoch =  8 batch =  89 / 100 loss =  -0.392088331533282
epoch =  8 batch =  90 / 100 loss =  -0.3920765393310123
epoch =  8 batch =  91 / 100 loss

epoch =  10 batch =  18 / 100 loss =  -0.3929052816496955
epoch =  10 batch =  19 / 100 loss =  -0.3931518808791512
epoch =  10 batch =  20 / 100 loss =  -0.39348429888486863
epoch =  10 batch =  21 / 100 loss =  -0.39340676012493314
epoch =  10 batch =  22 / 100 loss =  -0.39395039596340875
epoch =  10 batch =  23 / 100 loss =  -0.3937854222629381
epoch =  10 batch =  24 / 100 loss =  -0.3938241849342982
epoch =  10 batch =  25 / 100 loss =  -0.3938210916519165
epoch =  10 batch =  26 / 100 loss =  -0.39416165076769316
epoch =  10 batch =  27 / 100 loss =  -0.3939809865421719
epoch =  10 batch =  28 / 100 loss =  -0.39372595718928743
epoch =  10 batch =  29 / 100 loss =  -0.39350399786028367
epoch =  10 batch =  30 / 100 loss =  -0.39370236496130623
epoch =  10 batch =  31 / 100 loss =  -0.3933314017711147
epoch =  10 batch =  32 / 100 loss =  -0.3929292634129524
epoch =  10 batch =  33 / 100 loss =  -0.3926576701077548
epoch =  10 batch =  34 / 100 loss =  -0.3921427384895437
epoch =

epoch =  11 batch =  59 / 100 loss =  -0.39314368619757184
epoch =  11 batch =  60 / 100 loss =  -0.3931774283448855
epoch =  11 batch =  61 / 100 loss =  -0.39318724782740483
epoch =  11 batch =  62 / 100 loss =  -0.39306860825707834
epoch =  11 batch =  63 / 100 loss =  -0.39299864948741975
epoch =  11 batch =  64 / 100 loss =  -0.39288752852007747
epoch =  11 batch =  65 / 100 loss =  -0.39294272156862114
epoch =  11 batch =  66 / 100 loss =  -0.39299185754674854
epoch =  11 batch =  67 / 100 loss =  -0.39303440967602515
epoch =  11 batch =  68 / 100 loss =  -0.3928859851816121
epoch =  11 batch =  69 / 100 loss =  -0.39281213974607165
epoch =  11 batch =  70 / 100 loss =  -0.39272445014544893
epoch =  11 batch =  71 / 100 loss =  -0.39263648928051265
epoch =  11 batch =  72 / 100 loss =  -0.39252864196896553
epoch =  11 batch =  73 / 100 loss =  -0.39258069167398424
epoch =  11 batch =  74 / 100 loss =  -0.39268056967773945
epoch =  11 batch =  75 / 100 loss =  -0.39277562061945587

epoch =  13 batch =  3 / 100 loss =  -0.3921734591325124
epoch =  13 batch =  4 / 100 loss =  -0.3917006850242615
epoch =  13 batch =  5 / 100 loss =  -0.3934480369091034
epoch =  13 batch =  6 / 100 loss =  -0.39463164905707043
epoch =  13 batch =  7 / 100 loss =  -0.39557430148124695
epoch =  13 batch =  8 / 100 loss =  -0.39528895914554596
epoch =  13 batch =  9 / 100 loss =  -0.39670765068795943
epoch =  13 batch =  10 / 100 loss =  -0.3969224810600281
epoch =  13 batch =  11 / 100 loss =  -0.39552258361469617
epoch =  13 batch =  12 / 100 loss =  -0.3959300220012665
epoch =  13 batch =  13 / 100 loss =  -0.3959298890370589
epoch =  13 batch =  14 / 100 loss =  -0.3947123054947172
epoch =  13 batch =  15 / 100 loss =  -0.39448097745577493
epoch =  13 batch =  16 / 100 loss =  -0.3944192919880152
epoch =  13 batch =  17 / 100 loss =  -0.39339197383207436
epoch =  13 batch =  18 / 100 loss =  -0.3931495067146089
epoch =  13 batch =  19 / 100 loss =  -0.3934793880111293
epoch =  13 ba

epoch =  14 batch =  46 / 100 loss =  -0.3917611286691997
epoch =  14 batch =  47 / 100 loss =  -0.3918199266525025
epoch =  14 batch =  48 / 100 loss =  -0.3919752910733223
epoch =  14 batch =  49 / 100 loss =  -0.39198161205466914
epoch =  14 batch =  50 / 100 loss =  -0.3918775010108948
epoch =  14 batch =  51 / 100 loss =  -0.391972283522288
epoch =  14 batch =  52 / 100 loss =  -0.392195184643452
epoch =  14 batch =  53 / 100 loss =  -0.392240269004174
epoch =  14 batch =  54 / 100 loss =  -0.39232392222793017
epoch =  14 batch =  55 / 100 loss =  -0.3923941704359922
epoch =  14 batch =  56 / 100 loss =  -0.3922826291194984
epoch =  14 batch =  57 / 100 loss =  -0.3922345136341296
epoch =  14 batch =  58 / 100 loss =  -0.39218289893248987
epoch =  14 batch =  59 / 100 loss =  -0.39231291665869245
epoch =  14 batch =  60 / 100 loss =  -0.3924206718802452
epoch =  14 batch =  61 / 100 loss =  -0.392486111062472
epoch =  14 batch =  62 / 100 loss =  -0.3926185797299108
epoch =  14 ba

epoch =  15 batch =  90 / 100 loss =  -0.3936050752798716
epoch =  15 batch =  91 / 100 loss =  -0.39344721475800315
epoch =  15 batch =  92 / 100 loss =  -0.39344851219135785
epoch =  15 batch =  93 / 100 loss =  -0.39340280461054977
epoch =  15 batch =  94 / 100 loss =  -0.3934165103004334
epoch =  15 batch =  95 / 100 loss =  -0.39346441501065305
epoch =  15 batch =  96 / 100 loss =  -0.39342730833838385
epoch =  15 batch =  97 / 100 loss =  -0.39344737302396715
epoch =  15 batch =  98 / 100 loss =  -0.3934884196033283
epoch =  15 batch =  99 / 100 loss =  -0.393374715188537
epoch =  15 batch =  100 / 100 loss =  -0.393345565199852
epoch =  16 batch =  1 / 100 loss =  -0.3848830461502075
epoch =  16 batch =  2 / 100 loss =  -0.39406560361385345
epoch =  16 batch =  3 / 100 loss =  -0.39500118295351666
epoch =  16 batch =  4 / 100 loss =  -0.39579663425683975
epoch =  16 batch =  5 / 100 loss =  -0.3955433189868927
epoch =  16 batch =  6 / 100 loss =  -0.3944848229487737
epoch =  16 

epoch =  17 batch =  34 / 100 loss =  -0.3932587916360182
epoch =  17 batch =  35 / 100 loss =  -0.3931902689593179
epoch =  17 batch =  36 / 100 loss =  -0.3931768122646544
epoch =  17 batch =  37 / 100 loss =  -0.39313925520793813
epoch =  17 batch =  38 / 100 loss =  -0.39291001072055415
epoch =  17 batch =  39 / 100 loss =  -0.3929369617731143
epoch =  17 batch =  40 / 100 loss =  -0.39290680065751077
epoch =  17 batch =  41 / 100 loss =  -0.39297854682294336
epoch =  17 batch =  42 / 100 loss =  -0.3928627300830114
epoch =  17 batch =  43 / 100 loss =  -0.39288990372835203
epoch =  17 batch =  44 / 100 loss =  -0.3929330835288221
epoch =  17 batch =  45 / 100 loss =  -0.3930016961362627
epoch =  17 batch =  46 / 100 loss =  -0.3930015680582627
epoch =  17 batch =  47 / 100 loss =  -0.3931377479370604
epoch =  17 batch =  48 / 100 loss =  -0.393073217322429
epoch =  17 batch =  49 / 100 loss =  -0.3929953933978567
epoch =  17 batch =  50 / 100 loss =  -0.39297484755516054
epoch =  

epoch =  18 batch =  78 / 100 loss =  -0.393871566424003
epoch =  18 batch =  79 / 100 loss =  -0.39388025825536704
epoch =  18 batch =  80 / 100 loss =  -0.3939858961850405
epoch =  18 batch =  81 / 100 loss =  -0.3940165087028786
epoch =  18 batch =  82 / 100 loss =  -0.3939745524307577
epoch =  18 batch =  83 / 100 loss =  -0.39402942664651985
epoch =  18 batch =  84 / 100 loss =  -0.39400995274384815
epoch =  18 batch =  85 / 100 loss =  -0.39415160066941207
epoch =  18 batch =  86 / 100 loss =  -0.3940909369740375
epoch =  18 batch =  87 / 100 loss =  -0.39403678459682684
epoch =  18 batch =  88 / 100 loss =  -0.39398151534524833
epoch =  18 batch =  89 / 100 loss =  -0.39392990644058484
epoch =  18 batch =  90 / 100 loss =  -0.3939126388894187
epoch =  18 batch =  91 / 100 loss =  -0.3938832063596327
epoch =  18 batch =  92 / 100 loss =  -0.3939229211081629
epoch =  18 batch =  93 / 100 loss =  -0.3939771117061697
epoch =  18 batch =  94 / 100 loss =  -0.3940312291713471
epoch = 

epoch =  20 batch =  22 / 100 loss =  -0.3929829841310328
epoch =  20 batch =  23 / 100 loss =  -0.39317103572513745
epoch =  20 batch =  24 / 100 loss =  -0.3934388607740402
epoch =  20 batch =  25 / 100 loss =  -0.3935351276397705
epoch =  20 batch =  26 / 100 loss =  -0.39342606182281786
epoch =  20 batch =  27 / 100 loss =  -0.39339487309809085
epoch =  20 batch =  28 / 100 loss =  -0.3931189009121486
epoch =  20 batch =  29 / 100 loss =  -0.3931906747406927
epoch =  20 batch =  30 / 100 loss =  -0.39317633509635924
epoch =  20 batch =  31 / 100 loss =  -0.3931680927353521
epoch =  20 batch =  32 / 100 loss =  -0.39294965472072363
epoch =  20 batch =  33 / 100 loss =  -0.3935056220401417
epoch =  20 batch =  34 / 100 loss =  -0.39341746095348806
epoch =  20 batch =  35 / 100 loss =  -0.39357866389410834
epoch =  20 batch =  36 / 100 loss =  -0.39359033356110257
epoch =  20 batch =  37 / 100 loss =  -0.3935643015681086
epoch =  20 batch =  38 / 100 loss =  -0.39392455157480744
epoch

epoch =  21 batch =  64 / 100 loss =  -0.393563122022897
epoch =  21 batch =  65 / 100 loss =  -0.393597347002763
epoch =  21 batch =  66 / 100 loss =  -0.39354532370061585
epoch =  21 batch =  67 / 100 loss =  -0.39346062202951804
epoch =  21 batch =  68 / 100 loss =  -0.39353203948806315
epoch =  21 batch =  69 / 100 loss =  -0.39372820223587146
epoch =  21 batch =  70 / 100 loss =  -0.3938277964081083
epoch =  21 batch =  71 / 100 loss =  -0.3937432841515877
epoch =  21 batch =  72 / 100 loss =  -0.39361940779619753
epoch =  21 batch =  73 / 100 loss =  -0.3936695174811638
epoch =  21 batch =  74 / 100 loss =  -0.39365844911820186
epoch =  21 batch =  75 / 100 loss =  -0.3937081138292949
epoch =  21 batch =  76 / 100 loss =  -0.3936276204491917
epoch =  21 batch =  77 / 100 loss =  -0.39368353880845114
epoch =  21 batch =  78 / 100 loss =  -0.3937509743831097
epoch =  21 batch =  79 / 100 loss =  -0.39353897262223164
epoch =  21 batch =  80 / 100 loss =  -0.3933613557368517
epoch = 

epoch =  23 batch =  8 / 100 loss =  -0.3933686688542366
epoch =  23 batch =  9 / 100 loss =  -0.39263201091024613
epoch =  23 batch =  10 / 100 loss =  -0.3916592180728912
epoch =  23 batch =  11 / 100 loss =  -0.39133269407532434
epoch =  23 batch =  12 / 100 loss =  -0.39222336808840436
epoch =  23 batch =  13 / 100 loss =  -0.39207955965628993
epoch =  23 batch =  14 / 100 loss =  -0.39273502571242197
epoch =  23 batch =  15 / 100 loss =  -0.39314250349998475
epoch =  23 batch =  16 / 100 loss =  -0.3926392961293459
epoch =  23 batch =  17 / 100 loss =  -0.3920897894045886
epoch =  23 batch =  18 / 100 loss =  -0.39234553939766353
epoch =  23 batch =  19 / 100 loss =  -0.39309851276247126
epoch =  23 batch =  20 / 100 loss =  -0.3924980103969574
epoch =  23 batch =  21 / 100 loss =  -0.392249896412804
epoch =  23 batch =  22 / 100 loss =  -0.3928364285013892
epoch =  23 batch =  23 / 100 loss =  -0.39290298327155737
epoch =  23 batch =  24 / 100 loss =  -0.3931576882799466
epoch = 

epoch =  24 batch =  52 / 100 loss =  -0.3944311652045983
epoch =  24 batch =  53 / 100 loss =  -0.3944893929193604
epoch =  24 batch =  54 / 100 loss =  -0.3945907481290675
epoch =  24 batch =  55 / 100 loss =  -0.3948035565289583
epoch =  24 batch =  56 / 100 loss =  -0.39479295856186314
epoch =  24 batch =  57 / 100 loss =  -0.3948698451644495
epoch =  24 batch =  58 / 100 loss =  -0.39499703563492866
epoch =  24 batch =  59 / 100 loss =  -0.3948164738840975
epoch =  24 batch =  60 / 100 loss =  -0.3948545709252357
epoch =  24 batch =  61 / 100 loss =  -0.39496811100693996
epoch =  24 batch =  62 / 100 loss =  -0.39493297280803796
epoch =  24 batch =  63 / 100 loss =  -0.3948637459959302
epoch =  24 batch =  64 / 100 loss =  -0.3949244613759219
epoch =  24 batch =  65 / 100 loss =  -0.39478049782606267
epoch =  24 batch =  66 / 100 loss =  -0.39468223669312213
epoch =  24 batch =  67 / 100 loss =  -0.39466589911660144
epoch =  24 batch =  68 / 100 loss =  -0.39466749526122025
epoch 

epoch =  25 batch =  97 / 100 loss =  -0.39478957222909045
epoch =  25 batch =  98 / 100 loss =  -0.3948345546211515
epoch =  25 batch =  99 / 100 loss =  -0.39491758322474935
epoch =  25 batch =  100 / 100 loss =  -0.39480010598897936
epoch =  26 batch =  1 / 100 loss =  -0.38861221075057983
epoch =  26 batch =  2 / 100 loss =  -0.39023540914058685
epoch =  26 batch =  3 / 100 loss =  -0.3910955289999644
epoch =  26 batch =  4 / 100 loss =  -0.3909341171383858
epoch =  26 batch =  5 / 100 loss =  -0.38996033668518065
epoch =  26 batch =  6 / 100 loss =  -0.3918473074833552
epoch =  26 batch =  7 / 100 loss =  -0.3927954392773764
epoch =  26 batch =  8 / 100 loss =  -0.3925544433295727
epoch =  26 batch =  9 / 100 loss =  -0.39356964495446944
epoch =  26 batch =  10 / 100 loss =  -0.393855556845665
epoch =  26 batch =  11 / 100 loss =  -0.39350236545909534
epoch =  26 batch =  12 / 100 loss =  -0.3932073712348938
epoch =  26 batch =  13 / 100 loss =  -0.39232144676722014
epoch =  26 ba

epoch =  27 batch =  41 / 100 loss =  -0.3953188316124241
epoch =  27 batch =  42 / 100 loss =  -0.3955849352337064
epoch =  27 batch =  43 / 100 loss =  -0.395697078732557
epoch =  27 batch =  44 / 100 loss =  -0.39538425207138056
epoch =  27 batch =  45 / 100 loss =  -0.395072462161382
epoch =  27 batch =  46 / 100 loss =  -0.39491215607394337
epoch =  27 batch =  47 / 100 loss =  -0.3949875495535261
epoch =  27 batch =  48 / 100 loss =  -0.3949651556710401
epoch =  27 batch =  49 / 100 loss =  -0.39489079555686635
epoch =  27 batch =  50 / 100 loss =  -0.39468324899673457
epoch =  27 batch =  51 / 100 loss =  -0.3947449998528349
epoch =  27 batch =  52 / 100 loss =  -0.39473632379220075
epoch =  27 batch =  53 / 100 loss =  -0.3947521532481571
epoch =  27 batch =  54 / 100 loss =  -0.39484884617505245
epoch =  27 batch =  55 / 100 loss =  -0.3947754047133705
epoch =  27 batch =  56 / 100 loss =  -0.3946035099881035
epoch =  27 batch =  57 / 100 loss =  -0.39462540913046446
epoch =  

epoch =  28 batch =  84 / 100 loss =  -0.39500056029785263
epoch =  28 batch =  85 / 100 loss =  -0.39482777223867543
epoch =  28 batch =  86 / 100 loss =  -0.39484735872856425
epoch =  28 batch =  87 / 100 loss =  -0.3949598933773481
epoch =  28 batch =  88 / 100 loss =  -0.39500739459287054
epoch =  28 batch =  89 / 100 loss =  -0.3949502257818588
epoch =  28 batch =  90 / 100 loss =  -0.394946180118455
epoch =  28 batch =  91 / 100 loss =  -0.394997802081999
epoch =  28 batch =  92 / 100 loss =  -0.39500463138455943
epoch =  28 batch =  93 / 100 loss =  -0.3948426663234671
epoch =  28 batch =  94 / 100 loss =  -0.3948224025203828
epoch =  28 batch =  95 / 100 loss =  -0.39480666047648394
epoch =  28 batch =  96 / 100 loss =  -0.3949145957206688
epoch =  28 batch =  97 / 100 loss =  -0.3949554648595988
epoch =  28 batch =  98 / 100 loss =  -0.39479464049241997
epoch =  28 batch =  99 / 100 loss =  -0.3948392130509773
epoch =  28 batch =  100 / 100 loss =  -0.3948300802707674
epoch = 

epoch =  30 batch =  27 / 100 loss =  -0.394912823482796
epoch =  30 batch =  28 / 100 loss =  -0.39510725545031683
epoch =  30 batch =  29 / 100 loss =  -0.3952934156204092
epoch =  30 batch =  30 / 100 loss =  -0.3951319694519043
epoch =  30 batch =  31 / 100 loss =  -0.39497136012200385
epoch =  30 batch =  32 / 100 loss =  -0.39491327852010727
epoch =  30 batch =  33 / 100 loss =  -0.39465134342511493
epoch =  30 batch =  34 / 100 loss =  -0.3947114812977174
epoch =  30 batch =  35 / 100 loss =  -0.3949710292475564
epoch =  30 batch =  36 / 100 loss =  -0.3946845763259464
epoch =  30 batch =  37 / 100 loss =  -0.3945807532684223
epoch =  30 batch =  38 / 100 loss =  -0.3945809892917934
epoch =  30 batch =  39 / 100 loss =  -0.394624215670121
epoch =  30 batch =  40 / 100 loss =  -0.3945314325392246
epoch =  30 batch =  41 / 100 loss =  -0.3943584226980442
epoch =  30 batch =  42 / 100 loss =  -0.39439684791224344
epoch =  30 batch =  43 / 100 loss =  -0.39462304184603136
epoch =  3

epoch =  31 batch =  71 / 100 loss =  -0.3942885176396706
epoch =  31 batch =  72 / 100 loss =  -0.3942456667621931
epoch =  31 batch =  73 / 100 loss =  -0.39413840517605825
epoch =  31 batch =  74 / 100 loss =  -0.3941339649058678
epoch =  31 batch =  75 / 100 loss =  -0.39403824289639794
epoch =  31 batch =  76 / 100 loss =  -0.3940752332932071
epoch =  31 batch =  77 / 100 loss =  -0.3941075213543781
epoch =  31 batch =  78 / 100 loss =  -0.3940278983268983
epoch =  31 batch =  79 / 100 loss =  -0.3941576861882512
epoch =  31 batch =  80 / 100 loss =  -0.3941570814698935
epoch =  31 batch =  81 / 100 loss =  -0.39400570738462753
epoch =  31 batch =  82 / 100 loss =  -0.39401574214784113
epoch =  31 batch =  83 / 100 loss =  -0.3939542059438774
epoch =  31 batch =  84 / 100 loss =  -0.39411768388180507
epoch =  31 batch =  85 / 100 loss =  -0.3941035154987784
epoch =  31 batch =  86 / 100 loss =  -0.3942164033651352
epoch =  31 batch =  87 / 100 loss =  -0.39428677161534625
epoch = 

epoch =  33 batch =  15 / 100 loss =  -0.394850891828537
epoch =  33 batch =  16 / 100 loss =  -0.3945472966879606
epoch =  33 batch =  17 / 100 loss =  -0.39439110720858855
epoch =  33 batch =  18 / 100 loss =  -0.39434845083289677
epoch =  33 batch =  19 / 100 loss =  -0.3947330760328393
epoch =  33 batch =  20 / 100 loss =  -0.395074699819088
epoch =  33 batch =  21 / 100 loss =  -0.3951453382060641
epoch =  33 batch =  22 / 100 loss =  -0.3947606086730957
epoch =  33 batch =  23 / 100 loss =  -0.394833883513575
epoch =  33 batch =  24 / 100 loss =  -0.39486056193709373
epoch =  33 batch =  25 / 100 loss =  -0.39487973690032957
epoch =  33 batch =  26 / 100 loss =  -0.3948941104687177
epoch =  33 batch =  27 / 100 loss =  -0.3950900645167739
epoch =  33 batch =  28 / 100 loss =  -0.3949797164116587
epoch =  33 batch =  29 / 100 loss =  -0.3948772498245897
epoch =  33 batch =  30 / 100 loss =  -0.3945663819710414
epoch =  33 batch =  31 / 100 loss =  -0.39477871983282026
epoch =  33 

epoch =  34 batch =  59 / 100 loss =  -0.394882881035239
epoch =  34 batch =  60 / 100 loss =  -0.3949082533518473
epoch =  34 batch =  61 / 100 loss =  -0.3948503004722908
epoch =  34 batch =  62 / 100 loss =  -0.3947057973953985
epoch =  34 batch =  63 / 100 loss =  -0.3946387233242156
epoch =  34 batch =  64 / 100 loss =  -0.3945621596649289
epoch =  34 batch =  65 / 100 loss =  -0.3945678481688866
epoch =  34 batch =  66 / 100 loss =  -0.3946446460304838
epoch =  34 batch =  67 / 100 loss =  -0.3945404735963736
epoch =  34 batch =  68 / 100 loss =  -0.3946044857011122
epoch =  34 batch =  69 / 100 loss =  -0.39473401852276013
epoch =  34 batch =  70 / 100 loss =  -0.3946794935635158
epoch =  34 batch =  71 / 100 loss =  -0.394783830978501
epoch =  34 batch =  72 / 100 loss =  -0.3946446871591939
epoch =  34 batch =  73 / 100 loss =  -0.39461419884472676
epoch =  34 batch =  74 / 100 loss =  -0.3945771991401105
epoch =  34 batch =  75 / 100 loss =  -0.39458594759305315
epoch =  34 b

epoch =  36 batch =  2 / 100 loss =  -0.3998228460550308
epoch =  36 batch =  3 / 100 loss =  -0.398266206185023
epoch =  36 batch =  4 / 100 loss =  -0.3941982463002205
epoch =  36 batch =  5 / 100 loss =  -0.39147602915763857
epoch =  36 batch =  6 / 100 loss =  -0.3919980973005295
epoch =  36 batch =  7 / 100 loss =  -0.3922187771115984
epoch =  36 batch =  8 / 100 loss =  -0.39223019778728485
epoch =  36 batch =  9 / 100 loss =  -0.3907095657454597
epoch =  36 batch =  10 / 100 loss =  -0.39113928377628326
epoch =  36 batch =  11 / 100 loss =  -0.3912483101541346
epoch =  36 batch =  12 / 100 loss =  -0.39117031296094257
epoch =  36 batch =  13 / 100 loss =  -0.391150359924023
epoch =  36 batch =  14 / 100 loss =  -0.3918473848274776
epoch =  36 batch =  15 / 100 loss =  -0.3916473607222239
epoch =  36 batch =  16 / 100 loss =  -0.39238011464476585
epoch =  36 batch =  17 / 100 loss =  -0.3928041388006771
epoch =  36 batch =  18 / 100 loss =  -0.39286339779694873
epoch =  36 batch 

epoch =  37 batch =  45 / 100 loss =  -0.39474678635597227
epoch =  37 batch =  46 / 100 loss =  -0.39463636862195056
epoch =  37 batch =  47 / 100 loss =  -0.3946905345358747
epoch =  37 batch =  48 / 100 loss =  -0.3943544967720906
epoch =  37 batch =  49 / 100 loss =  -0.39461634961926206
epoch =  37 batch =  50 / 100 loss =  -0.3946834796667099
epoch =  37 batch =  51 / 100 loss =  -0.39468398865531473
epoch =  37 batch =  52 / 100 loss =  -0.3947452020186644
epoch =  37 batch =  53 / 100 loss =  -0.39484515898632555
epoch =  37 batch =  54 / 100 loss =  -0.3946152831669207
epoch =  37 batch =  55 / 100 loss =  -0.39450922229073265
epoch =  37 batch =  56 / 100 loss =  -0.3946205420153482
epoch =  37 batch =  57 / 100 loss =  -0.39481176514374583
epoch =  37 batch =  58 / 100 loss =  -0.39477878397908706
epoch =  37 batch =  59 / 100 loss =  -0.3948784233149836
epoch =  37 batch =  60 / 100 loss =  -0.3949716245134672
epoch =  37 batch =  61 / 100 loss =  -0.39507107461085084
epoch

epoch =  38 batch =  89 / 100 loss =  -0.3953969639338804
epoch =  38 batch =  90 / 100 loss =  -0.39550485279825
epoch =  38 batch =  91 / 100 loss =  -0.39544665289449166
epoch =  38 batch =  92 / 100 loss =  -0.3952941596508026
epoch =  38 batch =  93 / 100 loss =  -0.3952863347786729
epoch =  38 batch =  94 / 100 loss =  -0.3952017176024457
epoch =  38 batch =  95 / 100 loss =  -0.39523818775227193
epoch =  38 batch =  96 / 100 loss =  -0.39515170517067116
epoch =  38 batch =  97 / 100 loss =  -0.39520548115071563
epoch =  38 batch =  98 / 100 loss =  -0.3950466005777826
epoch =  38 batch =  99 / 100 loss =  -0.39505183696746826
epoch =  38 batch =  100 / 100 loss =  -0.3951111501455307
epoch =  39 batch =  1 / 100 loss =  -0.3946811556816101
epoch =  39 batch =  2 / 100 loss =  -0.3955254554748535
epoch =  39 batch =  3 / 100 loss =  -0.39603402217229206
epoch =  39 batch =  4 / 100 loss =  -0.39328256994485855
epoch =  39 batch =  5 / 100 loss =  -0.3947270929813385
epoch =  39 b

epoch =  40 batch =  34 / 100 loss =  -0.39442641068907347
epoch =  40 batch =  35 / 100 loss =  -0.39440124716077535
epoch =  40 batch =  36 / 100 loss =  -0.3945012349221442
epoch =  40 batch =  37 / 100 loss =  -0.39466845103212306
epoch =  40 batch =  38 / 100 loss =  -0.39479168308408635
epoch =  40 batch =  39 / 100 loss =  -0.394858541396948
epoch =  40 batch =  40 / 100 loss =  -0.39502365887165064
epoch =  40 batch =  41 / 100 loss =  -0.39504620941673835
epoch =  40 batch =  42 / 100 loss =  -0.3949516983259292
epoch =  40 batch =  43 / 100 loss =  -0.395040234854055
epoch =  40 batch =  44 / 100 loss =  -0.3952626863663847
epoch =  40 batch =  45 / 100 loss =  -0.3952047990428077
epoch =  40 batch =  46 / 100 loss =  -0.39539787691572437
epoch =  40 batch =  47 / 100 loss =  -0.39541163216245934
epoch =  40 batch =  48 / 100 loss =  -0.39550773178537685
epoch =  40 batch =  49 / 100 loss =  -0.39576576193984675
epoch =  40 batch =  50 / 100 loss =  -0.39590774178504945
epoch

epoch =  41 batch =  78 / 100 loss =  -0.3951085137251095
epoch =  41 batch =  79 / 100 loss =  -0.39497749299942686
epoch =  41 batch =  80 / 100 loss =  -0.3949950724840164
epoch =  41 batch =  81 / 100 loss =  -0.3950707275926331
epoch =  41 batch =  82 / 100 loss =  -0.39505823320004996
epoch =  41 batch =  83 / 100 loss =  -0.3951094850718257
epoch =  41 batch =  84 / 100 loss =  -0.395133162183421
epoch =  41 batch =  85 / 100 loss =  -0.3952292989282047
epoch =  41 batch =  86 / 100 loss =  -0.3952173791652502
epoch =  41 batch =  87 / 100 loss =  -0.39521653453509015
epoch =  41 batch =  88 / 100 loss =  -0.39517366005615756
epoch =  41 batch =  89 / 100 loss =  -0.39508168014247763
epoch =  41 batch =  90 / 100 loss =  -0.39520653519365523
epoch =  41 batch =  91 / 100 loss =  -0.3951697074449979
epoch =  41 batch =  92 / 100 loss =  -0.395146400384281
epoch =  41 batch =  93 / 100 loss =  -0.39520773003178256
epoch =  41 batch =  94 / 100 loss =  -0.39525966948651253
epoch = 

epoch =  43 batch =  21 / 100 loss =  -0.39368696156002225
epoch =  43 batch =  22 / 100 loss =  -0.39382156187837775
epoch =  43 batch =  23 / 100 loss =  -0.39403960886208905
epoch =  43 batch =  24 / 100 loss =  -0.39388083294034004
epoch =  43 batch =  25 / 100 loss =  -0.394081746339798
epoch =  43 batch =  26 / 100 loss =  -0.3943535788701131
epoch =  43 batch =  27 / 100 loss =  -0.3944178256723616
epoch =  43 batch =  28 / 100 loss =  -0.39497804641723633
epoch =  43 batch =  29 / 100 loss =  -0.3949837499651416
epoch =  43 batch =  30 / 100 loss =  -0.3950525859991709
epoch =  43 batch =  31 / 100 loss =  -0.3950195216363476
epoch =  43 batch =  32 / 100 loss =  -0.3950796043500304
epoch =  43 batch =  33 / 100 loss =  -0.3954367890502467
epoch =  43 batch =  34 / 100 loss =  -0.39542121834614696
epoch =  43 batch =  35 / 100 loss =  -0.39524727293423245
epoch =  43 batch =  36 / 100 loss =  -0.3951763196123971
epoch =  43 batch =  37 / 100 loss =  -0.3952889998216887
epoch = 

epoch =  44 batch =  65 / 100 loss =  -0.39539562784708465
epoch =  44 batch =  66 / 100 loss =  -0.3954037608522357
epoch =  44 batch =  67 / 100 loss =  -0.3952590591871916
epoch =  44 batch =  68 / 100 loss =  -0.3954341963810079
epoch =  44 batch =  69 / 100 loss =  -0.39531887052715686
epoch =  44 batch =  70 / 100 loss =  -0.39522924593516756
epoch =  44 batch =  71 / 100 loss =  -0.39536547576877434
epoch =  44 batch =  72 / 100 loss =  -0.3953051037258572
epoch =  44 batch =  73 / 100 loss =  -0.3951732108037766
epoch =  44 batch =  74 / 100 loss =  -0.395212270923563
epoch =  44 batch =  75 / 100 loss =  -0.39539558490117394
epoch =  44 batch =  76 / 100 loss =  -0.39550333391678966
epoch =  44 batch =  77 / 100 loss =  -0.39541613049321367
epoch =  44 batch =  78 / 100 loss =  -0.3953453818192849
epoch =  44 batch =  79 / 100 loss =  -0.3953417924386037
epoch =  44 batch =  80 / 100 loss =  -0.3954379159957171
epoch =  44 batch =  81 / 100 loss =  -0.395501846148644
epoch =  

epoch =  46 batch =  8 / 100 loss =  -0.3927052989602089
epoch =  46 batch =  9 / 100 loss =  -0.3938372598754035
epoch =  46 batch =  10 / 100 loss =  -0.39422266781330106
epoch =  46 batch =  11 / 100 loss =  -0.39585182070732117
epoch =  46 batch =  12 / 100 loss =  -0.39688530564308167
epoch =  46 batch =  13 / 100 loss =  -0.397203477529379
epoch =  46 batch =  14 / 100 loss =  -0.39635363646915983
epoch =  46 batch =  15 / 100 loss =  -0.3955320676167806
epoch =  46 batch =  16 / 100 loss =  -0.39472880214452744
epoch =  46 batch =  17 / 100 loss =  -0.39524050670511585
epoch =  46 batch =  18 / 100 loss =  -0.3952171819077598
epoch =  46 batch =  19 / 100 loss =  -0.39490842191796555
epoch =  46 batch =  20 / 100 loss =  -0.3948889434337616
epoch =  46 batch =  21 / 100 loss =  -0.39497197242010207
epoch =  46 batch =  22 / 100 loss =  -0.39515787904912775
epoch =  46 batch =  23 / 100 loss =  -0.3955166819302932
epoch =  46 batch =  24 / 100 loss =  -0.3955921605229378
epoch = 

epoch =  47 batch =  51 / 100 loss =  -0.3948856527898826
epoch =  47 batch =  52 / 100 loss =  -0.3951101555274083
epoch =  47 batch =  53 / 100 loss =  -0.395070788432967
epoch =  47 batch =  54 / 100 loss =  -0.39518387615680695
epoch =  47 batch =  55 / 100 loss =  -0.3951653518460014
epoch =  47 batch =  56 / 100 loss =  -0.39518680796027184
epoch =  47 batch =  57 / 100 loss =  -0.3952613713448508
epoch =  47 batch =  58 / 100 loss =  -0.3952321045357606
epoch =  47 batch =  59 / 100 loss =  -0.39527198520757384
epoch =  47 batch =  60 / 100 loss =  -0.39533587694168093
epoch =  47 batch =  61 / 100 loss =  -0.3951545196478484
epoch =  47 batch =  62 / 100 loss =  -0.3950854568712173
epoch =  47 batch =  63 / 100 loss =  -0.39503468407524955
epoch =  47 batch =  64 / 100 loss =  -0.39491145499050617
epoch =  47 batch =  65 / 100 loss =  -0.3946461668381324
epoch =  47 batch =  66 / 100 loss =  -0.3944154210162885
epoch =  47 batch =  67 / 100 loss =  -0.394450843334198
epoch =  4

epoch =  48 batch =  95 / 100 loss =  -0.3953610690016496
epoch =  48 batch =  96 / 100 loss =  -0.39528872817754745
epoch =  48 batch =  97 / 100 loss =  -0.3953018431196508
epoch =  48 batch =  98 / 100 loss =  -0.3953129685654932
epoch =  48 batch =  99 / 100 loss =  -0.3953275852131121
epoch =  48 batch =  100 / 100 loss =  -0.3953523996472359
epoch =  49 batch =  1 / 100 loss =  -0.3866505026817322
epoch =  49 batch =  2 / 100 loss =  -0.39583083987236023
epoch =  49 batch =  3 / 100 loss =  -0.39476101597150165
epoch =  49 batch =  4 / 100 loss =  -0.3950177952647209
epoch =  49 batch =  5 / 100 loss =  -0.39252246618270875
epoch =  49 batch =  6 / 100 loss =  -0.39282484849294025
epoch =  49 batch =  7 / 100 loss =  -0.39306922044072834
epoch =  49 batch =  8 / 100 loss =  -0.3940621055662632
epoch =  49 batch =  9 / 100 loss =  -0.3944534990522597
epoch =  49 batch =  10 / 100 loss =  -0.39379738867282865
epoch =  49 batch =  11 / 100 loss =  -0.3938770294189453
epoch =  49 bat

epoch =  50 batch =  38 / 100 loss =  -0.39478305766457006
epoch =  50 batch =  39 / 100 loss =  -0.39494353456374925
epoch =  50 batch =  40 / 100 loss =  -0.3951946072280407
epoch =  50 batch =  41 / 100 loss =  -0.3950252772831335
epoch =  50 batch =  42 / 100 loss =  -0.39502336368674323
epoch =  50 batch =  43 / 100 loss =  -0.3949331440204798
epoch =  50 batch =  44 / 100 loss =  -0.3951246846805919
epoch =  50 batch =  45 / 100 loss =  -0.3952626387278239
epoch =  50 batch =  46 / 100 loss =  -0.39497083101583563
epoch =  50 batch =  47 / 100 loss =  -0.395083043169468
epoch =  50 batch =  48 / 100 loss =  -0.39524360311528045
epoch =  50 batch =  49 / 100 loss =  -0.39553629378883204
epoch =  50 batch =  50 / 100 loss =  -0.3952948272228241
epoch =  50 batch =  51 / 100 loss =  -0.3952588538328807
epoch =  50 batch =  52 / 100 loss =  -0.3951082155108452
epoch =  50 batch =  53 / 100 loss =  -0.3951910762292034
epoch =  50 batch =  54 / 100 loss =  -0.39510056376457214
epoch = 

epoch =  51 batch =  82 / 100 loss =  -0.39555743736464805
epoch =  51 batch =  83 / 100 loss =  -0.39557236038058635
epoch =  51 batch =  84 / 100 loss =  -0.3956184930035046
epoch =  51 batch =  85 / 100 loss =  -0.39546821222585793
epoch =  51 batch =  86 / 100 loss =  -0.39544207204219906
epoch =  51 batch =  87 / 100 loss =  -0.3953962877564047
epoch =  51 batch =  88 / 100 loss =  -0.3954360999844291
epoch =  51 batch =  89 / 100 loss =  -0.39545680766695
epoch =  51 batch =  90 / 100 loss =  -0.3953691946135627
epoch =  51 batch =  91 / 100 loss =  -0.39526967294923554
epoch =  51 batch =  92 / 100 loss =  -0.3952300393063089
epoch =  51 batch =  93 / 100 loss =  -0.3952773892751304
epoch =  51 batch =  94 / 100 loss =  -0.395309826795091
epoch =  51 batch =  95 / 100 loss =  -0.39529521810381035
epoch =  51 batch =  96 / 100 loss =  -0.39538982572654885
epoch =  51 batch =  97 / 100 loss =  -0.3954238157296918
epoch =  51 batch =  98 / 100 loss =  -0.39544989138233416
epoch =  

epoch =  53 batch =  26 / 100 loss =  -0.39224432294185346
epoch =  53 batch =  27 / 100 loss =  -0.39262764321433175
epoch =  53 batch =  28 / 100 loss =  -0.39308603746550425
epoch =  53 batch =  29 / 100 loss =  -0.3932923695136761
epoch =  53 batch =  30 / 100 loss =  -0.39322876433531445
epoch =  53 batch =  31 / 100 loss =  -0.39340053450676704
epoch =  53 batch =  32 / 100 loss =  -0.39332133810967207
epoch =  53 batch =  33 / 100 loss =  -0.3935009484941309
epoch =  53 batch =  34 / 100 loss =  -0.3935660509502186
epoch =  53 batch =  35 / 100 loss =  -0.39341477496283395
epoch =  53 batch =  36 / 100 loss =  -0.3935526832938194
epoch =  53 batch =  37 / 100 loss =  -0.393330748822238
epoch =  53 batch =  38 / 100 loss =  -0.39378767970361206
epoch =  53 batch =  39 / 100 loss =  -0.3938672481439052
epoch =  53 batch =  40 / 100 loss =  -0.39424607604742046
epoch =  53 batch =  41 / 100 loss =  -0.3944767233802051
epoch =  53 batch =  42 / 100 loss =  -0.3945151787428629
epoch 

epoch =  54 batch =  70 / 100 loss =  -0.39669235050678253
epoch =  54 batch =  71 / 100 loss =  -0.3966662887956055
epoch =  54 batch =  72 / 100 loss =  -0.39666208376487094
epoch =  54 batch =  73 / 100 loss =  -0.3966669415774411
epoch =  54 batch =  74 / 100 loss =  -0.396711705504237
epoch =  54 batch =  75 / 100 loss =  -0.3966050354639689
epoch =  54 batch =  76 / 100 loss =  -0.39639944740031896
epoch =  54 batch =  77 / 100 loss =  -0.3965348527803049
epoch =  54 batch =  78 / 100 loss =  -0.3966165952957593
epoch =  54 batch =  79 / 100 loss =  -0.39656450250480746
epoch =  54 batch =  80 / 100 loss =  -0.39650455415248864
epoch =  54 batch =  81 / 100 loss =  -0.3961912485552422
epoch =  54 batch =  82 / 100 loss =  -0.3961279759319816
epoch =  54 batch =  83 / 100 loss =  -0.3961262695760611
epoch =  54 batch =  84 / 100 loss =  -0.39609796482892257
epoch =  54 batch =  85 / 100 loss =  -0.3959694529280942
epoch =  54 batch =  86 / 100 loss =  -0.3959900521954824
epoch =  

epoch =  56 batch =  14 / 100 loss =  -0.39299874220575604
epoch =  56 batch =  15 / 100 loss =  -0.39310765266418457
epoch =  56 batch =  16 / 100 loss =  -0.3933531567454338
epoch =  56 batch =  17 / 100 loss =  -0.39273828443358927
epoch =  56 batch =  18 / 100 loss =  -0.39258939690060085
epoch =  56 batch =  19 / 100 loss =  -0.3926833773914136
epoch =  56 batch =  20 / 100 loss =  -0.3924742966890335
epoch =  56 batch =  21 / 100 loss =  -0.3925526709783645
epoch =  56 batch =  22 / 100 loss =  -0.3924287340857766
epoch =  56 batch =  23 / 100 loss =  -0.3924068640107694
epoch =  56 batch =  24 / 100 loss =  -0.3925200973947843
epoch =  56 batch =  25 / 100 loss =  -0.39282333850860596
epoch =  56 batch =  26 / 100 loss =  -0.3931425718160776
epoch =  56 batch =  27 / 100 loss =  -0.39350905021031696
epoch =  56 batch =  28 / 100 loss =  -0.3936975747346878
epoch =  56 batch =  29 / 100 loss =  -0.39395273451147406
epoch =  56 batch =  30 / 100 loss =  -0.3940284957488378
epoch =

epoch =  57 batch =  59 / 100 loss =  -0.3957944214344025
epoch =  57 batch =  60 / 100 loss =  -0.3958323086301486
epoch =  57 batch =  61 / 100 loss =  -0.39581212254821285
epoch =  57 batch =  62 / 100 loss =  -0.39572204312970566
epoch =  57 batch =  63 / 100 loss =  -0.39572039104643325
epoch =  57 batch =  64 / 100 loss =  -0.3954987991601229
epoch =  57 batch =  65 / 100 loss =  -0.3955313737575825
epoch =  57 batch =  66 / 100 loss =  -0.39556557888334454
epoch =  57 batch =  67 / 100 loss =  -0.39571206516294344
epoch =  57 batch =  68 / 100 loss =  -0.3957080236252617
epoch =  57 batch =  69 / 100 loss =  -0.39557419047839404
epoch =  57 batch =  70 / 100 loss =  -0.3956201604434423
epoch =  57 batch =  71 / 100 loss =  -0.3955446590839978
epoch =  57 batch =  72 / 100 loss =  -0.395457865877284
epoch =  57 batch =  73 / 100 loss =  -0.3954667865413509
epoch =  57 batch =  74 / 100 loss =  -0.3953750141569086
epoch =  57 batch =  75 / 100 loss =  -0.3953590738773346
epoch =  

epoch =  59 batch =  3 / 100 loss =  -0.39555182059605914
epoch =  59 batch =  4 / 100 loss =  -0.39752162992954254
epoch =  59 batch =  5 / 100 loss =  -0.39832820892333987
epoch =  59 batch =  6 / 100 loss =  -0.397300864259402
epoch =  59 batch =  7 / 100 loss =  -0.3978922835418156
epoch =  59 batch =  8 / 100 loss =  -0.3979741595685482
epoch =  59 batch =  9 / 100 loss =  -0.397602852847841
epoch =  59 batch =  10 / 100 loss =  -0.39745340347290037
epoch =  59 batch =  11 / 100 loss =  -0.3977192288095301
epoch =  59 batch =  12 / 100 loss =  -0.39694053928057355
epoch =  59 batch =  13 / 100 loss =  -0.3972441554069519
epoch =  59 batch =  14 / 100 loss =  -0.3961082909788404
epoch =  59 batch =  15 / 100 loss =  -0.39614983399709064
epoch =  59 batch =  16 / 100 loss =  -0.39574517495930195
epoch =  59 batch =  17 / 100 loss =  -0.39514584225766797
epoch =  59 batch =  18 / 100 loss =  -0.3953157265981038
epoch =  59 batch =  19 / 100 loss =  -0.3959704794381794
epoch =  59 bat

epoch =  60 batch =  46 / 100 loss =  -0.395984952216563
epoch =  60 batch =  47 / 100 loss =  -0.3959077254254767
epoch =  60 batch =  48 / 100 loss =  -0.3959803966184457
epoch =  60 batch =  49 / 100 loss =  -0.39597603739524373
epoch =  60 batch =  50 / 100 loss =  -0.3958993375301361
epoch =  60 batch =  51 / 100 loss =  -0.395941345130696
epoch =  60 batch =  52 / 100 loss =  -0.3959144617502506
epoch =  60 batch =  53 / 100 loss =  -0.39609082239978716
epoch =  60 batch =  54 / 100 loss =  -0.3961425048333627
epoch =  60 batch =  55 / 100 loss =  -0.3960233520377766
epoch =  60 batch =  56 / 100 loss =  -0.3958897925913334
epoch =  60 batch =  57 / 100 loss =  -0.39595144830251994
epoch =  60 batch =  58 / 100 loss =  -0.39602301603761214
epoch =  60 batch =  59 / 100 loss =  -0.3960145474490473
epoch =  60 batch =  60 / 100 loss =  -0.39598218897978466
epoch =  60 batch =  61 / 100 loss =  -0.3960353898220375
epoch =  60 batch =  62 / 100 loss =  -0.39584851841772756
epoch =  6

epoch =  61 batch =  90 / 100 loss =  -0.39587769475248125
epoch =  61 batch =  91 / 100 loss =  -0.39588617034010837
epoch =  61 batch =  92 / 100 loss =  -0.3959076628088951
epoch =  61 batch =  93 / 100 loss =  -0.39574580263066034
epoch =  61 batch =  94 / 100 loss =  -0.3956507361949758
epoch =  61 batch =  95 / 100 loss =  -0.3957727767919239
epoch =  61 batch =  96 / 100 loss =  -0.39585873670876026
epoch =  61 batch =  97 / 100 loss =  -0.39585657709652616
epoch =  61 batch =  98 / 100 loss =  -0.3957448841965928
epoch =  61 batch =  99 / 100 loss =  -0.39569479347479464
epoch =  61 batch =  100 / 100 loss =  -0.3958014205098152
epoch =  62 batch =  1 / 100 loss =  -0.3985939919948578
epoch =  62 batch =  2 / 100 loss =  -0.39694058895111084
epoch =  62 batch =  3 / 100 loss =  -0.3958937426408132
epoch =  62 batch =  4 / 100 loss =  -0.397492416203022
epoch =  62 batch =  5 / 100 loss =  -0.3977018356323242
epoch =  62 batch =  6 / 100 loss =  -0.3959200878938039
epoch =  62 b

epoch =  63 batch =  34 / 100 loss =  -0.3964752660078161
epoch =  63 batch =  35 / 100 loss =  -0.3964745938777924
epoch =  63 batch =  36 / 100 loss =  -0.3962409935063786
epoch =  63 batch =  37 / 100 loss =  -0.39593456886910106
epoch =  63 batch =  38 / 100 loss =  -0.39595202866353485
epoch =  63 batch =  39 / 100 loss =  -0.39608186177718335
epoch =  63 batch =  40 / 100 loss =  -0.39586658775806427
epoch =  63 batch =  41 / 100 loss =  -0.39578384088306895
epoch =  63 batch =  42 / 100 loss =  -0.39595694059417363
epoch =  63 batch =  43 / 100 loss =  -0.39605416600094284
epoch =  63 batch =  44 / 100 loss =  -0.3964009014042941
epoch =  63 batch =  45 / 100 loss =  -0.39634721279144286
epoch =  63 batch =  46 / 100 loss =  -0.39636679252852564
epoch =  63 batch =  47 / 100 loss =  -0.396121066301427
epoch =  63 batch =  48 / 100 loss =  -0.39609240430096787
epoch =  63 batch =  49 / 100 loss =  -0.3957392993021984
epoch =  63 batch =  50 / 100 loss =  -0.39591373801231383
epoc

epoch =  64 batch =  76 / 100 loss =  -0.39582905134088114
epoch =  64 batch =  77 / 100 loss =  -0.39582328208081136
epoch =  64 batch =  78 / 100 loss =  -0.395861675341924
epoch =  64 batch =  79 / 100 loss =  -0.3957991015307511
epoch =  64 batch =  80 / 100 loss =  -0.39592898525297643
epoch =  64 batch =  81 / 100 loss =  -0.3960311401773382
epoch =  64 batch =  82 / 100 loss =  -0.3959689976238623
epoch =  64 batch =  83 / 100 loss =  -0.39603899329541675
epoch =  64 batch =  84 / 100 loss =  -0.39607853087640943
epoch =  64 batch =  85 / 100 loss =  -0.3959583980195663
epoch =  64 batch =  86 / 100 loss =  -0.3959361235069674
epoch =  64 batch =  87 / 100 loss =  -0.39569857853582535
epoch =  64 batch =  88 / 100 loss =  -0.39571295780214394
epoch =  64 batch =  89 / 100 loss =  -0.3958568368734938
epoch =  64 batch =  90 / 100 loss =  -0.3958295921484629
epoch =  64 batch =  91 / 100 loss =  -0.39594180472604523
epoch =  64 batch =  92 / 100 loss =  -0.39590593297844345
epoch 

epoch =  66 batch =  20 / 100 loss =  -0.3969895765185356
epoch =  66 batch =  21 / 100 loss =  -0.3959615315709795
epoch =  66 batch =  22 / 100 loss =  -0.39593063430352643
epoch =  66 batch =  23 / 100 loss =  -0.39583518064540363
epoch =  66 batch =  24 / 100 loss =  -0.39605549971262616
epoch =  66 batch =  25 / 100 loss =  -0.3961253106594086
epoch =  66 batch =  26 / 100 loss =  -0.3964910896924826
epoch =  66 batch =  27 / 100 loss =  -0.39651120812804613
epoch =  66 batch =  28 / 100 loss =  -0.3963782212563923
epoch =  66 batch =  29 / 100 loss =  -0.3962689103751347
epoch =  66 batch =  30 / 100 loss =  -0.3961508403221766
epoch =  66 batch =  31 / 100 loss =  -0.3960243732698502
epoch =  66 batch =  32 / 100 loss =  -0.39577874913811684
epoch =  66 batch =  33 / 100 loss =  -0.39569686578981805
epoch =  66 batch =  34 / 100 loss =  -0.3955513852484086
epoch =  66 batch =  35 / 100 loss =  -0.39576489414487565
epoch =  66 batch =  36 / 100 loss =  -0.39572977854145897
epoch 

epoch =  67 batch =  64 / 100 loss =  -0.3957055704668164
epoch =  67 batch =  65 / 100 loss =  -0.3957330456146827
epoch =  67 batch =  66 / 100 loss =  -0.3957640898950172
epoch =  67 batch =  67 / 100 loss =  -0.39578145208643445
epoch =  67 batch =  68 / 100 loss =  -0.39584220332257886
epoch =  67 batch =  69 / 100 loss =  -0.39577799473983655
epoch =  67 batch =  70 / 100 loss =  -0.3958029040268489
epoch =  67 batch =  71 / 100 loss =  -0.3957330479588307
epoch =  67 batch =  72 / 100 loss =  -0.39572645392682815
epoch =  67 batch =  73 / 100 loss =  -0.395669697082206
epoch =  67 batch =  74 / 100 loss =  -0.3957575015925072
epoch =  67 batch =  75 / 100 loss =  -0.3957677511374156
epoch =  67 batch =  76 / 100 loss =  -0.39579112435642044
epoch =  67 batch =  77 / 100 loss =  -0.39576371497922136
epoch =  67 batch =  78 / 100 loss =  -0.39590298747405034
epoch =  67 batch =  79 / 100 loss =  -0.3960105784331696
epoch =  67 batch =  80 / 100 loss =  -0.3959820058196783
epoch = 

epoch =  69 batch =  8 / 100 loss =  -0.3925073966383934
epoch =  69 batch =  9 / 100 loss =  -0.3934083316061232
epoch =  69 batch =  10 / 100 loss =  -0.39327520728111265
epoch =  69 batch =  11 / 100 loss =  -0.39341534809632733
epoch =  69 batch =  12 / 100 loss =  -0.3932737708091736
epoch =  69 batch =  13 / 100 loss =  -0.3935311918075268
epoch =  69 batch =  14 / 100 loss =  -0.3943320023162024
epoch =  69 batch =  15 / 100 loss =  -0.3946003794670105
epoch =  69 batch =  16 / 100 loss =  -0.3944011218845844
epoch =  69 batch =  17 / 100 loss =  -0.39380387348287244
epoch =  69 batch =  18 / 100 loss =  -0.3947535388999515
epoch =  69 batch =  19 / 100 loss =  -0.3951177251966376
epoch =  69 batch =  20 / 100 loss =  -0.39526255875825883
epoch =  69 batch =  21 / 100 loss =  -0.3948921873455956
epoch =  69 batch =  22 / 100 loss =  -0.39531962167132983
epoch =  69 batch =  23 / 100 loss =  -0.395700700905012
epoch =  69 batch =  24 / 100 loss =  -0.3953633705774943
epoch =  69 

epoch =  70 batch =  53 / 100 loss =  -0.39531079384515866
epoch =  70 batch =  54 / 100 loss =  -0.395307817392879
epoch =  70 batch =  55 / 100 loss =  -0.3954361460425636
epoch =  70 batch =  56 / 100 loss =  -0.3956435253577572
epoch =  70 batch =  57 / 100 loss =  -0.3955667651536171
epoch =  70 batch =  58 / 100 loss =  -0.395639263350388
epoch =  70 batch =  59 / 100 loss =  -0.3960062487650725
epoch =  70 batch =  60 / 100 loss =  -0.39601758519808444
epoch =  70 batch =  61 / 100 loss =  -0.39603967744796
epoch =  70 batch =  62 / 100 loss =  -0.3959165282787815
epoch =  70 batch =  63 / 100 loss =  -0.3960396046676332
epoch =  70 batch =  64 / 100 loss =  -0.3960160394199192
epoch =  70 batch =  65 / 100 loss =  -0.39587526825758124
epoch =  70 batch =  66 / 100 loss =  -0.395730892365629
epoch =  70 batch =  67 / 100 loss =  -0.3957716288851268
epoch =  70 batch =  68 / 100 loss =  -0.39569018518223475
epoch =  70 batch =  69 / 100 loss =  -0.3957820776580036
epoch =  70 bat

epoch =  71 batch =  97 / 100 loss =  -0.39598461678347635
epoch =  71 batch =  98 / 100 loss =  -0.39599981113355986
epoch =  71 batch =  99 / 100 loss =  -0.3960416922063539
epoch =  71 batch =  100 / 100 loss =  -0.39609434694051743
epoch =  72 batch =  1 / 100 loss =  -0.4027171730995178
epoch =  72 batch =  2 / 100 loss =  -0.387325644493103


In [None]:
with torch.no_grad():
    x_flow = flow.sample(100000).cpu().numpy()

In [None]:
plt.hist2d(x_data[:,0].cpu().numpy(), x_data[:,1].cpu().numpy(), bins=25)
plt.show()

In [None]:
plt.hist2d(x_flow[:,0], x_flow[:,1], bins=25)
plt.show()