In [2]:
import math as m
import numpy as np
import random as r
import matplotlib.pyplot as plt

In [4]:
import torch
from torch import nn
from torch import optim

In [6]:
from nflows.distributions.uniform import BoxUniform
from nflows.transforms.base import CompositeTransform
from nflows.flows.base import Flow
from nflows.distributions.dropout import UniformStochasticDropout
from nflows.distributions.dropout import VariationalStochasticDropout

In [8]:
n_probs = 5
x_drop = torch.rand(n_probs)
p_drop = x_drop/(torch.sum(x_drop))
p_cum_drop = torch.cumsum(p_drop, dim=0)
drop_indices = torch.tensor([0,0,1,1,1,2,3,3,4])

In [10]:
def generate(n, drop_indices):
    x = torch.rand(n, drop_indices.shape[0])

    # Tensor with bools that are true when r passes the cumprob
    larger_than_cumprob = torch.rand(n,1) < p_cum_drop
    # Do the arange trick to find first nonzero
    selected_index = torch.argmax(larger_than_cumprob*torch.arange(n_probs, 0, -1), axis=1)

    # Get the values of the selected llhs
    selected_probs = p_drop[selected_index]

    # Find the index of the first true
    drop_mask = drop_indices > selected_index[:,None]
    x[drop_mask] = 0
    
    return x, torch.log(selected_probs)

In [12]:
uniform_dropout = UniformStochasticDropout(drop_indices)
variational_dropout = VariationalStochasticDropout(drop_indices)
optimizer_uniform = torch.optim.Adam(uniform_dropout.parameters())
optimizer_variational = torch.optim.Adam(variational_dropout.parameters())

In [14]:
variational_dropout.sample(10)

tensor([[0.9681, 0.4382, 0.6580, 0.4396, 0.8842, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9657, 0.2708, 0.3083, 0.0418, 0.3040, 0.0210, 0.8117, 0.0328, 0.0000],
        [0.0614, 0.4462, 0.8920, 0.6788, 0.5892, 0.5509, 0.8516, 0.7177, 0.3241],
        [0.8513, 0.1258, 0.4332, 0.6951, 0.2300, 0.7434, 0.0407, 0.2823, 0.0000],
        [0.8820, 0.8708, 0.3588, 0.1751, 0.4674, 0.5946, 0.7861, 0.1213, 0.8394],
        [0.3350, 0.5908, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0800, 0.2039, 0.0677, 0.2359, 0.2074, 0.4105, 0.3680, 0.7247, 0.4096],
        [0.9672, 0.0722, 0.4764, 0.5729, 0.4611, 0.0871, 0.7416, 0.7464, 0.5334],
        [0.5363, 0.6451, 0.6556, 0.2093, 0.2832, 0.8169, 0.3032, 0.1311, 0.0000],
        [0.1538, 0.6937, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

In [16]:
n_epochs = 1000
n_batch = 500
for i in range(n_epochs):
    x,llh = generate(n_batch, drop_indices)

    optimizer_uniform.zero_grad()
    optimizer_variational.zero_grad()

    loss_uniform = -uniform_dropout.log_prob(x).mean()
    loss_variational = -variational_dropout.log_prob(x).mean()

    loss_uniform.backward()
    loss_variational.backward()

    optimizer_uniform.step()
    optimizer_variational.step()

    if i%50 == 0:
        print(loss_uniform, loss_variational)

tensor(1.6045, grad_fn=<NegBackward>) tensor(1.6243, grad_fn=<NegBackward>)
tensor(1.6036, grad_fn=<NegBackward>) tensor(1.6240, grad_fn=<NegBackward>)
tensor(1.6042, grad_fn=<NegBackward>) tensor(1.6147, grad_fn=<NegBackward>)
tensor(1.6041, grad_fn=<NegBackward>) tensor(1.6296, grad_fn=<NegBackward>)
tensor(1.6067, grad_fn=<NegBackward>) tensor(1.6261, grad_fn=<NegBackward>)
tensor(1.6058, grad_fn=<NegBackward>) tensor(1.6147, grad_fn=<NegBackward>)
tensor(1.6078, grad_fn=<NegBackward>) tensor(1.6238, grad_fn=<NegBackward>)
tensor(1.6053, grad_fn=<NegBackward>) tensor(1.6185, grad_fn=<NegBackward>)
tensor(1.6101, grad_fn=<NegBackward>) tensor(1.6298, grad_fn=<NegBackward>)
tensor(1.6028, grad_fn=<NegBackward>) tensor(1.6183, grad_fn=<NegBackward>)
tensor(1.6083, grad_fn=<NegBackward>) tensor(1.6271, grad_fn=<NegBackward>)
tensor(1.6060, grad_fn=<NegBackward>) tensor(1.6193, grad_fn=<NegBackward>)
tensor(1.6072, grad_fn=<NegBackward>) tensor(1.6135, grad_fn=<NegBackward>)
tensor(1.606

In [18]:
x, llh = generate(10, drop_indices)
print(torch.exp(llh))
print(torch.exp(uniform_dropout.log_prob(x)))
print(torch.exp(variational_dropout.log_prob(x)))

tensor([0.1870, 0.2154, 0.2140, 0.1785, 0.2154, 0.2140, 0.1785, 0.2140, 0.2140,
        0.1785])
tensor([0.1866, 0.2140, 0.2125, 0.1812, 0.2140, 0.2125, 0.1812, 0.2125, 0.2125,
        0.1812], grad_fn=<ExpBackward>)
tensor([0.1915, 0.2647, 0.2084, 0.2383, 0.1968, 0.2274, 0.2429, 0.2421, 0.1821,
        0.2173], grad_fn=<ExpBackward>)
