Skip to content

Commit

Permalink
Merge 2ae5980 into 53e4966
Browse files Browse the repository at this point in the history
  • Loading branch information
rpgoldman committed Oct 11, 2019
2 parents 53e4966 + 2ae5980 commit d8d1020
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 59 deletions.
17 changes: 4 additions & 13 deletions pymc3/distributions/distribution.py
Expand Up @@ -7,7 +7,7 @@
from ..memoize import memoize
from ..model import (
Model, get_named_nodes_and_relations, FreeRV,
ObservedRV, MultiObservedRV, Context, InitContextMeta
ObservedRV, MultiObservedRV, ContextMeta
)
from ..vartypes import string_types, theano_constant
from .shape_utils import (
Expand Down Expand Up @@ -449,23 +449,14 @@ def random(self, point=None, size=None, **kwargs):
"Define a custom random method and pass it as kwarg random")


class _DrawValuesContext(Context, metaclass=InitContextMeta):
class _DrawValuesContext(metaclass=ContextMeta, context_class='_DrawValuesContext'):
""" A context manager class used while drawing values with draw_values
"""

def __new__(cls, *args, **kwargs):
# resolves the parent instance
instance = super().__new__(cls)
if cls.get_contexts():
potential_parent = cls.get_contexts()[-1]
# We have to make sure that the context is a _DrawValuesContext
# and not a Model
if isinstance(potential_parent, _DrawValuesContext):
instance._parent = potential_parent
else:
instance._parent = None
else:
instance._parent = None
instance._parent = cls.get_context(error_if_none=False)
return instance

def __init__(self):
Expand All @@ -485,7 +476,7 @@ def parent(self):
return self._parent


class _DrawValuesContextBlocker(_DrawValuesContext, metaclass=InitContextMeta):
class _DrawValuesContextBlocker(_DrawValuesContext):
"""
Context manager that starts a new drawn variables context disregarding all
parent contexts. This can be used inside a random method to ensure that
Expand Down
159 changes: 115 additions & 44 deletions pymc3/model.py
Expand Up @@ -3,7 +3,8 @@
import itertools
import threading
import warnings
from typing import Optional
from typing import Optional, Tuple, TypeVar, Type, List, Union
from sys import modules

import numpy as np
from pandas import Series
Expand Down Expand Up @@ -161,49 +162,130 @@ def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
node_children.update(temp_tree)
return leaf_nodes, node_parents, node_children

T = TypeVar('T', bound='ContextMeta')

class Context:

class ContextMeta(type):
"""Functionality for objects that put themselves in a context using
the `with` statement.
"""
contexts = threading.local()

def __enter__(self):
type(self).get_contexts().append(self)
# self._theano_config is set in Model.__new__
if hasattr(self, '_theano_config'):
self._old_theano_config = set_theano_conf(self._theano_config)
return self

def __exit__(self, typ, value, traceback):
type(self).get_contexts().pop()
# self._theano_config is set in Model.__new__
if hasattr(self, '_old_theano_config'):
set_theano_conf(self._old_theano_config)

@classmethod
def get_contexts(cls):
# no race-condition here, cls.contexts is a thread-local object
def __new__(cls, name, bases, dct, **kargs):
# this serves only to strip off keyword args, per the warning from
# StackExchange:
# DO NOT send "**kargs" to "type.__new__". It won't catch them and
# you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
return super().__new__(cls, name, bases, dct)

# FIXME: is there a more elegant way to automatically add methods to the class that
# are instance methods instead of class methods?
def __init__(cls, name, bases, nmspc, context_class: Optional[Type]=None, **kwargs):
"""Add ``__enter__`` and ``__exit__`` methods to the new class automatically."""
if context_class is not None:
cls._context_class = context_class
super().__init__(name, bases, nmspc)

def __enter__(self):
self.__class__.context_class.get_contexts().append(self)
# self._theano_config is set in Model.__new__
if hasattr(self, '_theano_config'):
self._old_theano_config = set_theano_conf(self._theano_config)
return self

def __exit__(self, typ, value, traceback):
self.__class__.context_class.get_contexts().pop()
# self._theano_config is set in Model.__new__
if hasattr(self, '_old_theano_config'):
set_theano_conf(self._old_theano_config)

cls.__enter__ = __enter__
cls.__exit__ = __exit__


def get_context(cls, error_if_none=True) -> Optional[T]:
"""Return the most recently pushed context object of type ``cls``
on the stack, or ``None``. If ``error_if_none`` is True (default),
raise a ``TypeError`` instead of returning ``None``."""
idx = -1
while True:
try:
candidate = cls.get_contexts()[idx] # type: Optional[T]
except IndexError as e:
# Calling code expects to get a TypeError if the entity
# is unfound, and there's too much to fix.
if error_if_none:
raise TypeError("No %s on context stack"%str(cls))
return None
return candidate
idx = idx - 1

def get_contexts(cls) -> List[T]:
"""Return a stack of context instances for the ``context_class``
of ``cls``."""
# This lazily creates the context class's contexts
# thread-local object, as needed. This seems inelegant to me,
# but since the context class is not guaranteed to exist when
# the metaclass is being instantiated, I couldn't figure out a
# better way. [2019/10/11:rpg]

# no race-condition here, contexts is a thread-local object
# be sure not to override contexts in a subclass however!
if not hasattr(cls.contexts, 'stack'):
cls.contexts.stack = []
return cls.contexts.stack
context_class = cls.context_class
assert isinstance(context_class, type), \
"Name of context class, %s was not resolvable to a class"%context_class
if not hasattr(context_class, 'contexts'):
context_class.contexts = threading.local()

contexts = context_class.contexts

if not hasattr(contexts, 'stack'):
contexts.stack = []
return contexts.stack

# the following complex property accessor is necessary because the
# context_class may not have been created at the point it is
# specified, so the context_class may be a class *name* rather
# than a class.
@property
def context_class(cls) -> Type:
def resolve_type(c: Union[Type, str]) -> Type:
if isinstance(c, str):
c = getattr(modules[cls.__module__], c)
if isinstance(c, type):
return c
raise ValueError("Cannot resolve context class %s"%c)
assert cls is not None
if isinstance(cls._context_class, str):
cls._context_class = resolve_type(cls._context_class)
if not isinstance(cls._context_class, (str, type)):
raise ValueError("Context class for %s, %s, is not of the right type"%\
(cls.__name__, cls._context_class))
return cls._context_class

# Inherit context class from parent
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.context_class = super().context_class

# Initialize object in its own context...
# Merged from InitContextMeta in the original.
def __call__(cls, *args, **kwargs):
instance = cls.__new__(cls, *args, **kwargs)
with instance: # appends context
instance.__init__(*args, **kwargs)
return instance

@classmethod
def get_context(cls):
"""Return the deepest context on the stack."""
try:
return cls.get_contexts()[-1]
except IndexError:
raise TypeError("No context on context stack")


def modelcontext(model: Optional['Model']) -> 'Model':
def modelcontext(model: Optional['Model']) -> Optional['Model']:
"""return the given model or try to find it in the context if there was
none supplied.
"""
if model is None:
return Model.get_context()
found: Optional['Model'] = Model.get_context(error_if_none=False)
if found is None:
raise ValueError("No model on context stack.")
return found
return model


Expand Down Expand Up @@ -291,15 +373,6 @@ def logp_nojact(self):
return logp


class InitContextMeta(type):
"""Metaclass that executes `__init__` of instance in it's context"""
def __call__(cls, *args, **kwargs):
instance = cls.__new__(cls, *args, **kwargs)
with instance: # appends context
instance.__init__(*args, **kwargs)
return instance


def withparent(meth):
"""Helper wrapper that passes calls to parent's instance"""
def wrapped(self, *args, **kwargs):
Expand Down Expand Up @@ -554,7 +627,7 @@ def _build_joined(self, cost, args, vmap):
return args_joined, theano.clone(cost, replace=replace)


class Model(Context, Factor, WithMemoization, metaclass=InitContextMeta):
class Model(Factor, WithMemoization, metaclass=ContextMeta, context_class='Model'):
"""Encapsulates the variables and likelihood factors of a model.
Model class can be used for creating class based models. To create
Expand Down Expand Up @@ -647,10 +720,8 @@ def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
if kwargs.get('model') is not None:
instance._parent = kwargs.get('model')
elif cls.get_contexts():
instance._parent = cls.get_contexts()[-1]
else:
instance._parent = None
instance._parent = cls.get_context(error_if_none=False)
theano_config = kwargs.get('theano_config', None)
if theano_config is None or 'compute_test_value' not in theano_config:
theano_config = {'compute_test_value': 'raise'}
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_data_container.py
Expand Up @@ -77,7 +77,7 @@ def test_sample_after_set_data(self):
atol=1e-1)

def test_creation_of_data_outside_model_context(self):
with pytest.raises(TypeError) as error:
with pytest.raises((IndexError, TypeError)) as error:
pm.Data('data', [1.1, 2.2, 3.3])
error.match('No model on context stack')

Expand Down
7 changes: 6 additions & 1 deletion pymc3/tests/test_model.py
Expand Up @@ -52,10 +52,15 @@ def test_setattr_properly_works(self):

def test_context_passes_vars_to_parent_model(self):
with pm.Model() as model:
assert pm.model.modelcontext(None) == model
assert pm.Model.get_context() == model
# a set of variables is created
NewModel()
nm = NewModel()
assert pm.Model.get_context() == model
# another set of variables are created but with prefix 'another'
usermodel2 = NewModel(name='another')
assert pm.Model.get_context() == model
assert usermodel2._parent == model
# you can enter in a context with submodel
with usermodel2:
usermodel2.Var('v3', pm.Normal.dist())
Expand Down
38 changes: 38 additions & 0 deletions pymc3/tests/test_modelcontext.py
@@ -1,5 +1,8 @@
import threading
from pytest import raises
from pymc3 import Model, Normal
from pymc3.distributions.distribution import _DrawValuesContext, _DrawValuesContextBlocker
from pymc3.model import modelcontext


class TestModelContext:
Expand Down Expand Up @@ -42,3 +45,38 @@ def make_model_b():
list(modelA.named_vars),
list(modelB.named_vars),
) == (['a'],['b'])

def test_mixed_contexts():
modelA = Model()
modelB = Model()
with raises((ValueError, TypeError)):
modelcontext(None)
with modelA:
with modelB:
assert Model.get_context() == modelB
assert modelcontext(None) == modelB
dvc = _DrawValuesContext()
with dvc:
assert Model.get_context() == modelB
assert modelcontext(None) == modelB
assert _DrawValuesContext.get_context() == dvc
dvcb = _DrawValuesContextBlocker()
with dvcb:
assert _DrawValuesContext.get_context() == dvcb
assert _DrawValuesContextBlocker.get_context() == dvcb
assert _DrawValuesContext.get_context() == dvc
assert _DrawValuesContextBlocker.get_context() is dvc
assert Model.get_context() == modelB
assert modelcontext(None) == modelB
assert _DrawValuesContext.get_context(error_if_none=False) is None
with raises(TypeError):
_DrawValuesContext.get_context()
assert Model.get_context() == modelB
assert modelcontext(None) == modelB
assert Model.get_context() == modelA
assert modelcontext(None) == modelA
assert Model.get_context(error_if_none=False) is None
with raises(TypeError):
Model.get_context(error_if_none=True)
with raises((ValueError, TypeError)):
modelcontext(None)

0 comments on commit d8d1020

Please sign in to comment.