From 4ae1ec3c5276f098970e72e9e90b9c3d956f4f50 Mon Sep 17 00:00:00 2001 From: Alexander Rodin Date: Sun, 21 May 2017 20:30:59 +0300 Subject: [PATCH] Avoid storing bound methods as variables to prevent pickling problems --- pymc3/distributions/transforms.py | 10 +++++++--- pymc3/model.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 1fed434fe4..59e8418667 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -5,7 +5,6 @@ from . import distribution from ..math import logit, invlogit import numpy as np -from functools import partial __all__ = ['transform', 'stick_breaking', 'logodds', 'interval', 'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking'] @@ -70,8 +69,13 @@ def __init__(self, dist, transform, *args, **kwargs): b = np.hstack(((np.atleast_1d(self.shape) == 1)[:-1], False)) # force the last dim not broadcastable self.type = tt.TensorType(v.dtype, b) - - self._repr_latex_ = partial(dist._repr_latex_, dist=dist) + + def _repr_latex_(self, name=None, dist=None): + if name is None: + name = self.name + if dist is None: + dist = self.dist + return dist._repr_latex_(self, name=name, dist=dist) def logp(self, x): return (self.dist.logp(self.transform_used.backward(x)) + diff --git a/pymc3/model.py b/pymc3/model.py index 5f508139d7..0990553e2e 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -1,7 +1,6 @@ import collections import threading import six -from functools import partial import numpy as np import scipy.sparse as sps @@ -823,7 +822,14 @@ def __init__(self, type=None, owner=None, index=None, name=None, methods=['random'], wrapper=InstanceMethod) - self._repr_latex_ = partial(distribution._repr_latex_, name=name, dist=distribution) + def _repr_latex_(self, name=None, dist=None): + if self.distribution is None: + return None + if name is None: + name = self.name + if dist is None: + dist = self.distribution + return self.distribution._repr_latex_(name=name, dist=dist) @property def init_value(self): @@ -916,8 +922,15 @@ def __init__(self, type=None, owner=None, index=None, name=None, data=None, inputs=[data], outputs=[self]) self.tag.test_value = theano.compile.view_op(data).tag.test_value - - self._repr_latex_ = partial(distribution._repr_latex_, name=name, dist=distribution) + + def _repr_latex_(self, name=None, dist=None): + if self.distribution is None: + return None + if name is None: + name = self.name + if dist is None: + dist = self.distribution + return self.distribution._repr_latex_(name=name, dist=dist) @property def init_value(self): @@ -1016,6 +1029,7 @@ def __init__(self, type=None, owner=None, index=None, name=None, if distribution is not None: self.model = model + self.distribution = distribution transformed_name = get_transformed_name(name, transform) @@ -1032,7 +1046,14 @@ def __init__(self, type=None, owner=None, index=None, name=None, methods=['random'], wrapper=InstanceMethod) - self._repr_latex_ = partial(distribution._repr_latex_, name=name, dist=distribution) + def _repr_latex_(self, name=None, dist=None): + if self.distribution is None: + return None + if name is None: + name = self.name + if dist is None: + dist = self.distribution + return self.distribution._repr_latex_(name=name, dist=dist) @property def init_value(self):