-
-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modularizing distribution tests using parametrized fixtures #162
Conversation
This is awesome, great work! |
min_samples=1000, | ||
is_discrete=False, | ||
expected_support_file=None, | ||
expected_support_key=''): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make all of these into named arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a good idea, will do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
pyro/distributions/distribution.py
Outdated
@@ -31,3 +31,9 @@ def batch_log_pdf(self, x, batch_size): | |||
|
|||
def support(self): | |||
raise NotImplementedError("Support not supported for {}".format(str(type(self)))) | |||
|
|||
def analytic_mean(self, *args, **kwargs): | |||
raise NotImplementedError("Method not implemented by the subclass {}".format(str(type(self)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add docstrings for these methods here and on the rest of the distributions and verify that documentation is correctly generated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will add docstrings and check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added to the base class. I am not sure if it makes sense to repetitively add this to all the implementing classes; will look into how sphinx handles inheritance during document generation.
pyro/distributions/distribution.py
Outdated
raise NotImplementedError("Method not implemented by the subclass {}".format(str(type(self)))) | ||
|
||
def analytic_var(self, *args, **kwargs): | ||
raise NotImplementedError("Method not implemented by the subclass {}".format(str(type(self)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring
|
||
pytestmark = pytest.mark.init(rng_seed=123) | ||
|
||
continuous_dists = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we split the test generation code here into a separate file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the fixtures have now been moved to the local conftest.py
file and should be available to all the distribution tests.
assert_equal(log_px_val, log_px_np, prec=1e-4) | ||
|
||
|
||
def test_float_type(float_test_data, float_alpha, float_beta, test_data, alpha, beta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also have a test_cuda_type
marked for Travis to skip.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be nice if we could run tests (or at least a subset of integration tests) using Variable.cuda()
based on an env flag (CI will always set this flag to false, and so will our laptops), but on a GPU machine with the flag enabled, we should be able to run these tests without having to modify the test suite. Will look into that. Just testing Variable.cuda()
otherwise will throw an error on our local setup.
@@ -77,3 +77,11 @@ def support(self, ps=None, *args, **kwargs): | |||
size = functools.reduce(lambda x, y: x * y, _ps.size()) | |||
return (Variable(torch.Tensor(list(x)).view_as(_ps)) | |||
for x in itertools.product(torch.Tensor([0, 1]).type_as(_ps.data), repeat=size)) | |||
|
|||
def analytic_mean(self, ps=None): | |||
_ps = self._sanitize_input(ps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Retaining the convention, but will change as part of #167
import pyro.distributions as dist | ||
from tests.common import TestCase | ||
|
||
pytestmark = pytest.mark.init(rng_seed=123) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can you remove the pytestmark =
or rename it to unused = ...
or unused_pytestmark = ...
? (here and in other files)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is used to set the random seed for all the tests in the file. Removing it or changing the variable name will result in non-deterministic tests. See - https://docs.pytest.org/en/latest/example/markers.html#marking-whole-classes-or-modules.
We are using the same seed for almost all our unit tests; so I think we can set this in conftest.py
itself, to keep it clean. Any tests that need to override this default can do that using the decorator explicitly. I will tackle this in a separate PR though, since this is true of all our unit tests, not just the ones for distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh thanks, I didn't know about that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish I could mark this comment as resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I concur, looks like this is one of the most heavily requested features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad This looks great! I'm so glad that we're modularizing our tests; it should make it much easier to add test that apply to all distributions. Also, it's great to learn from you how to use pytest fixtures; I've used them only a little bit in the past.
I'd like to take one more pass verifying math before merging.
|
||
def analytic_mean(self, *args, **kwargs): | ||
""" | ||
Analytic mean of the distribution, to be implemented by derived classes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider mentioning that this is optional and "is used only in testing".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, will do.
pyro/distributions/multinomial.py
Outdated
|
||
def analytic_var(self, ps=None, n=None): | ||
_ps, _n = self._sanitize_input(ps, n) | ||
return _n * ps * (1 - _ps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: ps
should be _ps
. It would sure be easier to read without underscores #167 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, will change.
tests/unit/distributions/conftest.py
Outdated
scipy_dist=sp.dirichlet, | ||
dist_params=[([2.4, 3, 6],), ([3.2, 1.2, 0.4],)], | ||
test_data=[[0.2, 0.45, 0.35], [0.3, 0.4, 0.3]], | ||
scipy_arg_fn=lambda alpha: ((alpha,), {})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Add a trailing comma. This makes it less error-prone to add new distributions, to review changes, and to resolve merge conflicts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, will change.
tests/unit/distributions/conftest.py
Outdated
min_samples=10000, | ||
is_discrete=True, | ||
expected_support_file='tests/test_data/support_categorical.json', | ||
expected_support_key='one_hot') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ditto: add trailing comma
|
||
def test_batch_log_pdf(dist): | ||
# TODO (@npradhan) - remove once #144 is resolved | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this skip-if-not-implemented behavior for researchy codebases. Modularised tests can cause some friction when testing new distributions or when partially implementing distributions for a single science project. This caused some pain for @ericmjonas when he partially implemented exotic distributions in posterior/distributions. On researchy codebases I now often use an @xfail_if_not_implemented
decorator (code). I expect @ngoodman and students would appreciate that 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be clear, the issue is that on one-off data science projects, we often create a partial implementation of some crazy distribution. It's nice to test that it is correct, but unless we use something like an xfail_if_not_implemented
pattern, the author will have to comment-out a bunch of tests on some branch of pyro.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for sharing those links! It will be useful to have a similar decorator for our code, specially as we start putting in more distributions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(on second look, it was a context manager that's used like
with xfail_if_not_implemented():
...do stuff...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's even better, as I was not sure how test decorators interact with pytest. We can have something like this in testing.common
_, counts = np.unique(torch_samples, return_counts=True) | ||
exp_ = float(counts[0]) / self.n_samples | ||
torch_var = float(counts[0]) * np.power(0.1 * (0 - np.mean(torch_samples)), 2) | ||
torch_var = np.square(np.mean(torch_samples)) / 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a comment explaining this computation and the magic number 16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious to know what this magic number is too. :)
I have removed the variance check now, as I am not even sure what variance in this context would mean. Thanks for flagging this.
|
||
with open('tests/test_data/support_categorical.json') as data_file: | ||
data = json.load(data_file) | ||
self.support = list(map(lambda x: torch.Tensor(x), data['one_hot'])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I find it easier to read list comprehensions. What do you think?
self.support = [torch.Tensor(x) for x in data['one_hot']]
return None | ||
with open(self.support_file) as data: | ||
data = json.load(data) | ||
return list(map(lambda x: torch.Tensor(x), data[self.support_key])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Consider using a list comprehension instead of list(map(lambda ...))
_sum_alpha = torch.sum(_alpha) | ||
return _alpha / _sum_alpha | ||
|
||
def analytic_var(self, alpha): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this computes the element-wise variance, could you please add a comment verifying so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -70,7 +70,7 @@ def log_pdf(self, x, alpha=None, beta=None, *args, **kwargs): | |||
def batch_log_pdf(self, x, alpha=None, beta=None, batch_size=1, *args, **kwargs): | |||
_alpha, _beta = self._sanitize_input(alpha, beta) | |||
if x.dim() == 1 and _beta.dim() == 1 and batch_size == 1: | |||
return self.log_pdf(x. _alpha, _beta) | |||
return self.log_pdf(x, _alpha, _beta) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm so glad that we now have tests that will catch this 😄
@neerajprad Do you have a plan to make these test fixtures available to authors of new distributions? For example, suppose I partially implement a spherical von Mises distribution in my own codebase that imports |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
In future tests, I'd also like to verify that we're not accidentally promoting FloatTensor
s to DoubleTensor
s by checking .dtype
or something. I worry that I may have introduced such a performance bug in #163.
tests/unit/distributions/conftest.py
Outdated
dist_params=[([1.0, 1.0], [[2.0, 0.0], [1.0, 3.0]])], | ||
test_data=[(0.4, 0.6)], | ||
scipy_arg_fn=lambda mean, L: ((), {"mean": np.array(mean), | ||
"cov": np.matmul(np.array(L), np.transpose(np.array(L)))}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can also write np.matmul(np.array(L), np.array(L).T))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh..nice!
|
||
@pytest.fixture() | ||
def test_data(): | ||
return Variable(torch.DoubleTensor([0.4])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a convention of using float vs double in pyro?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default, we get a float tensor when we call torch.Tensor
. So I suppose we follow the same convention as pytorch in that we use the float tensor by default.
tests/unit/distributions/conftest.py
Outdated
dist_params=[(1.4, 0.4), (2.6, 0.5)], | ||
test_data=[5.5, 6.4], | ||
scipy_arg_fn=lambda mean, sigma: ((sigma,), {"scale": math.exp(mean)}), | ||
prec=0.1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These prec
seem pretty loose. Do you have a sense why?
I only ask because I'm not reviewing all the analytic_mean
and analytic_var
math for correctness, and I'd feel more comfortable if we had higher precision tests for that math.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests need to generate many more samples to get to the desired precision otherwise. I have increased the number of samples so that we can aim for the default 0.05
precision on these tests now. Worth noting that this slowness is mostly due to the naive sampling mentioned in #146. Once we have batching enabled for sampling from distributions, we can easily generate much bigger samples and validate these quantities with a finer precision.
@fritzo - Thanks for such a detailed review, this was really helpful! I will take another pass at the remainder of the comments, and get them addressed. We can get this merged after that, unless there are more comments from the reviewers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks fine, pending fritz's comments
OK, I'll squash-and-merge as soon as tests pass. |
Right now we are excluding the |
This uses parametrized fixtures to remove code duplication in our distribution tests, so that it is easier to add new tests, without having to specifically code them in for each distribution.
Additional details:
dist_fixture.py
- this provides the corresponding scipy function, number of samples to generate to get a desired precision etc. This is the abstract out any common fixture related utility from the tests themselves, so that it can be shared across multiple tests if needed.test_support
function only consumes the fixtures for the discrete distributions.unittest.TestCase
to using pytest (aided by Abstract out test utilities for pytest #155), but this is an ongoing endeavor and a few tests still use the setup method fromunittest
.analytic_mean
andanalytic_var
methods are added to dists so that it is easy to validate consistency between sample means/vars from the analytical ones.Some tests for categorical distribution and the delta distribution have not yet been migrated, as they have some specific behavior that does not easily fit the patterns that the other tests employ. They are put into separate classes.