Skip to content

Commit

Permalink
Reduce time for NUTS tests (#804)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed Feb 23, 2018
1 parent 1801774 commit 60bcab7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions tests/infer/mcmc/test_hmc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import absolute_import, division, print_function

import logging
from collections import defaultdict, namedtuple

import logging
import os

import pytest
import torch
from torch.autograd import Variable
Expand Down
14 changes: 8 additions & 6 deletions tests/infer/mcmc/test_nuts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import absolute_import, division, print_function

import logging
from collections import defaultdict
import logging
import os

import pytest
import torch
Expand All @@ -13,18 +14,19 @@
from pyro.infer.mcmc.mcmc import MCMC
from tests.common import assert_equal

from .test_hmc import rmse, TEST_CASES, TEST_IDS
from .test_hmc import rmse, T, TEST_CASES, TEST_IDS

logging.basicConfig(format='%(levelname)s %(message)s')
logger = logging.getLogger('pyro')
logger.setLevel(logging.INFO)

TEST_CASES[0] = TEST_CASES[0]._replace(mean_tol=0.04, std_tol=0.04)
TEST_CASES[1] = TEST_CASES[1]._replace(mean_tol=0.04, std_tol=0.04)

# TODO: re-enable once https://github.com/uber/pyro/issues/799 is resolved, with lower number of samples
TEST_CASES[2] = TEST_CASES[2]._replace(marks=[pytest.mark.skip(reason="Slow test using NUTS.")])
TEST_CASES[3] = TEST_CASES[3]._replace(marks=[pytest.mark.skip(reason="Slow test using NUTS.")])
T2 = T(*TEST_CASES[2].values)._replace(num_samples=600, warmup_steps=100)
TEST_CASES[2] = pytest.param(*T2, marks=pytest.mark.skipif(
'CI' in os.environ and os.environ['CI'] == 'true', reason='Slow test - skip on CI'))
T3 = T(*TEST_CASES[3].values)._replace(num_samples=700, warmup_steps=100)
TEST_CASES[3] = T3


@pytest.mark.parametrize(
Expand Down

0 comments on commit 60bcab7

Please sign in to comment.