Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added type assertion for better code clarity #3036

Merged
merged 8 commits into from
Mar 15, 2022
Merged

Conversation

GautamV234
Copy link
Contributor

@GautamV234 GautamV234 commented Mar 4, 2022

I am working with Prof. @nipunbatra and I am currently working on this PR to solve #3026.

@fritzo
Copy link
Member

fritzo commented Mar 4, 2022

Can you make lint and ensure your code is well formatted? I believe you'll need to change the assertion to

assert isinstance(X, torch.Tensor)

@fehiepsi
Copy link
Member

fehiepsi commented Mar 5, 2022

I thought that it was recommended to use torch.is_tensor(x) but looking at its docs again, it recommends to use is_instance(...) pattern.

@GautamV234
Copy link
Contributor Author

Can you make lint and ensure your code is well formatted? I believe you'll need to change the assertion to

assert isinstance(X, torch.Tensor)

Sure I will do that and resubmit

@GautamV234
Copy link
Contributor Author

I have modified all the files (added assertion checks) having this problem in the pyro/contrib/gp/models folder

Below is an example that illustrates the assertion error.

import pyro
import pyro.contrib.gp as gp
import numpy as np
import torch 
X = np.random.randn(200, 2)
Y = np.logical_xor(X[:, 0] > 0, X[:, 1] > 0).astype(np.float64) 
kernel = gp.kernels.RBF(input_dim=2)
likelihood = gp.likelihoods.Binary()
model = gp.models.VariationalGP(X, Y, kernel, likelihood=likelihood, whiten=True)

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Input In [6], in <cell line: 11>()
      9 kernel = gp.kernels.RBF(input_dim=2)
     10 likelihood = gp.likelihoods.Binary()
---> 11 model = gp.models.VariationalGP(X, Y, kernel, likelihood=likelihood, whiten=True)

File C:\work\Pyro\pyro\pyro\contrib\gp\models\vgp.py:74, in VariationalGP.__init__(self, X, y, kernel, likelihood, mean_function, latent_shape, whiten, jitter)     63 def __init__(
     64     self,
     65     X,
   (...)
     72     jitter=1e-6,
     73 ):
---> 74     assert isinstance(
     75         X, torch.Tensor
     76     ), "X needs to be a torch Tensor instead of a {}".format(type(X))
     77     assert isinstance(
     78         y, torch.Tensor
     79     ), "y needs to be a torch Tensor instead of a {}".format(type(y))
     80     super().__init__(X, y, kernel, mean_function, jitter)

AssertionError: X needs to be a torch Tensor instead of a <class 'numpy.ndarray'>

Please let me know if some tests need to be added pertaining to the updated code.

@GautamV234 GautamV234 closed this Mar 8, 2022
@GautamV234 GautamV234 reopened this Mar 8, 2022
@fehiepsi
Copy link
Member

fehiepsi commented Mar 8, 2022

Thanks for adding the checks! I think no need to add tests. Pyro is based on PyTorch, not numpy, so we would expect all inputs are torch tensors.

Could you fix the lint issue? You can run make lint, make format to catch and fix the issues.

@fritzo
Copy link
Member

fritzo commented Mar 10, 2022

Run make format and push

@GautamV234
Copy link
Contributor Author

Run make format and push

Thanks for the help!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @GautamV234 !

@fehiepsi fehiepsi merged commit ad8520a into pyro-ppl:dev Mar 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants