Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] Add guards for deterministic algos #96695

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion test/dynamo/test_comptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,18 @@ def _(ctx):
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
self.assertExpectedInline(
FILE.getvalue().rstrip(),
re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE),
"""\
-
local 'x' TENSOR_MATCH
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
-
global '' DETERMINISTIC_ALGORITHMS
{
'guard_types': None,
'code': None,
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4406,9 +4406,9 @@ def guard_export_print(guards):
opt_fn(x, y)

if torch._dynamo.config.dynamic_shapes:
self.assertEqual(len(all_guards), 13)
self.assertEqual(len(all_guards), 17)
else:
self.assertEqual(len(all_guards), 9)
self.assertEqual(len(all_guards), 13)
for guard in all_guards:
# This guard was created
self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents")
Expand Down
20 changes: 20 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7637,6 +7637,26 @@ def fn(arg3_1, arg3_2, relu, permute_1):
ref = torch.compile(fn, fullgraph=True)(*args)
assert same(ref, correct)

@requires_cuda()
def test_deterministic_algorithms(self):
N = 10000

@torch.compile
def fn(idx, values):
x = torch.zeros(1, device="cuda")
x[idx] += values
return x

idx = torch.zeros(N, dtype=torch.int64, device="cuda")
values = torch.randn(N, device="cuda")

r0 = fn(idx, values)
with DeterministicGuard(True):
r1 = fn(idx, values)
for _ in range(10):
rn = fn(idx, values)
self.assertEqual(r1, rn, atol=0, rtol=0)

class TritonCodeGenTests(TestCase):
from torch._inductor.triton_heuristics import CachingAutotuner

Expand Down
6 changes: 6 additions & 0 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def clear(self):


initial_grad_state = None
initial_deterministic_algorithms_state = None


@functools.wraps(original_forward_from_src)
Expand Down Expand Up @@ -273,6 +274,11 @@ def format_guard_failures(code):
global initial_grad_state
initial_grad_state = torch.is_grad_enabled()

global initial_deterministic_algorithms_state
initial_deterministic_algorithms_state = (
torch.are_deterministic_algorithms_enabled()
)

return _compile(
frame.f_code,
frame.f_globals,
Expand Down
14 changes: 14 additions & 0 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
("___check_type_id", check_type_id),
("___check_obj_id", check_obj_id),
("___is_grad_enabled", torch.is_grad_enabled),
(
"___are_deterministic_algorithms_enabled",
torch.are_deterministic_algorithms_enabled,
),
("___odict_getitem", collections.OrderedDict.__getitem__),
("___dict_param_key_ids", dict_param_key_ids),
("___dict_const_keys", dict_const_keys),
Expand Down Expand Up @@ -414,6 +418,16 @@ def GRAD_MODE(self, guard: Guard):
code = "not ___is_grad_enabled()"
self._produce_guard_code(guard, [code])

def DETERMINISTIC_ALGORITHMS(self, guard: Guard):
"""Guard on the initial determinism algorithms state"""
assert guard.source is GuardSource.GLOBAL
code = None
if convert_frame.initial_deterministic_algorithms_state:
code = "___are_deterministic_algorithms_enabled()"
else:
code = "not ___are_deterministic_algorithms_enabled()"
self._produce_guard_code(guard, [code])

def SHAPE_ENV(self, guard: Guard):
# Let's handle ShapeEnv guards. To do this, we will resolve
# shape variables to sources from tracked_fakes. This must happen after
Expand Down
7 changes: 7 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .side_effects import SideEffects
from .source import (
ConstantSource,
DeterministicAlgorithmsSource,
is_constant_source,
LocalSource,
ParamBufferSource,
Expand Down Expand Up @@ -210,6 +211,12 @@ def __init__(
# that show up in ShapeEnv
self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))

self.guards.add(
DeterministicAlgorithmsSource().make_guard(
GuardBuilder.DETERMINISTIC_ALGORITHMS
)
)

# tracked_fakes says where any tensor that was wrapped to fake came
# from. It is similar to GraphArg, in that all GraphArgs will get
# will get added to TrackedFakes, but TrackedFakes also contains
Expand Down
9 changes: 9 additions & 0 deletions torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,15 @@ def guard_source(self):
return _GUARD_SOURCE_FSDP_MODULE[self.inner.guard_source()]


@dataclasses.dataclass
class DeterministicAlgorithmsSource(Source):
def name(self):
return ""

def guard_source(self):
return GuardSource.GLOBAL


@dataclasses.dataclass
class ConstantSource(Source):
source_name: str
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ContextWrappingVariable,
CUDAStreamContextVariable,
CUDAStreamVariable,
DeterministicAlgorithmsVariable,
GetAttrVariable,
GradModeVariable,
InspectSignatureVariable,
Expand Down Expand Up @@ -60,6 +61,7 @@
"FakeItemVariable",
"GetAttrVariable",
"GradModeVariable",
"DeterministicAlgorithmsVariable",
"InspectSignatureVariable",
"LambdaVariable",
"ListIteratorVariable",
Expand Down
41 changes: 41 additions & 0 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,47 @@ def fn_name(self):
return "set_grad_enabled"


class DeterministicAlgorithmsVariable(ContextWrappingVariable):
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""

_guards_singleton = {
Guard("", GuardSource.GLOBAL, GuardBuilder.DETERMINISTIC_ALGORITHMS)
}

@staticmethod
def create(tx, target_value, **kwargs):
var = DeterministicAlgorithmsVariable(
target_values=[target_value],
initial_values=[torch.are_deterministic_algorithms_enabled()],
**kwargs,
)
var._call_func(tx, [target_value])
return var

def __init__(self, target_values, initial_values=None, **kwargs):
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.guards = self.guards | self._guards_singleton

def enter(self, tx):
return variables.ConstantVariable(None, **VariableTracker.propagate(self))

def _call_func(self, tx, values):
assert len(values) == 1
value = values[0]
tx.output.create_node(
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
),
torch._C._set_deterministic_algorithms(value)

def module_name(self):
return "torch"

def fn_name(self):
return "use_deterministic_algorithms"


class AutocastModeVariable(ContextWrappingVariable):
@staticmethod
def create(target_values, kwargs):
Expand Down
10 changes: 10 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def call_function(
ConstantVariable,
CUDAStreamContextVariable,
CUDAStreamVariable,
DeterministicAlgorithmsVariable,
GradModeVariable,
SymNodeVariable,
TensorVariable,
Expand Down Expand Up @@ -275,6 +276,15 @@ def call_function(
return ConstantVariable(torch.is_grad_enabled(), **options).add_guards(
GradModeVariable._guards_singleton
)
elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
return DeterministicAlgorithmsVariable.create(
tx, args[0].as_python_constant(), **options
)
elif self.value is torch.are_deterministic_algorithms_enabled:
assert not (args or kwargs)
return ConstantVariable(
torch.are_deterministic_algorithms_enabled(), **options
).add_guards(DeterministicAlgorithmsVariable._guards_singleton)
elif self.value is torch.cuda.stream:
log.warning(
"torch.cuda.stream() not fully supported, streams may be ignored"
Expand Down