# Model + RandomVariableクラスのコードリーディング

## setup

import

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading

from tensorflow_probability import edward2 as ed

  return f(*args, **kwds)
  from ._conv import register_converters as _register_converters


Contextクラス

In [41]:
class Context(object):
    """Functionality for objects that put themselves in a context using
    the `with` statement.
    """
    contexts = threading.local()

    def __enter__(self):
        type(self).get_contexts().append(self)
        return self

    def __exit__(self, typ, value, traceback):
        type(self).get_contexts().pop()

withparentメソッド

In [42]:
def withparent(meth):
    """Helper wrapper that passes calls to parent's instance"""
    def wrapped(self, *args, **kwargs):
        res = meth(self, *args, **kwargs)
        if getattr(self, 'parent', None) is not None:
            getattr(self.parent, meth.__name__)(*args, **kwargs)
        return res
    # Unfortunately functools wrapper fails
    # when decorating built-in methods so we
    # need to fix that improper behaviour
    wrapped.__name__ = meth.__name__
    return wrapped

treedictクラス

In [43]:
class treedict(dict):
    """A dict that passes mutable extending operations used in Model
    to parent dict instance.
    Extending treedict you will also extend its parent
    """
    def __init__(self, iterable=(), parent=None, **kwargs):
        super(treedict, self).__init__(iterable, **kwargs)
        assert isinstance(parent, dict) or parent is None
        self.parent = parent
        if self.parent is not None:
            self.parent.update(self)
    # typechecking here works bad
    __setitem__ = withparent(dict.__setitem__)
    update = withparent(dict.update)

    def tree_contains(self, item):
        # needed for `add_random_variable` method
        if isinstance(self.parent, treedict):
            return (dict.__contains__(self, item) or
                    self.parent.tree_contains(item))
        elif isinstance(self.parent, dict):
            return (dict.__contains__(self, item) or
                    self.parent.__contains__(item))
        else:
            return dict.__contains__(self, item)


Modelクラス

In [47]:
class Model(Context):
    def __new__(cls, *args, **kwargs):
        instance = super(Model, cls).__new__(cls)
        if kwargs.get('model') is not None:
            instance.parent = kwargs.get('model')
        elif cls.get_contexts():
            instance.parent = cls.get_contexts()[-1]
        else:
            instance.parent = None
        return instance
            
    
    def __init__(self, name="", model=None, ):
        self.name = name
        if self.parent is not None:
            self.named_vars = treedict(parent=self.parent.named_vars)
        else:
            self.named_vars = treedict()
            
    @property
    def model(self):
        return self
    
    @property
    def decription(self):
        return
    
    @classmethod
    def get_contexts(cls):
        # no race-condition here, cls.contexts is a thread-local object
        # be sure not to override contexts in a subclass however!
        if not hasattr(cls.contexts, 'stack'):
            cls.contexts.stack = []
        return cls.contexts.stack

    @classmethod
    def get_context(cls):
        """Return the deepest context on the stack."""
        try:
            return cls.get_contexts()[-1]
        except IndexError:
            raise TypeError("No context on context stack")
            
    def add_random_variable(self, var):
        """Add a random variable to the named variables of the model."""
        if self.named_vars.tree_contains(var.name):
            raise ValueError(
                "Variable name {} already exists.".format(var.name))
        self.named_vars[var.name] = var

RandomVariableクラス

In [57]:
class RandomVariable(ed.RandomVariable):
    
    def __init__(
                self,
                distribution,
                sample_shape=(),
                value=None,
                name="RV"
                ):
        self.model = Model.get_context()
        self.name = name

        super(RandomVariable, self).__init__(
                                                distribution,
                                                sample_shape,
                                                value,
                                                )
        
        self.model.add_random_variable(self)

# 動作確認

`with` でModelインスタンスを生成する度に、stackの末尾にそのmodelインスタンスが生成され、`parent` が設定されるテスト

In [102]:
print("before: ", Model.get_contexts())

with Model(name="model1") as model1:
    print("[1] inside with", ["{}:{}".format(x.name, getattr(x, 'parent', None)) for x in Model.get_contexts()])
    with Model(name="model2") as model2:
        print("[2] inside with", ["{}:{}".format(x.name, getattr(x, 'parent', None)) for x in Model.get_contexts()])
        print("model2's parent: ", model2.parent.name)
    
with Model(name="model3") as model3:
    print("[3] inside with", ["{}:{}".format(x.name, getattr(x, 'parent', None)) for x in Model.get_contexts()])

print("after: ", Model.get_contexts())

before:  []
[1] inside with ['model1:None']
[2] inside with ['model1:None', 'model2:<__main__.Model object at 0x122666f60>']
model2's parent:  model1
[3] inside with ['model3:None']
after:  []


RandomVariableが設定されるテスト

In [101]:
with Model(name="model1") as model1:
    rv = RandomVariable(ed.Normal(0., 1.), name="x", sample_shape=10000)

AttributeError: 'RandomVariable' object has no attribute 'sample'