-
Notifications
You must be signed in to change notification settings - Fork 222
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
masking code for autoregressive nn (#203)
* initial commit * remove dtype
- Loading branch information
1 parent
899cc8c
commit 31ea972
Showing
3 changed files
with
110 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# lightly adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/nn/auto_reg_nn.py | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import numpy as onp | ||
|
||
|
||
def sample_mask_indices(input_dim, hidden_dim): | ||
""" | ||
Samples the indices assigned to hidden units during the construction of MADE masks | ||
:param input_dim: the dimensionality of the input variable | ||
:type input_dim: int | ||
:param hidden_dim: the dimensionality of the hidden layer | ||
:type hidden_dim: int | ||
""" | ||
indices = onp.linspace(1, input_dim, num=hidden_dim) | ||
# Simple procedure tries to space fractional indices evenly by rounding to nearest int | ||
return onp.round(indices) | ||
|
||
|
||
def create_mask(input_dim, hidden_dims, permutation, output_dim_multiplier): | ||
""" | ||
Creates (non-conditional) MADE masks | ||
:param input_dim: the dimensionality of the input variable | ||
:type input_dim: int | ||
:param hidden_dims: the dimensionality of the hidden layers(s) | ||
:type hidden_dims: list[int] | ||
:param permutation: the order of the input variables | ||
:type permutation: numpy array of integers of length `input_dim` | ||
:param output_dim_multiplier: tiles the output (e.g. for when a separate mean and scale parameter are desired) | ||
:type output_dim_multiplier: int | ||
""" | ||
# Create mask indices for input, hidden layers, and final layer | ||
var_index = onp.zeros(permutation.shape[0]) | ||
var_index[permutation] = onp.arange(input_dim) | ||
|
||
# Create the indices that are assigned to the neurons | ||
input_indices = 1 + var_index | ||
hidden_indices = [sample_mask_indices(input_dim, h) for h in hidden_dims] | ||
output_indices = onp.tile(var_index + 1, output_dim_multiplier) | ||
|
||
# Create mask from input to output for the skips connections | ||
mask_skip = output_indices[:, None] > input_indices[None, :] | ||
|
||
# Create mask from input to first hidden layer, and between subsequent hidden layers | ||
# NOTE: The masks created follow a slightly different pattern than that given in Germain et al. Figure 1 | ||
# The output first in the order (e.g. x_2 in the figure) is connected to hidden units rather than being unattached | ||
# Tracing a path back through the network, however, this variable will still be unconnected to any input variables | ||
masks = [hidden_indices[0][:, None] > input_indices[None, :]] | ||
for i in range(1, len(hidden_dims)): | ||
masks.append(hidden_indices[i][:, None] >= hidden_indices[i - 1][None, :]) | ||
|
||
# Create mask from last hidden layer to output layer | ||
masks.append(output_indices[:, None] >= hidden_indices[-1][None, :]) | ||
|
||
return masks, mask_skip |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# lightly adapted from https://github.com/pyro-ppl/pyro/blob/dev/tests/nn/test_autoregressive.py | ||
|
||
import pytest | ||
import numpy as onp | ||
from numpy.testing import assert_array_equal | ||
from numpyro.contrib.nn.auto_reg_nn import create_mask | ||
|
||
|
||
@pytest.mark.parametrize('input_dim', [5, 8]) | ||
@pytest.mark.parametrize('n_layers', [1, 3]) | ||
@pytest.mark.parametrize('output_dim_multiplier', [1, 4]) | ||
def test_masks(input_dim, n_layers, output_dim_multiplier): | ||
hidden_dim = input_dim * 3 | ||
hidden_dims = [hidden_dim] * n_layers | ||
permutation = onp.random.permutation(input_dim) | ||
masks, mask_skip = create_mask(input_dim, hidden_dims, permutation, output_dim_multiplier) | ||
|
||
# First test that hidden layer masks are adequately connected | ||
# Tracing backwards, works out what inputs each output is connected to | ||
# It's a dictionary of sets indexed by a tuple (input_dim, param_dim) | ||
_permutation = list(permutation) | ||
|
||
# Loop over variables | ||
for idx in range(input_dim): | ||
# Calculate correct answer | ||
correct = onp.array(sorted(_permutation[0:onp.where(permutation == idx)[0][0]])) | ||
|
||
# Loop over parameters for each variable | ||
for jdx in range(output_dim_multiplier): | ||
prev_connections = set() | ||
# Do output-to-penultimate hidden layer mask | ||
for kdx in range(masks[-1].shape[1]): | ||
if masks[-1][idx + jdx * input_dim, kdx]: | ||
prev_connections.add(kdx) | ||
|
||
# Do hidden-to-hidden, and hidden-to-input layer masks | ||
for m in reversed(masks[:-1]): | ||
this_connections = set() | ||
for kdx in prev_connections: | ||
for ldx in range(m.shape[1]): | ||
if m[kdx, ldx]: | ||
this_connections.add(ldx) | ||
prev_connections = this_connections | ||
|
||
assert_array_equal(list(sorted(prev_connections)), correct) | ||
|
||
# Test the skip-connections mask | ||
skip_connections = set() | ||
for kdx in range(mask_skip.shape[1]): | ||
if mask_skip[idx + jdx * input_dim, kdx]: | ||
skip_connections.add(kdx) | ||
assert_array_equal(list(sorted(skip_connections)), correct) |