In [2]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sklearn.datasets as datasets

In [4]:
import torch
from torch import nn
from torch import optim

In [6]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from nflows.transforms.dequantization import UniformDequantization
from nflows.transforms.dequantization import VariationalDequantization

In [7]:
from nflows.distributions.normal import ConditionalDiagonalNormal
from nflows.distributions.normal import StandardNormal
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform

In [8]:
# Generate some dist with 2 real and 2 discrete dims

# These are the probabilities to draw a particular label
x_prob_1 = np.random.rand(10)
x_prob_2 = np.random.rand(5)
p_labels_1 = x_prob_1/(np.sum(x_prob_1))
p_labels_2 = x_prob_2/(np.sum(x_prob_2))
p_cum_labels_1 = np.cumsum(p_labels_1)
p_cum_labels_2 = np.cumsum(p_labels_2)

def p(x,y):
    if np.all(x > 0) and np.all(x < 1) and x.shape == y.shape:
        return m.exp(-(x[0] + x[1])/2) * np.cos((x[0]*y[0] + x[1]*y[1])*m.pi)
    return 0

def generate(n):
    x = np.zeros((n,4))

    counter = 0
    while(counter < n):
        y = np.zeros(2)
        y[0] = np.argmax(r.random() < p_cum_labels_1)
        y[1] = np.argmax(r.random() < p_cum_labels_2)

        while(True):
            # Get a point
            x_trial = np.random.rand(2)
            if r.random() < p(x_trial, y):
                x[counter][0] = x_trial[0]
                x[counter][1] = x_trial[1]
                x[counter][2] = y[0]
                x[counter][3] = y[1]
                counter += 1
                break

    return x

In [9]:
vardeq = VariationalDequantization(max_labels=torch.tensor([-1,-1,9,4]), rqs_hidden_features=15, rqs_flow_layers=1)

In [10]:
x_test = torch.tensor(generate(10), dtype=torch.float32)

In [11]:
x_forward, llh_forward = vardeq.forward(x_test)

derivs tensor(0.5644) tensor(1.0157)
what tensor(1.0147)
derivs tensor(0.5644) tensor(1.0157)
what tensor(1.0147)


In [12]:
x_inverse, llh_inverse = vardeq.inverse(x_forward)

tensor([[0.4823, 0.6992],
        [0.3941, 0.1129],
        [0.9702, 0.1847],
        [0.0718, 0.7597],
        [0.2705, 0.2407],
        [0.5501, 0.0097],
        [0.5447, 0.8463],
        [0.9689, 0.0555],
        [0.9394, 0.2060],
        [0.9075, 0.6193]])
torch.Size([10, 4])
derivs tensor(0.5644) tensor(1.0157)
what tensor(1.0147)


In [13]:
x_test

tensor([[0.3638, 0.0900, 4.0000, 3.0000],
        [0.6305, 0.7375, 3.0000, 0.0000],
        [0.4686, 0.4766, 9.0000, 0.0000],
        [0.4981, 0.1276, 0.0000, 3.0000],
        [0.7681, 0.3206, 2.0000, 1.0000],
        [0.0400, 0.5088, 5.0000, 0.0000],
        [0.2186, 0.7807, 5.0000, 4.0000],
        [0.1841, 0.4129, 9.0000, 0.0000],
        [0.6080, 0.5518, 9.0000, 1.0000],
        [0.6679, 0.7004, 9.0000, 3.0000]])

In [14]:
x_inverse

tensor([[0.3638, 0.0900, 4.0000, 3.0000],
        [0.6305, 0.7375, 3.0000, 0.0000],
        [0.4686, 0.4766, 9.0000, 0.0000],
        [0.4981, 0.1276, 0.0000, 3.0000],
        [0.7681, 0.3206, 2.0000, 1.0000],
        [0.0400, 0.5088, 5.0000, 0.0000],
        [0.2186, 0.7807, 5.0000, 4.0000],
        [0.1841, 0.4129, 9.0000, 0.0000],
        [0.6080, 0.5518, 9.0000, 1.0000],
        [0.6679, 0.7004, 9.0000, 3.0000]])

In [15]:
np.exp(llh_forward.detach().numpy())

array([[0.98914516],
       [1.0129949 ],
       [0.98784626],
       [0.9839106 ],
       [0.9771976 ],
       [0.99293524],
       [1.0124247 ],
       [0.99451196],
       [0.9941572 ],
       [0.98960286]], dtype=float32)

In [16]:
np.exp(llh_inverse.detach().numpy())

array([[1.0168296 ],
       [1.0155218 ],
       [0.99945664],
       [0.98976994],
       [0.9891932 ],
       [0.99958265],
       [0.98490906],
       [0.9880851 ],
       [1.0035509 ],
       [1.010094  ]], dtype=float32)

In [17]:
'''
x_plot = generate(1000000)
x_select_1 = x_plot[x_plot[:,2] == 7]
x_select_2 = x_select_1[x_select_1[:,3] == 3]
plt.hist2d(x_select_2[:,0], x_select_2[:,1], bins=25)
plt.show()
'''

'\nx_plot = generate(1000000)\nx_select_1 = x_plot[x_plot[:,2] == 7]\nx_select_2 = x_select_1[x_select_1[:,3] == 3]\nplt.hist2d(x_select_2[:,0], x_select_2[:,1], bins=25)\nplt.show()\n'