Skip to content

Commit

Permalink
Make cond_indep_stack a tuple (hence hashable) (#854)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored and martinjankowiak committed Mar 6, 2018
1 parent 748e344 commit ff88887
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def sample(name, fn, *args, **kwargs):
"value": None,
"infer": infer,
"scale": 1.0,
"cond_indep_stack": [],
"cond_indep_stack": (),
"done": False,
"stop": False,
"continuation": None
Expand Down Expand Up @@ -328,7 +328,7 @@ def param(name, *args, **kwargs):
"kwargs": kwargs,
"infer": {},
"scale": 1.0,
"cond_indep_stack": [],
"cond_indep_stack": (),
"value": None,
"done": False,
"stop": False,
Expand Down
3 changes: 2 additions & 1 deletion pyro/poutine/indep_poutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ def next_context(self):
self.counter += 1

def _process_message(self, msg):
msg["cond_indep_stack"].insert(0, CondIndepStackFrame(self.name, self.counter, self.vectorized, self.size))
frame = CondIndepStackFrame(self.name, self.counter, self.vectorized, self.size)
msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
return None
4 changes: 2 additions & 2 deletions pyro/poutine/trace_poutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def get_vectorized_map_data_info(trace):
if site_is_subsample(node):
continue
if node["type"] in ("sample", "param"):
stack = tuple(node["cond_indep_stack"])
stack = node["cond_indep_stack"]
vec_mds = [x for x in stack if x.vectorized]
stack_dict[name] = vec_mds

for name, node in nodes.items():
if site_is_subsample(node):
continue
if node["type"] in ("sample", "param"):
stack = tuple(node["cond_indep_stack"])
stack = node["cond_indep_stack"]
vec_mds = [x for x in stack if x.vectorized]
stack_dict[name] = vec_mds
# check for nested vectorized map datas
Expand Down

0 comments on commit ff88887

Please sign in to comment.