Skip to content

Commit

Permalink
Improve the Theano scan_args class
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 16, 2020
1 parent 663bfa2 commit b1bb05a
Show file tree
Hide file tree
Showing 2 changed files with 875 additions and 2 deletions.
325 changes: 323 additions & 2 deletions symbolic_pymc/theano/opt.py
Expand Up @@ -4,10 +4,13 @@
import theano.tensor as tt

from functools import wraps
from unittest.mock import patch
from collections import namedtuple, OrderedDict

from theano.gof.opt import LocalOptimizer, local_optimizer
from theano.gof.graph import inputs as tt_inputs
from theano.scan_module.scan_op import Scan
from theano.scan_module.scan_utils import scan_args as ScanArgs
from theano.scan_module.scan_utils import scan_args

from unification import var, variables

Expand Down Expand Up @@ -36,6 +39,13 @@ def eval_and_reify_meta(x):
return res


def safe_index(lst, x):
try:
return lst.index(x)
except ValueError:
return None


class FunctionGraph(theano.gof.fg.FunctionGraph):
"""A version of `FunctionGraph` that knows not to merge non-deterministic `Op`s.
Expand Down Expand Up @@ -217,13 +227,324 @@ def transform(self, node):
else:
raise ValueError(
"Unsupported FunctionGraph replacement variable type: {chosen_res}"
)
) # pragma: no cover

return new_node
else:
return False


FieldInfo = namedtuple("FieldInfo", ("name", "agg_name", "index", "inner_index", "agg_index"))


class ScanArgs(scan_args):
"""An improved version of `theano.scan_module.scan_utils`."""

default_filter = lambda x: x.startswith("inner_") or x.startswith("outer_")
nested_list_fields = ("inner_in_mit_mot", "inner_in_mit_sot", "inner_out_mit_mot")

def __init__(self, *args, **kwargs):
# Prevent unnecessary and counter-productive cloning.
# If you want to clone the inner graph, do it before you call this!
with patch(
"theano.scan_module.scan_utils.reconstruct_graph",
side_effect=lambda x, y, z=None: [x, y],
):
super().__init__(*args, **kwargs)

@staticmethod
def from_node(node):
if not isinstance(node.op, Scan):
raise TypeError("{} is not a Scan node".format(node))
return ScanArgs(node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info)

@classmethod
def create_empty(cls):
info = OrderedDict(
[
("n_seqs", 0),
("n_mit_mot", 0),
("n_mit_sot", 0),
("tap_array", []),
("n_sit_sot", 0),
("n_nit_sot", 0),
("n_shared_outs", 0),
("n_mit_mot_outs", 0),
("mit_mot_out_slices", []),
("truncate_gradient", -1),
("name", None),
("mode", None),
("destroy_map", OrderedDict()),
("gpua", False),
("as_while", False),
("profile", False),
("allow_gc", False),
]
)
res = cls([1], [], [], [], info)
res.n_steps = None
return res

@property
def n_nit_sot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.info["n_nit_sot"]

@property
def inputs(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.inner_inputs

@property
def n_mit_mot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.info["n_mit_mot"]

@property
def var_mappings(self):
return Scan.get_oinp_iinp_iout_oout_mappings(self)

@property
def field_names(self):
res = ["mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slices"]
res.extend(
[
attr
for attr in self.__dict__
if attr.startswith("inner_in")
or attr.startswith("inner_out")
or attr.startswith("outer_in")
or attr.startswith("outer_out")
or attr == "n_steps"
]
)
return res

def get_alt_field(self, var_info, alt_prefix):
"""Get the alternate input/output field for a given element of `ScanArgs`.
For example, if `var_info` is in `ScanArgs.outer_out_sit_sot`, then
`get_alt_field(var_info, "inner_out")` returns the element corresponding
`var_info` in `ScanArgs.inner_out_sit_sot`.
Parameters
----------
var_info: TensorVariable or FieldInfo
The element for which we want the alternate
alt_prefix: str
The string prefix for the alternate field type. It can be one of
the following: "inner_out", "inner_in", "outer_in", and "outer_out".
Outputs
-------
TensorVariable
Returns the alternate variable.
"""
if not isinstance(var_info, FieldInfo):
var_info = self.find_among_fields(var_info)

alt_type = var_info.name[(var_info.name.index("_", 6) + 1) :]
alt_var = getattr(self, "inner_out_{}".format(alt_type))[var_info.index]
return alt_var

def find_among_fields(self, i, field_filter=default_filter):
"""Find the type and indices of the field containing a given element.
NOTE: This only returns the *first* field containing the given element.
Parameters
----------
i: theano.gof.graph.Variable
The element to find among this object's fields.
field_filter: function
A function passed to `filter` that determines which fields to
consider. It must take a string field name and return a truthy
value.
Returns
-------
A tuple of length 4 containing the field name string, the first index,
the second index (for nested lists), and the "major" index (i.e. the
index within the aggregate lists like `self.inner_inputs`,
`self.outer_outputs`, etc.), or a triple of `None` when no match is
found.
"""

field_names = filter(field_filter, self.field_names)

for field_name in field_names:
lst = getattr(self, field_name)

field_prefix = field_name[:8]
if field_prefix.endswith("in"):
agg_field_name = "{}puts".format(field_prefix)
else:
agg_field_name = "{}tputs".format(field_prefix)

agg_list = getattr(self, agg_field_name)

if field_name in self.nested_list_fields:
for n, sub_lst in enumerate(lst):
idx = safe_index(sub_lst, i)
if idx is not None:
agg_idx = safe_index(agg_list, i)
return FieldInfo(field_name, agg_field_name, n, idx, agg_idx)
else:
idx = safe_index(lst, i)
if idx is not None:
agg_idx = safe_index(agg_list, i)
return FieldInfo(field_name, agg_field_name, idx, None, agg_idx)

return None

def _remove_from_fields(self, i, field_filter=default_filter):

field_info = self.find_among_fields(i, field_filter=field_filter)

if field_info is None:
return None

if field_info.inner_index is not None:
getattr(self, field_info.name)[field_info.index].remove(i)
else:
getattr(self, field_info.name).remove(i)

return field_info

def get_dependent_nodes(self, i, seen=None):

if seen is None:
seen = {i}
else:
seen.add(i)

var_mappings = self.var_mappings

field_info = self.find_among_fields(i)

if field_info is None:
raise ValueError("{} not found among fields.".format(i))

# Find the `var_mappings` key suffix that matches the field/set of
# arguments containing our source node
if field_info.name[:8].endswith("_in"):
map_key_suffix = "{}p".format(field_info.name[:8])
else:
map_key_suffix = field_info.name[:9]

dependent_nodes = set()
for k, v in var_mappings.items():

if not k.endswith(map_key_suffix):
continue

dependent_idx = v[field_info.agg_index]
dependent_idx = dependent_idx if isinstance(dependent_idx, list) else [dependent_idx]

# Get the `ScanArgs` field name for the aggregate list property
# corresponding to these dependent argument types (i.e. either
# "outer_inputs", "inner_inputs", "inner_outputs", or
# "outer_outputs").
# To do this, we need to parse the "shared" prefix of the
# current `var_mappings` key and append the missing parts so that
# it either forms `"*_inputs"` or `"*_outputs"`.
to_agg_field_prefix = k[:9]
if to_agg_field_prefix.endswith("p"):
to_agg_field_name = "{}uts".format(to_agg_field_prefix)
else:
to_agg_field_name = "{}puts".format(to_agg_field_prefix)

to_agg_field = getattr(self, to_agg_field_name)

for d_id in dependent_idx:
if d_id < 0:
continue

dependent_var = to_agg_field[d_id]

if dependent_var not in seen:
dependent_nodes.add(dependent_var)

if field_info.name.startswith("inner_in"):
# If starting from an inner-input, then we need to find any
# inner-outputs that depend on it.
for out_n in self.inner_outputs:
if i in tt_inputs([out_n]):
if out_n not in seen:
dependent_nodes.add(out_n)

for n in tuple(dependent_nodes):
if n in seen:
continue
sub_dependent_nodes = self.get_dependent_nodes(n, seen=seen)
dependent_nodes |= sub_dependent_nodes
seen |= sub_dependent_nodes

return dependent_nodes

def remove_from_fields(self, i, rm_dependents=True):

if rm_dependents:
vars_to_remove = self.get_dependent_nodes(i) | {i}
else:
vars_to_remove = {i}

rm_info = []
for v in vars_to_remove:
dependent_rm_info = self._remove_from_fields(v)
rm_info.append((v, dependent_rm_info))

return rm_info

def __str__(self):
inner_arg_strs = [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("outer_in") or p == "n_steps"
]
inner_arg_strs += [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("inner_in")
]
inner_arg_strs += [
"\tmit_mot_in_slices={}".format(self.mit_mot_in_slices),
"\tmit_sot_in_slices={}".format(self.mit_sot_in_slices),
]
inner_arg_strs += [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("inner_out")
]
inner_arg_strs += [
"\tmit_mot_out_slices={}".format(self.mit_mot_out_slices),
]
inner_arg_strs += [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("outer_out")
]
res = "ScanArgs(\n{})".format(",\n".join(inner_arg_strs))
return res

def __repr__(self):
return self.__str__()

def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented

for field_name in self.field_names:
if not hasattr(other, field_name) or getattr(self, field_name) != getattr(
other, field_name
):
return False

return True


@local_optimizer([Scan])
def push_out_rvs_from_scan(node):
"""Push `RandomVariable`s out of `Scan` nodes.
Expand Down

0 comments on commit b1bb05a

Please sign in to comment.