Skip to content

Commit

Permalink
reuse conjugate gaussian chain for gradient tests (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak authored and fritzo committed Mar 6, 2018
1 parent 4ff3ba0 commit 748e344
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 124 deletions.
31 changes: 31 additions & 0 deletions tests/infer/test_conjugate_gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

import pyro
from pyro.infer.tracegraph_elbo import TraceGraph_ELBO
from tests.common import assert_equal
from tests.integration_tests.test_conjugate_gaussian_models import GaussianChain


# TODO increase precision and number of particles once latter is parallelized properly
class ConjugateChainGradientTests(GaussianChain):

def test_gradients(self):
for N in [3, 5]:
for reparameterized in [True, False]:
self.do_test_gradients(N, reparameterized)

def do_test_gradients(self, N, reparameterized):
pyro.clear_param_store()
self.setup_chain(N)

elbo = TraceGraph_ELBO(num_particles=1000)
elbo.loss_and_grads(self.model, self.guide, reparameterized=reparameterized)

for i in range(1, N + 1):
for param_prefix in ["mu_q_%d", "log_sig_q_%d", "kappa_q_%d"]:
if i == N and param_prefix == 'kappa_q_%d':
continue
actual_grad = pyro.param(param_prefix % i).grad
assert_equal(actual_grad, 0.0 * actual_grad, prec=0.20, msg="".join([
"parameter %s%d" % (param_prefix[:-2], i),
"\nexpected = zero vector",
"\n actual = {}".format(actual_grad.detach().cpu().numpy())]))
250 changes: 126 additions & 124 deletions tests/integration_tests/test_conjugate_gaussian_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import pytest
import torch
from torch.autograd import Variable
from torch.autograd import Variable, variable

import pyro
import pyro.distributions as dist
Expand All @@ -25,20 +25,14 @@ def param_mse(name, target):
return torch.sum(torch.pow(target - pyro.param(name), 2.0)).detach().cpu().item()


@pytest.mark.stage("integration", "integration_batch_1")
@pytest.mark.init(rng_seed=0)
class GaussianChainTests(TestCase):
class GaussianChain(TestCase):
# chain of normals with known covariances and latent means

def setUp(self):
self.mu0 = Variable(torch.Tensor([0.2]))
self.data = []
self.data.append(Variable(torch.Tensor([-0.1])))
self.data.append(Variable(torch.Tensor([0.03])))
self.data.append(Variable(torch.Tensor([0.20])))
self.data.append(Variable(torch.Tensor([0.10])))
self.n_data = Variable(torch.Tensor([len(self.data)]))
self.sum_data = self.data[0] + self.data[1] + self.data[2] + self.data[3]
self.data = Variable(torch.Tensor([-0.1, 0.03, 0.20, 0.10]))
self.n_data = self.data.size(0)
self.sum_data = self.data.sum()

def setup_chain(self, N):
self.N = N # number of latent variables in the chain
Expand All @@ -53,7 +47,7 @@ def setup_chain(self, N):
for k in range(1, self.N):
lambda_k = self.lambdas[k] + self.lambda_tilde_posts[k - 1]
self.lambda_posts.append(lambda_k)
lambda_N_post = (self.n_data.expand_as(self.lambdas[N]) * self.lambdas[N]) +\
lambda_N_post = (self.n_data * variable(1.0).expand_as(self.lambdas[N]) * self.lambdas[N]) +\
self.lambda_tilde_posts[N - 1]
self.lambda_posts.append(lambda_N_post)
self.target_kappas = [None]
Expand All @@ -73,6 +67,48 @@ def setup_reparam_mask(self, N):
if torch.sum(mask) < 0.40 * N and torch.sum(mask) > 0.5:
return mask

def model(self, reparameterized, difficulty=0.0):
next_mean = self.mu0
for k in range(1, self.N + 1):
latent_dist = dist.Normal(next_mean, torch.pow(self.lambdas[k - 1], -0.5))
mu_latent = pyro.sample("mu_latent_%d" % k, latent_dist)
next_mean = mu_latent

mu_N = next_mean
with pyro.iarange("data", self.data.size(0)):
pyro.sample("obs", dist.Normal(mu_N.expand_as(self.data),
torch.pow(self.lambdas[self.N], -0.5).expand_as(self.data)), obs=self.data)
return mu_N

def guide(self, reparameterized, difficulty=0.0):
previous_sample = None
for k in reversed(range(1, self.N + 1)):
mu_q = pyro.param("mu_q_%d" % k, Variable(self.target_mus[k].data +
difficulty * (0.1 * torch.randn(1) - 0.53),
requires_grad=True))
log_sig_q = pyro.param("log_sig_q_%d" % k,
Variable(-0.5 * torch.log(self.lambda_posts[k]).data +
difficulty * (0.1 * torch.randn(1) - 0.53),
requires_grad=True))
sig_q = torch.exp(log_sig_q)
kappa_q = None
if k != self.N:
kappa_q = pyro.param("kappa_q_%d" % k, Variable(self.target_kappas[k].data +
difficulty * (0.1 * torch.randn(1) - 0.53),
requires_grad=True))
mean_function = mu_q if k == self.N else kappa_q * previous_sample + mu_q
node_flagged = True if self.which_nodes_reparam[k - 1] == 1.0 else False
Normal = dist.Normal if reparameterized or node_flagged else fakes.NonreparameterizedNormal
mu_latent = pyro.sample("mu_latent_%d" % k, Normal(mean_function, sig_q),
infer=dict(baseline=dict(use_decaying_avg_baseline=True)))
previous_sample = mu_latent
return previous_sample


@pytest.mark.stage("integration", "integration_batch_1")
@pytest.mark.init(rng_seed=0)
class GaussianChainTests(GaussianChain):

def test_elbo_reparameterized_N_is_3(self):
self.setup_chain(3)
self.do_elbo_test(True, 4000, 0.0015, 0.03, difficulty=1.0)
Expand Down Expand Up @@ -116,49 +152,12 @@ def array_to_string(y):
logger.debug("lambda_tilde_posts: " + array_to_string(self.lambda_tilde_posts))
pyro.clear_param_store()

def model(*args, **kwargs):
next_mean = self.mu0
for k in range(1, self.N + 1):
latent_dist = dist.Normal(next_mean, torch.pow(self.lambdas[k - 1], -0.5))
mu_latent = pyro.sample("mu_latent_%d" % k, latent_dist)
next_mean = mu_latent

mu_N = next_mean
for i, x in enumerate(self.data):
pyro.sample("obs_%d" % i, dist.Normal(mu_N, torch.pow(self.lambdas[self.N], -0.5)),
obs=x)
return mu_N

def guide(*args, **kwargs):
previous_sample = None
for k in reversed(range(1, self.N + 1)):
mu_q = pyro.param("mu_q_%d" % k, Variable(self.target_mus[k].data +
difficulty * (0.1 * torch.randn(1) - 0.53),
requires_grad=True))
log_sig_q = pyro.param("log_sig_q_%d" % k,
Variable(-0.5 * torch.log(self.lambda_posts[k]).data +
difficulty * (0.1 * torch.randn(1) - 0.53),
requires_grad=True))
sig_q = torch.exp(log_sig_q)
kappa_q = None if k == self.N \
else pyro.param("kappa_q_%d" % k,
Variable(self.target_kappas[k].data +
difficulty * (0.1 * torch.randn(1) - 0.53),
requires_grad=True))
mean_function = mu_q if k == self.N else kappa_q * previous_sample + mu_q
node_flagged = True if self.which_nodes_reparam[k - 1] == 1.0 else False
Normal = dist.Normal if reparameterized or node_flagged else fakes.NonreparameterizedNormal
mu_latent = pyro.sample("mu_latent_%d" % k, Normal(mean_function, sig_q),
infer=dict(baseline=dict(use_decaying_avg_baseline=True)))
previous_sample = mu_latent
return previous_sample

adam = optim.Adam({"lr": lr, "betas": (0.95, 0.999)})
svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)
svi = SVI(self.model, self.guide, adam, loss="ELBO", trace_graph=True)

for step in range(n_steps):
t0 = time.time()
svi.step()
svi.step(reparameterized=reparameterized, difficulty=difficulty)

if step % 5000 == 0 or step == n_steps - 1:
kappa_errors, log_sig_errors, mu_errors = [], [], []
Expand Down Expand Up @@ -364,6 +363,76 @@ def add_edge(s):

return g

def model(self, reparameterized, model_permutation, difficulty=0.0):
top_latent_dist = dist.Normal(self.mu0, torch.pow(self.lambdas[0], -0.5))
previous_names = ["mu_latent_1"]
top_latent = pyro.sample(previous_names[0], top_latent_dist)
previous_latents_and_names = list(zip([top_latent], previous_names))

# for sampling model variables in different sequential orders
def permute(x, n):
if model_permutation:
return [x[self.model_permutations[n - 1][i]] for i in range(len(x))]
return x

def unpermute(x, n):
if model_permutation:
return [x[self.model_unpermutations[n - 1][i]] for i in range(len(x))]
return x

for n in range(2, self.N + 1):
new_latents_and_names = []
for prev_latent, prev_name in permute(previous_latents_and_names, n - 1):
latent_dist = dist.Normal(prev_latent, torch.pow(self.lambdas[n - 1], -0.5))
couple = []
for LR in ['L', 'R']:
new_name = prev_name + LR
mu_latent_LR = pyro.sample(new_name, latent_dist)
couple.append([mu_latent_LR, new_name])
new_latents_and_names.append(couple)
_previous_latents_and_names = unpermute(new_latents_and_names, n - 1)
previous_latents_and_names = []
for x in _previous_latents_and_names:
previous_latents_and_names.append(x[0])
previous_latents_and_names.append(x[1])

for i, data_i in enumerate(self.data):
for k, x in enumerate(data_i):
pyro.sample("obs_%s_%d" % (previous_latents_and_names[i][1], k),
dist.Normal(previous_latents_and_names[i][0], torch.pow(self.lambdas[-1], -0.5)),
obs=x)
return top_latent

def guide(self, reparameterized, model_permutation, difficulty=0.0):
latents_dict = {}

n_nodes = len(self.q_topo_sort)
for i, node in enumerate(self.q_topo_sort):
deps = self.q_dag.predecessors(node)
node_suffix = node[10:]
log_sig_node = pyro.param("log_sig_" + node_suffix,
Variable(-0.5 * torch.log(self.target_lambdas[node_suffix]).data +
difficulty * (torch.Tensor([-0.3]) -
0.3 * (torch.randn(1) ** 2)),
requires_grad=True))
mean_function_node = pyro.param("constant_term_" + node,
Variable(self.mu0.data +
torch.Tensor([difficulty * i / n_nodes]),
requires_grad=True))
for dep in deps:
kappa_dep = pyro.param("kappa_" + node_suffix + '_' + dep[10:],
Variable(torch.Tensor([0.5 + difficulty * i / n_nodes]),
requires_grad=True))
mean_function_node = mean_function_node + kappa_dep * latents_dict[dep]
node_flagged = True if self.which_nodes_reparam[i] == 1.0 else False
Normal = dist.Normal if reparameterized or node_flagged else fakes.NonreparameterizedNormal
latent_node = pyro.sample(node, Normal(mean_function_node, torch.exp(log_sig_node)),
infer=dict(baseline=dict(use_decaying_avg_baseline=True,
baseline_beta=0.96)))
latents_dict[node] = latent_node

return latents_dict['mu_latent_1']

def do_elbo_test(self, reparameterized, n_steps, lr, prec, beta1,
difficulty=1.0, model_permutation=False):
n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized \
Expand All @@ -374,79 +443,12 @@ def do_elbo_test(self, reparameterized, n_steps, lr, prec, beta1,
len(self.q_topo_sort), model_permutation))
pyro.clear_param_store()

def model(*args, **kwargs):
top_latent_dist = dist.Normal(self.mu0, torch.pow(self.lambdas[0], -0.5))
previous_names = ["mu_latent_1"]
top_latent = pyro.sample(previous_names[0], top_latent_dist)
previous_latents_and_names = list(zip([top_latent], previous_names))

# for sampling model variables in different sequential orders
def permute(x, n):
if model_permutation:
return [x[self.model_permutations[n - 1][i]] for i in range(len(x))]
return x

def unpermute(x, n):
if model_permutation:
return [x[self.model_unpermutations[n - 1][i]] for i in range(len(x))]
return x

for n in range(2, self.N + 1):
new_latents_and_names = []
for prev_latent, prev_name in permute(previous_latents_and_names, n - 1):
latent_dist = dist.Normal(prev_latent, torch.pow(self.lambdas[n - 1], -0.5))
couple = []
for LR in ['L', 'R']:
new_name = prev_name + LR
mu_latent_LR = pyro.sample(new_name, latent_dist)
couple.append([mu_latent_LR, new_name])
new_latents_and_names.append(couple)
_previous_latents_and_names = unpermute(new_latents_and_names, n - 1)
previous_latents_and_names = []
for x in _previous_latents_and_names:
previous_latents_and_names.append(x[0])
previous_latents_and_names.append(x[1])

for i, data_i in enumerate(self.data):
for k, x in enumerate(data_i):
pyro.sample("obs_%s_%d" % (previous_latents_and_names[i][1], k),
dist.Normal(previous_latents_and_names[i][0], torch.pow(self.lambdas[-1], -0.5)),
obs=x)
return top_latent

def guide(*args, **kwargs):
latents_dict = {}

n_nodes = len(self.q_topo_sort)
for i, node in enumerate(self.q_topo_sort):
deps = self.q_dag.predecessors(node)
node_suffix = node[10:]
log_sig_node = pyro.param("log_sig_" + node_suffix,
Variable(-0.5 * torch.log(self.target_lambdas[node_suffix]).data +
difficulty * (torch.Tensor([-0.3]) -
0.3 * (torch.randn(1) ** 2)),
requires_grad=True))
mean_function_node = pyro.param("constant_term_" + node,
Variable(self.mu0.data +
torch.Tensor([difficulty * i / n_nodes]),
requires_grad=True))
for dep in deps:
kappa_dep = pyro.param("kappa_" + node_suffix + '_' + dep[10:],
Variable(torch.Tensor([0.5 + difficulty * i / n_nodes]),
requires_grad=True))
mean_function_node = mean_function_node + kappa_dep * latents_dict[dep]
node_flagged = True if self.which_nodes_reparam[i] == 1.0 else False
Normal = dist.Normal if reparameterized or node_flagged else fakes.NonreparameterizedNormal
latent_node = pyro.sample(node, Normal(mean_function_node, torch.exp(log_sig_node)),
infer=dict(baseline=dict(use_decaying_avg_baseline=True,
baseline_beta=0.96)))
latents_dict[node] = latent_node

return latents_dict['mu_latent_1']

# check graph structure is as expected but only for N=2
if self.N == 2:
guide_trace = pyro.poutine.trace(guide, graph_type="dense").get_trace()
guide_trace = pyro.poutine.trace(self.guide,
graph_type="dense").get_trace(reparameterized=reparameterized,
model_permutation=model_permutation,
difficulty=difficulty)
expected_nodes = set(['log_sig_1R', 'kappa_1_1L', '_INPUT', 'constant_term_mu_latent_1R', '_RETURN',
'mu_latent_1R', 'mu_latent_1', 'constant_term_mu_latent_1', 'mu_latent_1L',
'constant_term_mu_latent_1L', 'log_sig_1L', 'kappa_1_1R', 'kappa_1R_1L', 'log_sig_1'])
Expand All @@ -456,11 +458,11 @@ def guide(*args, **kwargs):
assert expected_edges == set(guide_trace.edges)

adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)})
svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)
svi = SVI(self.model, self.guide, adam, loss="ELBO", trace_graph=True)

for step in range(n_steps):
t0 = time.time()
svi.step()
svi.step(reparameterized=reparameterized, model_permutation=model_permutation, difficulty=difficulty)

if step % 5000 == 0 or step == n_steps - 1:
log_sig_errors = []
Expand Down

0 comments on commit 748e344

Please sign in to comment.