Skip to content

Commit

Permalink
Update torch_patch.py (#3174)
Browse files Browse the repository at this point in the history
* Update torch_patch.py

Removed the PositiveDefinite() function for the one from PyTorch to be used instead, as is what less precise resulting in errors for certain cases.

* Update test_linear_models_eig.py

Passed scale_tril to dist.MultivariateNormal as keyword argument instead of positional argument as it otherwise gets identified as being a covariance matrix.

* Update gp.ipynb

* Update gp.ipynb

* Changed float precision to 64 in cell 27
  • Loading branch information
S163669 committed Jan 27, 2023
1 parent 1678ee2 commit 9a858da
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 14 deletions.
11 changes: 0 additions & 11 deletions pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,6 @@ def _HalfCauchy_logprob(self, value):
return log_prob


# TODO fix batch_shape have an extra singleton dimension upstream
@patch_dependency("torch.distributions.constraints._PositiveDefinite.check")
def _PositiveDefinite_check(self, value):
matrix_shape = value.shape[-2:]
batch_shape = value.shape[:-2]
flattened_value = value.reshape((-1,) + matrix_shape)
return torch.stack(
[torch.linalg.eigvalsh(v)[:1] > 0.0 for v in flattened_value]
).view(batch_shape)


@patch_dependency("torch.distributions.constraints._CorrCholesky.check")
def _CorrCholesky_check(self, value):
row_norm = torch.linalg.norm(value.detach(), dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/oed/test_linear_models_eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def marginal_guide(design, observation_labels, target_labels):
torch.eye(3),
constraint=torch.distributions.constraints.lower_cholesky,
)
pyro.sample("y", dist.MultivariateNormal(mu, scale_tril))
pyro.sample("y", dist.MultivariateNormal(mu, scale_tril=scale_tril))


def likelihood_guide(theta_dict, design, observation_labels, target_labels):
Expand Down
5 changes: 3 additions & 2 deletions tutorial/source/gp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
"\n",
"smoke_test = \"CI\" in os.environ # ignore; used to check code integrity in the Pyro repo\n",
"assert pyro.__version__.startswith('1.8.4')\n",
"pyro.set_rng_seed(0)"
"pyro.set_rng_seed(0)\n",
"torch.set_default_tensor_type(torch.DoubleTensor)"
]
},
{
Expand Down Expand Up @@ -1054,7 +1055,7 @@
"source": [
"# only take petal length and petal width\n",
"X = torch.from_numpy(\n",
" df[df.columns[2:4]].values.astype(\"float32\"),\n",
" df[df.columns[2:4]].values.astype(\"float64\"),\n",
")\n",
"df[\"species\"] = df[\"species\"].astype(\"category\")\n",
"# encode the species as 0, 1, 2\n",
Expand Down

0 comments on commit 9a858da

Please sign in to comment.