diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 5745374df..0c977cbc5 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -10,10 +10,14 @@ __all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete', 'NoDistribution', 'TensorType', 'draw_values'] +class _Unpickling(object): + pass class Distribution(object): """Statistical distribution""" def __new__(cls, name, *args, **kwargs): + if name is _Unpickling: + return object.__new__(cls) # for pickle try: model = Model.get_context() except TypeError: @@ -25,13 +29,11 @@ def __new__(cls, name, *args, **kwargs): data = kwargs.pop('observed', None) dist = cls.dist(*args, **kwargs) return model.Var(name, dist, data) - elif name is None: - return object.__new__(cls) # for pickle else: - raise TypeError("needed name or None but got: %s" % name) + raise TypeError("Name needs to be a string but got: %s" % name) def __getnewargs__(self): - return None, + return _Unpickling, @classmethod def dist(cls, *args, **kwargs): diff --git a/pymc3/tests/test_pickling.py b/pymc3/tests/test_pickling.py new file mode 100644 index 000000000..7a58b1935 --- /dev/null +++ b/pymc3/tests/test_pickling.py @@ -0,0 +1,21 @@ +import unittest +import pickle +import traceback +from .models import simple_model + + +class TestPickling(unittest.TestCase): + def setUp(self): + _, self.model, _ = simple_model() + + def test_model_roundtrip(self): + m = self.model + for proto in range(pickle.HIGHEST_PROTOCOL+1): + try: + s = pickle.dumps(m, proto) + n = pickle.loads(s) + except Exception as ex: + raise AssertionError( + "Exception while trying roundtrip with pickle protocol %d:\n"%proto + + ''.join(traceback.format_exc()) + )