Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularizing distribution tests using parametrized fixtures #162

Merged
merged 7 commits into from
Oct 1, 2017
Merged

Modularizing distribution tests using parametrized fixtures #162

merged 7 commits into from
Oct 1, 2017

Conversation

neerajprad
Copy link
Member

@neerajprad neerajprad commented Sep 28, 2017

This uses parametrized fixtures to remove code duplication in our distribution tests, so that it is easier to add new tests, without having to specifically code them in for each distribution.

Additional details:

  • Most of the fixture creation logic and helper functions reside in a separate class dist_fixture.py - this provides the corresponding scipy function, number of samples to generate to get a desired precision etc. This is the abstract out any common fixture related utility from the tests themselves, so that it can be shared across multiple tests if needed.
  • Fixtures are specified as
    Fixture(dist.diagnormal,                   # pyro dist to be tested
            sp.multivariate_normal,            # corresponding numpy function
            [(2.0, 4.0), (50.0, 100.0)],         # dist params to pass, can be batched as here
            [2.0, 50.0],                                # observed test values
            lambda (mean, sigma): ((), {"mean": mean, "cov": sigma ** 2}),  # convert dist params
                                                #   for calling scipy function
            prec=0.1,   # desired precision for computing mean and variance
            min_samples=50000),   # the above precision guides how many samples we need 
                              #  to get a good estimate for the mean and variance;
                              #  the actual number sampled is the max of this override and the 
                              #  number computed from above, as the normal approximation for 
                              #  sample variance does not work for all distributions
  • The parametrized fixtures have some common fields, and some optional fields (like specifying the file from which to get expected support for discrete distributions) that are only needed for certain distributions or tests. The fixtures are further grouped for continuous and discrete distributions, so that tests can consume them separately. e.g. the test_support function only consumes the fixtures for the discrete distributions.
  • I have made an attempt to move away from unittest.TestCase to using pytest (aided by Abstract out test utilities for pytest #155), but this is an ongoing endeavor and a few tests still use the setup method from unittest.
  • analytic_mean and analytic_var methods are added to dists so that it is easy to validate consistency between sample means/vars from the analytical ones.

Some tests for categorical distribution and the delta distribution have not yet been migrated, as they have some specific behavior that does not easily fit the patterns that the other tests employ. They are put into separate classes.

@eb8680
Copy link
Member

eb8680 commented Sep 28, 2017

This is awesome, great work!

min_samples=1000,
is_discrete=False,
expected_support_file=None,
expected_support_key=''):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make all of these into named arguments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good idea, will do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -31,3 +31,9 @@ def batch_log_pdf(self, x, batch_size):

def support(self):
raise NotImplementedError("Support not supported for {}".format(str(type(self))))

def analytic_mean(self, *args, **kwargs):
raise NotImplementedError("Method not implemented by the subclass {}".format(str(type(self))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add docstrings for these methods here and on the rest of the distributions and verify that documentation is correctly generated?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will add docstrings and check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to the base class. I am not sure if it makes sense to repetitively add this to all the implementing classes; will look into how sphinx handles inheritance during document generation.

raise NotImplementedError("Method not implemented by the subclass {}".format(str(type(self))))

def analytic_var(self, *args, **kwargs):
raise NotImplementedError("Method not implemented by the subclass {}".format(str(type(self))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring


pytestmark = pytest.mark.init(rng_seed=123)

continuous_dists = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we split the test generation code here into a separate file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fixtures have now been moved to the local conftest.py file and should be available to all the distribution tests.

assert_equal(log_px_val, log_px_np, prec=1e-4)


def test_float_type(float_test_data, float_alpha, float_beta, test_data, alpha, beta):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also have a test_cuda_type marked for Travis to skip.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be nice if we could run tests (or at least a subset of integration tests) using Variable.cuda() based on an env flag (CI will always set this flag to false, and so will our laptops), but on a GPU machine with the flag enabled, we should be able to run these tests without having to modify the test suite. Will look into that. Just testing Variable.cuda() otherwise will throw an error on our local setup.

@@ -77,3 +77,11 @@ def support(self, ps=None, *args, **kwargs):
size = functools.reduce(lambda x, y: x * y, _ps.size())
return (Variable(torch.Tensor(list(x)).view_as(_ps))
for x in itertools.product(torch.Tensor([0, 1]).type_as(_ps.data), repeat=size))

def analytic_mean(self, ps=None):
_ps = self._sanitize_input(ps)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retaining the convention, but will change as part of #167

import pyro.distributions as dist
from tests.common import TestCase

pytestmark = pytest.mark.init(rng_seed=123)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you remove the pytestmark = or rename it to unused = ... or unused_pytestmark = ...? (here and in other files)

Copy link
Member Author

@neerajprad neerajprad Sep 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is used to set the random seed for all the tests in the file. Removing it or changing the variable name will result in non-deterministic tests. See - https://docs.pytest.org/en/latest/example/markers.html#marking-whole-classes-or-modules.

We are using the same seed for almost all our unit tests; so I think we can set this in conftest.py itself, to keep it clean. Any tests that need to override this default can do that using the decorator explicitly. I will tackle this in a separate PR though, since this is true of all our unit tests, not just the ones for distribution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thanks, I didn't know about that!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I concur, looks like this is one of the most heavily requested features.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad This looks great! I'm so glad that we're modularizing our tests; it should make it much easier to add test that apply to all distributions. Also, it's great to learn from you how to use pytest fixtures; I've used them only a little bit in the past.

I'd like to take one more pass verifying math before merging.


def analytic_mean(self, *args, **kwargs):
"""
Analytic mean of the distribution, to be implemented by derived classes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider mentioning that this is optional and "is used only in testing".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will do.


def analytic_var(self, ps=None, n=None):
_ps, _n = self._sanitize_input(ps, n)
return _n * ps * (1 - _ps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: ps should be _ps. It would sure be easier to read without underscores #167 .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, will change.

scipy_dist=sp.dirichlet,
dist_params=[([2.4, 3, 6],), ([3.2, 1.2, 0.4],)],
test_data=[[0.2, 0.45, 0.35], [0.3, 0.4, 0.3]],
scipy_arg_fn=lambda alpha: ((alpha,), {}))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add a trailing comma. This makes it less error-prone to add new distributions, to review changes, and to resolve merge conflicts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, will change.

min_samples=10000,
is_discrete=True,
expected_support_file='tests/test_data/support_categorical.json',
expected_support_key='one_hot')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ditto: add trailing comma


def test_batch_log_pdf(dist):
# TODO (@npradhan) - remove once #144 is resolved
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this skip-if-not-implemented behavior for researchy codebases. Modularised tests can cause some friction when testing new distributions or when partially implementing distributions for a single science project. This caused some pain for @ericmjonas when he partially implemented exotic distributions in posterior/distributions. On researchy codebases I now often use an @xfail_if_not_implemented decorator (code). I expect @ngoodman and students would appreciate that 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear, the issue is that on one-off data science projects, we often create a partial implementation of some crazy distribution. It's nice to test that it is correct, but unless we use something like an xfail_if_not_implemented pattern, the author will have to comment-out a bunch of tests on some branch of pyro.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing those links! It will be useful to have a similar decorator for our code, specially as we start putting in more distributions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(on second look, it was a context manager that's used like

with xfail_if_not_implemented():
    ...do stuff...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's even better, as I was not sure how test decorators interact with pytest. We can have something like this in testing.common

_, counts = np.unique(torch_samples, return_counts=True)
exp_ = float(counts[0]) / self.n_samples
torch_var = float(counts[0]) * np.power(0.1 * (0 - np.mean(torch_samples)), 2)
torch_var = np.square(np.mean(torch_samples)) / 16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a comment explaining this computation and the magic number 16?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious to know what this magic number is too. :)

I have removed the variance check now, as I am not even sure what variance in this context would mean. Thanks for flagging this.


with open('tests/test_data/support_categorical.json') as data_file:
data = json.load(data_file)
self.support = list(map(lambda x: torch.Tensor(x), data['one_hot']))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I find it easier to read list comprehensions. What do you think?

self.support = [torch.Tensor(x) for x in data['one_hot']]

return None
with open(self.support_file) as data:
data = json.load(data)
return list(map(lambda x: torch.Tensor(x), data[self.support_key]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider using a list comprehension instead of list(map(lambda ...))

_sum_alpha = torch.sum(_alpha)
return _alpha / _sum_alpha

def analytic_var(self, alpha):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this computes the element-wise variance, could you please add a comment verifying so?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -70,7 +70,7 @@ def log_pdf(self, x, alpha=None, beta=None, *args, **kwargs):
def batch_log_pdf(self, x, alpha=None, beta=None, batch_size=1, *args, **kwargs):
_alpha, _beta = self._sanitize_input(alpha, beta)
if x.dim() == 1 and _beta.dim() == 1 and batch_size == 1:
return self.log_pdf(x. _alpha, _beta)
return self.log_pdf(x, _alpha, _beta)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm so glad that we now have tests that will catch this 😄

@fritzo
Copy link
Member

fritzo commented Sep 28, 2017

@neerajprad Do you have a plan to make these test fixtures available to authors of new distributions? For example, suppose I partially implement a spherical von Mises distribution in my own codebase that imports pyro. Is there a way that I can import some testing utilities to ensure my implementation is correct? Or should I simply fork Pyro? (this is out of scope for this PR, but do you have any ideas?)

fritzo
fritzo previously approved these changes Sep 29, 2017
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

In future tests, I'd also like to verify that we're not accidentally promoting FloatTensors to DoubleTensors by checking .dtype or something. I worry that I may have introduced such a performance bug in #163.

dist_params=[([1.0, 1.0], [[2.0, 0.0], [1.0, 3.0]])],
test_data=[(0.4, 0.6)],
scipy_arg_fn=lambda mean, L: ((), {"mean": np.array(mean),
"cov": np.matmul(np.array(L), np.transpose(np.array(L)))}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can also write np.matmul(np.array(L), np.array(L).T))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh..nice!


@pytest.fixture()
def test_data():
return Variable(torch.DoubleTensor([0.4]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a convention of using float vs double in pyro?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, we get a float tensor when we call torch.Tensor. So I suppose we follow the same convention as pytorch in that we use the float tensor by default.

dist_params=[(1.4, 0.4), (2.6, 0.5)],
test_data=[5.5, 6.4],
scipy_arg_fn=lambda mean, sigma: ((sigma,), {"scale": math.exp(mean)}),
prec=0.1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These prec seem pretty loose. Do you have a sense why?

I only ask because I'm not reviewing all the analytic_mean and analytic_var math for correctness, and I'd feel more comfortable if we had higher precision tests for that math.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests need to generate many more samples to get to the desired precision otherwise. I have increased the number of samples so that we can aim for the default 0.05 precision on these tests now. Worth noting that this slowness is mostly due to the naive sampling mentioned in #146. Once we have batching enabled for sampling from distributions, we can easily generate much bigger samples and validate these quantities with a finer precision.

@fritzo
Copy link
Member

fritzo commented Sep 29, 2017

@eb8680 @jpchen Do you have any more comments before we merge?

@neerajprad
Copy link
Member Author

@fritzo - Thanks for such a detailed review, this was really helpful! I will take another pass at the remainder of the comments, and get them addressed. We can get this merged after that, unless there are more comments from the reviewers.

Copy link
Member

@jpchen jpchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine, pending fritz's comments

@fritzo
Copy link
Member

fritzo commented Sep 29, 2017

OK, I'll squash-and-merge as soon as tests pass.

@neerajprad
Copy link
Member Author

@neerajprad Do you have a plan to make these test fixtures available to authors of new distributions? For example, suppose I partially implement a spherical von Mises distribution in my own codebase that imports pyro. Is there a way that I can import some testing utilities to ensure my implementation is correct? Or should I simply fork Pyro? (this is out of scope for this PR, but do you have any ideas?)

Right now we are excluding the test module in setup.py. As you pointed out in #143, we will need to start packaging any generic test utils for distributions so that they are available through pip install. The existing tests are still somewhat coupled to the way we have chosen our implementation (relying on scipy.stats library for instance), so I am not sure about their utility to someone implementing a custom distribution. As we build more generic testing tools for inference and distribution testing, we should certainly design them to be usable on their own, and include them either with the pyro package or a separate test utility package. Let me know if you have any suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants