Skip to content

Commit

Permalink
Update on "[dynamo] Collect cell_and_freevars correctly"
Browse files Browse the repository at this point in the history
cc ezyang msaroufim bdhirsh chauhang voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire

[ghstack-poisoned]
  • Loading branch information
anijain2305 committed Apr 27, 2024
2 parents dd8d8f1 + 8f246b0 commit 51ccbc5
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 97 deletions.
10 changes: 0 additions & 10 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,6 @@ def constant3(a, b):
return a - b + (1.0 + 2)


_variable = 0


def update_global(x):
global _variable
_variable += 1
# Check that updated global variable value is picked up
return x * _variable


def func_with_default(a, b, some_default_arg=True):
if some_default_arg:
return a - b
Expand Down
26 changes: 0 additions & 26 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,6 @@
import test_functions


_variable = 0
_variable1 = 0


def update_global():
global _variable, _variable1
_variable += 1
_variable1 += 1


class BasicModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -2445,22 +2435,6 @@ def forward(self, inp):

self.assertEqual(model.x, compiled_model.x)

def test_globals_change_in_other_file(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
update_global()
a = test_functions.update_global(x)
# Ensure that the updated global values are read
return x * a * (_variable + _variable1 + test_functions._variable)

res = fn(torch.ones(10))
self.assertEqual(_variable, 1)
self.assertEqual(_variable1, 1)
# Ensure that the reconstructed bytecode updates the global value in the
# other file.
self.assertEqual(test_functions._variable, 1)
self.assertEqual(res, 3 * torch.ones(10))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
3 changes: 1 addition & 2 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from .codegen import PyCodegen
from .exc import unimplemented
from .source import GlobalSource, LocalSource, Source
from .source import LocalSource, Source
from .utils import nn_module_new, object_new
from .variables.base import (
is_side_effect_safe,
Expand Down Expand Up @@ -485,7 +485,6 @@ def codegen_update_mutated(self, cg: PyCodegen):
if isinstance(var, variables.NewGlobalVariable):
cg.tx.output.update_co_names(name)
cg(value)
assert isinstance(var.mutable_local.source, GlobalSource) # type: ignore[attr-defined]
suffixes.append(
[create_instruction("STORE_GLOBAL", argval=name)]
)
Expand Down
77 changes: 18 additions & 59 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,22 @@ def _load_const(self, inst):
def LOAD_CONST(self, inst):
self.push(self._load_const(inst))

def get_global_source(self, name):
source: Source
if self.output.global_scope is self.f_globals:
source = GlobalSource(name)
else:
if "__name__" in self.f_globals:
source = AttrSource(
self.import_source(self.f_globals["__name__"]), name
)
else:
mangled_name = self.output.install_global_by_id(
"___unnamed_scope", self.f_globals
)
source = GetItemSource(GlobalSource(mangled_name), name)
return source

def LOAD_GLOBAL(self, inst):
if sys.version_info >= (3, 11):
if inst.arg % 2:
Expand Down Expand Up @@ -996,13 +1012,13 @@ def LOAD_GLOBAL(self, inst):
except KeyError:
return self.load_builtin(inst)

source = GlobalSource(name)
source = self.get_global_source(name)
self.push(VariableBuilder(self, source)(value))

def STORE_GLOBAL(self, inst):
value = self.pop()
name = inst.argval
source = GlobalSource(name)
source = self.get_global_source(name)
if name not in self.symbolic_globals:
self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object
variable = self.output.side_effects.track_global_existing(
Expand Down Expand Up @@ -2672,63 +2688,6 @@ def RETURN_CONST(self, inst):
self.instruction_pointer = None
raise ReturnValueOp

def get_globals_source_and_value(self, name):
if "__name__" in self.f_globals:
module_name = self.f_globals["__name__"]
module_source = self.import_source(module_name)
if "torch_package" in module_name:
fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment]
else:
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
fglobals_vt = VariableBuilder(self, module_source)(fglobals_value)
global_source = AttrSource(module_source, name)
else:
globals_name = self.output.install_global_by_id(
"___unnamed_scope", self.f_globals
)
globals_source = GlobalSource(globals_name)
fglobals_value = self.f_globals # type: ignore[assignment]
fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value)
global_source = GetItemSource(globals_source, name) # type: ignore[assignment]
return fglobals_value, fglobals_vt, global_source

def LOAD_GLOBAL(self, inst):
if self.output.global_scope is self.f_globals:
super().LOAD_GLOBAL(inst)
else:
if sys.version_info >= (3, 11):
if inst.arg % 2:
self.PUSH_NULL(inst)

name = inst.argval
if inst.argval == "AssertionError":
unimplemented("assert with non-string message")

_, fglobals_vt, global_source = self.get_globals_source_and_value(name)
if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name):
self.push(self.output.side_effects.load_attr(fglobals_vt, name))
else:
try:
value = self.f_globals[name]
except KeyError:
return self.load_builtin(inst)

self.push(VariableBuilder(self, global_source)(value))

def STORE_GLOBAL(self, inst):
if self.f_globals is self.parent.f_globals:
super().STORE_GLOBAL(inst)
else:
value = self.pop()
if isinstance(value, RemovableHandleVariable):
unimplemented("Storing handles in globals - NYI")
name = inst.argval
fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
fglobals_vt = self.output.side_effects.track_object_existing(
fglobals_value, fglobals_vt
)
self.output.side_effects.store_attr(fglobals_vt, name, value)


class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
generated_items: List[VariableTracker]
Expand Down

0 comments on commit 51ccbc5

Please sign in to comment.