In [1]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt

In [2]:
import torch
from torch import nn
from torch import optim

In [3]:
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.flows.base import Flow
from nflows.transforms.dropout import UniformStochasticDropout
from nflows.transforms.dropout import VariationalStochasticDropout
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform

In [4]:
#device = torch.device("cuda:0")
device = torch.device("cpu")

In [5]:
# This works with any size x
def p(x, n_probs):
    sums = torch.sum(x, axis=1)
    probs = torch.cos(torch.ger(sums, torch.arange(1, n_probs+1, dtype=torch.float32)))**2
    norm = torch.sum(probs, axis=1)

    for i in range(n_probs):
        probs[:,i] /= norm
    
    return probs

In [6]:
def generate(n, drop_indices):
    n_probs = torch.max(drop_indices) + 1
    x = torch.rand(n, drop_indices.shape[0])
    probs = p(x, n_probs)

    # Pick a prob
    probs_cumsum = torch.cumsum(probs, axis=1)

    # Tensor with bools that are true when r passes the cumprob
    larger_than_cumprob = torch.rand(n,1) < probs_cumsum
    # Do the arange trick to find first nonzero
    # This is the HIGHEST LABEL FROM DROP_INDICES THAT IS KEPT
    selected_index = torch.argmax(larger_than_cumprob*torch.arange(n_probs, 0, -1), axis=1)

    '''
    print("The index of the selected probability")
    print("This is also the highest label in drop_indices that is kept")
    print(selected_index)
    ''' 
    
    # Find the index of the first true
    drop_mask = drop_indices > selected_index[:,None]
    x[drop_mask] = 0
    
    return x

In [7]:
drop_indices = torch.tensor([0,0,1,1,1,2,3,3,4])
n_data = int(1e6)
x_data = generate(n_data, drop_indices).to(device)

In [8]:
drop_indices.shape[0]

9

In [9]:
num_layers = 6
base_dist_uniform = BoxUniform(torch.zeros(drop_indices.shape[0]), torch.ones(drop_indices.shape[0]))
#base_dist_variational = BoxUniform(drop_indices.shape[0], drop_indices.shape[0])

transforms_uniform = []
#transforms_variational = []

transforms_uniform.append(UniformStochasticDropout(drop_indices))
#transforms_variational.append(VariationalStochasticDropout(drop_indices))

for _ in range(num_layers):
    transforms_uniform.append(RandomPermutation(features=drop_indices.shape[0]))
    transforms_uniform.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=drop_indices.shape[0], 
        hidden_features=25,
        num_bins=10,
        num_blocks=4,
    ))

    #transforms_variational.append(RandomPermutation(features=drop_indices.shape[0]))
    #transforms_variational.append(MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
    #    features=drop_indices.shape[0], 
    #    hidden_features=25,
    #    num_bins=10,
    #    num_blocks=4,
    #))

transform_uniform = CompositeTransform(transforms_uniform)
#transform_variational = CompositeTransform(transforms_variational)

flow_uniform = Flow(transform_uniform, base_dist_uniform).to(device)
#flow_variational = Flow(transform_variational, base_dist_variational).to(device)

optimizer_uniform = optim.Adam(flow_uniform.parameters())
#optimizer_variational = optim.Adam(flow_variational.parameters())

In [10]:
n_epochs = 10
batch_size = 1000
n_batches = m.ceil(x_data.shape[0]/batch_size)

for epoch in range(n_epochs):
    permutation = torch.randperm(x_data.shape[0], device=device)    

    # Loop over batches
    cum_loss_uniform = 0
    #cum_loss_variational = 0
    for batch in range(n_batches):
        # Set up the batch
        batch_begin = batch*batch_size
        batch_end   = min( (batch+1)*batch_size, x_data.shape[0]-1 )
        indices = permutation[batch_begin:batch_end]
        batch_x = x_data[indices]
        
        # Take a step
        optimizer_uniform.zero_grad()
        #optimizer_variational.zero_grad()

        loss_uniform = -flow_uniform.log_prob(inputs=batch_x).mean()
        #loss_variational = -flow_variational.log_prob(inputs=batch_x).mean()

        loss_uniform.backward()
        #loss_variational.backward()

        optimizer_uniform.step()
        #optimizer_variational.step()

        # Compute cumulative loss
        cum_loss_uniform = (cum_loss_uniform*batch + loss_uniform.item())/(batch+1)
        #cum_loss_variational = (cum_loss_variational*batch + loss_variational.item())/(batch+1)

        print("epoch = ", epoch, "batch = ",batch+1, "/", n_batches, "loss_uniform = ", cum_loss_uniform)#, " loss_variational = ", cum_loss_variational)

uniform =  1.7655362319069496
epoch =  0 batch =  708 / 1000 loss_uniform =  1.7653532401990075
epoch =  0 batch =  709 / 1000 loss_uniform =  1.7651719502232468
epoch =  0 batch =  710 / 1000 loss_uniform =  1.7650157156124915
epoch =  0 batch =  711 / 1000 loss_uniform =  1.7648237995625207
epoch =  0 batch =  712 / 1000 loss_uniform =  1.7646351592259455
epoch =  0 batch =  713 / 1000 loss_uniform =  1.7644597338760712
epoch =  0 batch =  714 / 1000 loss_uniform =  1.7642669632655226
epoch =  0 batch =  715 / 1000 loss_uniform =  1.7640994090300333
epoch =  0 batch =  716 / 1000 loss_uniform =  1.7639266527231838
epoch =  0 batch =  717 / 1000 loss_uniform =  1.7637540145232913
epoch =  0 batch =  718 / 1000 loss_uniform =  1.7635827881380037
epoch =  0 batch =  719 / 1000 loss_uniform =  1.7634065749748022
epoch =  0 batch =  720 / 1000 loss_uniform =  1.7632252130243506
epoch =  0 batch =  721 / 1000 loss_uniform =  1.763056726958318
epoch =  0 batch =  722 / 1000 loss_uniform =  

KeyboardInterrupt: 

In [11]:
n_sample = 100
with torch.no_grad():
    x_uniform = flow_uniform.sample(n_sample).cpu()
    #x_variational = flow_variational.sample(n_sample).cpu()

In [12]:
plt.hist(np.count_nonzero(x_data.numpy(), axis=1), np.linspace(-0.5,4.5,6), histtype='stepfilled', edgecolor="black", facecolor="lightgray", density=True)
plt.hist(np.count_nonzero(x_flow.numpy(), axis=1), np.linspace(-0.5,4.5,6), edgecolor="red", histtype="step", density=True)
plt.show()

NameError: name 'x_data' is not defined

In [13]:
x_random = torch.rand(10,4)

In [14]:
p(x_random)

TypeError: p() missing 1 required positional argument: 'n_probs'

RuntimeError: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 1

In [15]:
test = StochasticDropout(torch.tensor([1,2,3,4]), hidden_layers=2)

In [16]:
a = test.inverse(torch.rand(10,4))

ModuleAttributeError: 'StochasticDropout' object has no attribute 'inverse'