Skip to content

Commit

Permalink
masking code for autoregressive nn (#203)
Browse files Browse the repository at this point in the history
* initial commit

* remove dtype
  • Loading branch information
martinjankowiak authored and fehiepsi committed Jun 14, 2019
1 parent 899cc8c commit 31ea972
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
Empty file added numpyro/contrib/nn/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions numpyro/contrib/nn/auto_reg_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# lightly adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/nn/auto_reg_nn.py

from __future__ import absolute_import, division, print_function

import numpy as onp


def sample_mask_indices(input_dim, hidden_dim):
"""
Samples the indices assigned to hidden units during the construction of MADE masks
:param input_dim: the dimensionality of the input variable
:type input_dim: int
:param hidden_dim: the dimensionality of the hidden layer
:type hidden_dim: int
"""
indices = onp.linspace(1, input_dim, num=hidden_dim)
# Simple procedure tries to space fractional indices evenly by rounding to nearest int
return onp.round(indices)


def create_mask(input_dim, hidden_dims, permutation, output_dim_multiplier):
"""
Creates (non-conditional) MADE masks
:param input_dim: the dimensionality of the input variable
:type input_dim: int
:param hidden_dims: the dimensionality of the hidden layers(s)
:type hidden_dims: list[int]
:param permutation: the order of the input variables
:type permutation: numpy array of integers of length `input_dim`
:param output_dim_multiplier: tiles the output (e.g. for when a separate mean and scale parameter are desired)
:type output_dim_multiplier: int
"""
# Create mask indices for input, hidden layers, and final layer
var_index = onp.zeros(permutation.shape[0])
var_index[permutation] = onp.arange(input_dim)

# Create the indices that are assigned to the neurons
input_indices = 1 + var_index
hidden_indices = [sample_mask_indices(input_dim, h) for h in hidden_dims]
output_indices = onp.tile(var_index + 1, output_dim_multiplier)

# Create mask from input to output for the skips connections
mask_skip = output_indices[:, None] > input_indices[None, :]

# Create mask from input to first hidden layer, and between subsequent hidden layers
# NOTE: The masks created follow a slightly different pattern than that given in Germain et al. Figure 1
# The output first in the order (e.g. x_2 in the figure) is connected to hidden units rather than being unattached
# Tracing a path back through the network, however, this variable will still be unconnected to any input variables
masks = [hidden_indices[0][:, None] > input_indices[None, :]]
for i in range(1, len(hidden_dims)):
masks.append(hidden_indices[i][:, None] >= hidden_indices[i - 1][None, :])

# Create mask from last hidden layer to output layer
masks.append(output_indices[:, None] >= hidden_indices[-1][None, :])

return masks, mask_skip
52 changes: 52 additions & 0 deletions test/contrib/test_auto_regressive_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# lightly adapted from https://github.com/pyro-ppl/pyro/blob/dev/tests/nn/test_autoregressive.py

import pytest
import numpy as onp
from numpy.testing import assert_array_equal
from numpyro.contrib.nn.auto_reg_nn import create_mask


@pytest.mark.parametrize('input_dim', [5, 8])
@pytest.mark.parametrize('n_layers', [1, 3])
@pytest.mark.parametrize('output_dim_multiplier', [1, 4])
def test_masks(input_dim, n_layers, output_dim_multiplier):
hidden_dim = input_dim * 3
hidden_dims = [hidden_dim] * n_layers
permutation = onp.random.permutation(input_dim)
masks, mask_skip = create_mask(input_dim, hidden_dims, permutation, output_dim_multiplier)

# First test that hidden layer masks are adequately connected
# Tracing backwards, works out what inputs each output is connected to
# It's a dictionary of sets indexed by a tuple (input_dim, param_dim)
_permutation = list(permutation)

# Loop over variables
for idx in range(input_dim):
# Calculate correct answer
correct = onp.array(sorted(_permutation[0:onp.where(permutation == idx)[0][0]]))

# Loop over parameters for each variable
for jdx in range(output_dim_multiplier):
prev_connections = set()
# Do output-to-penultimate hidden layer mask
for kdx in range(masks[-1].shape[1]):
if masks[-1][idx + jdx * input_dim, kdx]:
prev_connections.add(kdx)

# Do hidden-to-hidden, and hidden-to-input layer masks
for m in reversed(masks[:-1]):
this_connections = set()
for kdx in prev_connections:
for ldx in range(m.shape[1]):
if m[kdx, ldx]:
this_connections.add(ldx)
prev_connections = this_connections

assert_array_equal(list(sorted(prev_connections)), correct)

# Test the skip-connections mask
skip_connections = set()
for kdx in range(mask_skip.shape[1]):
if mask_skip[idx + jdx * input_dim, kdx]:
skip_connections.add(kdx)
assert_array_equal(list(sorted(skip_connections)), correct)

0 comments on commit 31ea972

Please sign in to comment.