Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fix for writing to closures #233

Merged
merged 3 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,38 @@ def fn3():
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 11)

def test_write_to_closures_in_inlining(self):
out = []
for use_dynamo in [False, True]:

def make_counter():
x = torch.randn(10)

def counter():
nonlocal x
x = x + 1
return x

return counter

torch.manual_seed(0)
counter = make_counter()
if not use_dynamo:
out.append(counter() + counter())
else:
cnts = torchdynamo.testing.CompileCounter()

@torchdynamo.optimize(cnts, nopython=True)
def fn(counter):
return counter() + counter()

out.append(fn(counter))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 3)
self.assertFalse(same(counter() + counter(), out[-1]))

self.assertTrue(same(out[0], out[1]))

def test_top_package_import(self):
def fn(x):
import torch.fx
Expand Down
17 changes: 13 additions & 4 deletions torchdynamo/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ def track_cell_new(
self.keepalive.append(obj)
return variable

def track_cell_existing(self, source: Source, item: Any):
variable = variables.NewCellVariable(
mutable_local=AttributeMutationExisting(source),
)
self.id_to_variable[id(item)] = variable
self.keepalive.append(item)
return variable

def prune_dead_object_new(self, tx):
live_new_objects = set()
skip_obj = None
Expand Down Expand Up @@ -232,13 +240,14 @@ def codegen(self, cg: PyCodegen):
]

for var in modified_vars:
if isinstance(var.mutable_local, AttributeMutationNew) and isinstance(
var, variables.NewCellVariable
):
if isinstance(
var.mutable_local, (AttributeMutationExisting, AttributeMutationNew)
) and isinstance(var, variables.NewCellVariable):
cg.load_import_from(utils.__name__, "make_cell")
cg.extend_output([create_instruction("CALL_FUNCTION", 0)])
cg.add_cache(var)
var.mutable_local.source = LocalSource(cg.tempvars[var])
if isinstance(var.mutable_local, AttributeMutationNew):
var.mutable_local.source = LocalSource(cg.tempvars[var])
elif isinstance(var.mutable_local, AttributeMutationNew):
cg.load_import_from(utils.__name__, "object_new")
cg(var.mutable_local.cls_source)
Expand Down
16 changes: 14 additions & 2 deletions torchdynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,15 @@ def STORE_DEREF(self, inst):
else:
self.output.side_effects.store_cell(cell, val)
else:
unimplemented("write to __closure__ while inlining")
if isinstance(
self.symbolic_locals.get(inst.argval),
torchdynamo.variables.NewCellVariable,
):
self.output.side_effects.store_cell(
self.symbolic_locals[inst.argval], self.pop()
)
else:
unimplemented("write to __closure__ while inlining")

def LOAD_DEREF(self, inst):
if inst.argval in self.closure_cells:
Expand All @@ -1303,7 +1311,11 @@ def LOAD_DEREF(self, inst):
else:
self.push(self.output.side_effects.load_cell(cell))
else:
super().LOAD_DEREF(inst)
maybe_sym_local = self.symbolic_locals.get(inst.argval, None)
if isinstance(maybe_sym_local, torchdynamo.variables.NewCellVariable):
self.push(self.output.side_effects.load_cell(maybe_sym_local))
else:
super().LOAD_DEREF(inst)

def LOAD_CLOSURE(self, inst):
assert inst.argval in self.cell_and_freevars()
Expand Down
28 changes: 23 additions & 5 deletions torchdynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,29 @@ def bind_args(self, parent, args, kwargs):
elif self.source:
from .builder import VariableBuilder

source = AttrSource(
GetItemSource(AttrSource(self.source, "__closure__"), idx),
"cell_contents",
)
result[name] = VariableBuilder(parent, source)(cell.cell_contents)
side_effects = parent.output.side_effects
if cell in side_effects:
out = side_effects[cell]
else:
closure_cell = GetItemSource(
AttrSource(self.source, "__closure__"), idx
)
closure_cell_contents = AttrSource(
closure_cell, "cell_contents"
)

# cells are written to with "cell_contents",
# so the source should just be the closure_cell, not its contents
out = side_effects.track_cell_existing(closure_cell, cell)
side_effects.store_cell(
out,
VariableBuilder(parent, closure_cell_contents)(
cell.cell_contents
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I know this is ancient history, but the new code here looks worse than the old code in one sense: we're unconditionally writing to the cell on function creation, triggering side effects, but in fact we may not actually need to write the closure cell if there aren't any variables that will get mutated in the function. I wonder if there's a way to avoid unconditionally store_cell here (context: I'm rewriting cond() operator to be bytecode based and for simplicity I want to assert that there are no side effects, but benign use of closures in our test cases are triggering side effects.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what I am going to try is analyzing the bytecode to see if there is a STORE_DEREF on the relevant freevar. If there isn't I can assume I don't need to allocate a cell.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up doing a different strategy by restarting analysis


result[name] = out

else:
unimplemented("inline with __closure__")

Expand Down