Skip to content

Commit

Permalink
[dynamo] Add guards for deterministic algos.
Browse files Browse the repository at this point in the history
Inductor now falls back to eager mode for deterministic algos. Add
guards in dynamo to check if the deterministic algos mode changes.

See #93537
  • Loading branch information
colesbury committed Mar 30, 2023
1 parent 0e4ddc2 commit 71c2849
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 2 deletions.
9 changes: 9 additions & 0 deletions test/dynamo/test_comptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,15 @@ def _(ctx):
"""\
-
local 'x' TENSOR_MATCH
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
-
global '' DETERMINISTIC_ALGORITHMS
{
'guard_types': None,
'code': None,
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4406,9 +4406,9 @@ def guard_export_print(guards):
opt_fn(x, y)

if torch._dynamo.config.dynamic_shapes:
self.assertEqual(len(all_guards), 13)
self.assertEqual(len(all_guards), 17)
else:
self.assertEqual(len(all_guards), 9)
self.assertEqual(len(all_guards), 13)
for guard in all_guards:
# This guard was created
self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents")
Expand Down
20 changes: 20 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7637,6 +7637,26 @@ def fn(arg3_1, arg3_2, relu, permute_1):
ref = torch.compile(fn, fullgraph=True)(*args)
assert same(ref, correct)

@requires_cuda()
def test_deterministic_algorithms(self):
N = 10000

@torch.compile
def fn(idx, values):
x = torch.zeros(1, device="cuda")
x[idx] += values
return x

idx = torch.zeros(N, dtype=torch.int64, device="cuda")
values = torch.randn(N, device="cuda")

r0 = fn(idx, values)
with DeterministicGuard(True):
r1 = fn(idx, values)
for _ in range(10):
rn = fn(idx, values)
assert (r1 == rn).all()

class TritonCodeGenTests(TestCase):
from torch._inductor.triton_heuristics import CachingAutotuner

Expand Down
6 changes: 6 additions & 0 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def clear(self):


initial_grad_state = None
initial_deterministic_algorithms_state = None


@functools.wraps(original_forward_from_src)
Expand Down Expand Up @@ -273,6 +274,11 @@ def format_guard_failures(code):
global initial_grad_state
initial_grad_state = torch.is_grad_enabled()

global initial_deterministic_algorithms_state
initial_deterministic_algorithms_state = (
torch.are_deterministic_algorithms_enabled()
)

return _compile(
frame.f_code,
frame.f_globals,
Expand Down
14 changes: 14 additions & 0 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
("___check_type_id", check_type_id),
("___check_obj_id", check_obj_id),
("___is_grad_enabled", torch.is_grad_enabled),
(
"___are_deterministic_algorithms_enabled",
torch.are_deterministic_algorithms_enabled,
),
("___odict_getitem", collections.OrderedDict.__getitem__),
("___dict_param_key_ids", dict_param_key_ids),
("___dict_const_keys", dict_const_keys),
Expand Down Expand Up @@ -414,6 +418,16 @@ def GRAD_MODE(self, guard: Guard):
code = "not ___is_grad_enabled()"
self._produce_guard_code(guard, [code])

def DETERMINISTIC_ALGORITHMS(self, guard: Guard):
"""Guard on the initial determinism algorithms state"""
assert guard.source is GuardSource.GLOBAL
code = None
if convert_frame.initial_deterministic_algorithms_state:
code = "___are_deterministic_algorithms_enabled()"
else:
code = "not ___are_deterministic_algorithms_enabled()"
self._produce_guard_code(guard, [code])

def SHAPE_ENV(self, guard: Guard):
# Let's handle ShapeEnv guards. To do this, we will resolve
# shape variables to sources from tracked_fakes. This must happen after
Expand Down
7 changes: 7 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .side_effects import SideEffects
from .source import (
ConstantSource,
DeterministicAlgorithmsSource,
is_constant_source,
LocalSource,
ParamBufferSource,
Expand Down Expand Up @@ -210,6 +211,12 @@ def __init__(
# that show up in ShapeEnv
self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))

self.guards.add(
DeterministicAlgorithmsSource().make_guard(
GuardBuilder.DETERMINISTIC_ALGORITHMS
)
)

# tracked_fakes says where any tensor that was wrapped to fake came
# from. It is similar to GraphArg, in that all GraphArgs will get
# will get added to TrackedFakes, but TrackedFakes also contains
Expand Down
9 changes: 9 additions & 0 deletions torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,15 @@ def guard_source(self):
return _GUARD_SOURCE_FSDP_MODULE[self.inner.guard_source()]


@dataclasses.dataclass
class DeterministicAlgorithmsSource(Source):
def name(self):
return ""

def guard_source(self):
return GuardSource.GLOBAL


@dataclasses.dataclass
class ConstantSource(Source):
source_name: str
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ContextWrappingVariable,
CUDAStreamContextVariable,
CUDAStreamVariable,
DeterministicAlgorithmsVariable,
GetAttrVariable,
GradModeVariable,
InspectSignatureVariable,
Expand Down Expand Up @@ -60,6 +61,7 @@
"FakeItemVariable",
"GetAttrVariable",
"GradModeVariable",
"DeterministicAlgorithmsVariable",
"InspectSignatureVariable",
"LambdaVariable",
"ListIteratorVariable",
Expand Down
41 changes: 41 additions & 0 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,47 @@ def fn_name(self):
return "set_grad_enabled"


class DeterministicAlgorithmsVariable(ContextWrappingVariable):
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""

_guards_singleton = {
Guard("", GuardSource.GLOBAL, GuardBuilder.DETERMINISTIC_ALGORITHMS)
}

@staticmethod
def create(tx, target_value, **kwargs):
var = DeterministicAlgorithmsVariable(
target_values=[target_value],
initial_values=[torch.are_deterministic_algorithms_enabled()],
**kwargs,
)
var._call_func(tx, [target_value])
return var

def __init__(self, target_values, initial_values=None, **kwargs):
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.guards = self.guards | self._guards_singleton

def enter(self, tx):
return variables.ConstantVariable(None, **VariableTracker.propagate(self))

def _call_func(self, tx, values):
assert len(values) == 1
value = values[0]
tx.output.create_node(
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
),
torch._C._set_deterministic_algorithms(value)

def module_name(self):
return "torch"

def fn_name(self):
return "use_deterministic_algorithms"


class AutocastModeVariable(ContextWrappingVariable):
@staticmethod
def create(target_values, kwargs):
Expand Down
10 changes: 10 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def call_function(
ConstantVariable,
CUDAStreamContextVariable,
CUDAStreamVariable,
DeterministicAlgorithmsVariable,
GradModeVariable,
SymNodeVariable,
TensorVariable,
Expand Down Expand Up @@ -275,6 +276,15 @@ def call_function(
return ConstantVariable(torch.is_grad_enabled(), **options).add_guards(
GradModeVariable._guards_singleton
)
elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
return DeterministicAlgorithmsVariable.create(
tx, args[0].as_python_constant(), **options
)
elif self.value is torch.are_deterministic_algorithms_enabled:
assert not (args or kwargs)
return ConstantVariable(
torch.are_deterministic_algorithms_enabled(), **options
).add_guards(DeterministicAlgorithmsVariable._guards_singleton)
elif self.value is torch.cuda.stream:
log.warning(
"torch.cuda.stream() not fully supported, streams may be ignored"
Expand Down

0 comments on commit 71c2849

Please sign in to comment.