In [2]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt
import sklearn.datasets as datasets

In [4]:
import torch
from torch import nn
from torch import optim

In [6]:
from nflows.flows.base import Flow
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.autoregressive import MaskedPiecewiseQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.splines.rational_quadratic import rational_quadratic_spline
from nflows.transforms.dequantization import UniformDequantization
from nflows.transforms.dequantization import VariationalDequantization

In [8]:
from nflows.distributions.normal import ConditionalDiagonalNormal
from nflows.distributions.normal import StandardNormal
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform

In [10]:
# Generate some dist with 2 real and 2 discrete dims

# These are the probabilities to draw a particular label
x_prob_1 = np.random.rand(10)
x_prob_2 = np.random.rand(5)
p_labels_1 = x_prob_1/(np.sum(x_prob_1))
p_labels_2 = x_prob_2/(np.sum(x_prob_2))
p_cum_labels_1 = np.cumsum(p_labels_1)
p_cum_labels_2 = np.cumsum(p_labels_2)

def p(x,y):
    if np.all(x > 0) and np.all(x < 1) and x.shape == y.shape:
        return m.exp(-(x[0] + x[1])/2) * np.cos((x[0]*y[0] + x[1]*y[1])*m.pi)
    return 0

def generate(n):
    x = np.zeros((n,4))

    counter = 0
    while(counter < n):
        y = np.zeros(2)
        y[0] = np.argmax(r.random() < p_cum_labels_1)
        y[1] = np.argmax(r.random() < p_cum_labels_2)

        while(True):
            # Get a point
            x_trial = np.random.rand(2)
            if r.random() < p(x_trial, y):
                x[counter][0] = x_trial[0]
                x[counter][1] = x_trial[1]
                x[counter][2] = y[0]
                x[counter][3] = y[1]
                counter += 1
                break

    return x

In [12]:
vardeq = VariationalDequantization(max_labels=torch.tensor([-1,-1,9,4]), rqs_hidden_features=15)

In [14]:
x_test = torch.tensor(generate(10), dtype=torch.float32)

In [16]:
x_forward = vardeq.forward(x_test)[0]

In [18]:
vardeq.inverse(x_forward)

torch.Size([10, 2]) torch.Size([10, 4])


InputOutsideDomain: 

In [21]:
x_test

tensor([[5.3657e-01, 8.9568e-02, 7.0000e+00, 1.0000e+00],
        [1.5493e-01, 9.5132e-01, 7.0000e+00, 3.0000e+00],
        [7.8500e-01, 1.5821e-01, 7.0000e+00, 3.0000e+00],
        [8.1737e-01, 5.3751e-03, 7.0000e+00, 0.0000e+00],
        [5.2431e-01, 6.8089e-01, 7.0000e+00, 3.0000e+00],
        [5.6633e-02, 2.7536e-02, 4.0000e+00, 3.0000e+00],
        [8.1685e-01, 3.6908e-01, 7.0000e+00, 1.0000e+00],
        [1.8659e-01, 7.9795e-01, 5.0000e+00, 1.0000e+00],
        [5.7817e-01, 9.6970e-01, 2.0000e+00, 3.0000e+00],
        [6.4404e-01, 1.5107e-01, 3.0000e+00, 1.0000e+00]])

In [22]:
'''
x_plot = generate(1000000)
x_select_1 = x_plot[x_plot[:,2] == 7]
x_select_2 = x_select_1[x_select_1[:,3] == 3]
plt.hist2d(x_select_2[:,0], x_select_2[:,1], bins=25)
plt.show()
'''

'\nx_plot = generate(1000000)\nx_select_1 = x_plot[x_plot[:,2] == 7]\nx_select_2 = x_select_1[x_select_1[:,3] == 3]\nplt.hist2d(x_select_2[:,0], x_select_2[:,1], bins=25)\nplt.show()\n'