Skip to content

Commit

Permalink
[inductor] Fall back to eager mode for deterministic algos.
Browse files Browse the repository at this point in the history
Inductor generates indexing kernels that use atomic ops and behaves
non-deterministically with duplicate indexes. Fallback to eager mode
if torch.use_deterministic_algorithms is True.

Fixes #93537
  • Loading branch information
colesbury committed Mar 13, 2023
1 parent 13011af commit 93955b4
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 0 deletions.
21 changes: 21 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import all_types
from torch.testing._internal.common_utils import (
DeterministicGuard,
IS_CI,
IS_MACOS,
IS_WINDOWS,
Expand Down Expand Up @@ -7289,6 +7290,26 @@ def fn(x, y):
fn_optimized = torch._dynamo.optimize("inductor")(fn)
assert same(fn(a, b), fn_optimized(a, b))

@requires_cuda()
def test_deterministic_algorithms(self):
N = 10000

@torch.compile
def fn(idx, values):
x = torch.zeros(1, device="cuda")
x[idx] += values
return x

idx = torch.zeros(N, dtype=torch.int64, device="cuda")
values = torch.randn(N, device="cuda")

r0 = fn(idx, values)
with DeterministicGuard(True):
r1 = fn(idx, values)
for _ in range(10):
rn = fn(idx, values)
assert (r1 == rn).all()

class TritonCodeGenTests(TestCase):
from torch._inductor.triton_ops.autotune import CachingAutotuner

Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def format_guard_failures(code):
global initial_grad_state
initial_grad_state = torch.is_grad_enabled()

global initial_deterministic_algorithms_state
initial_deterministic_algorithms_state = torch.are_deterministic_algorithms_enabled()

return _compile(
frame.f_code,
frame.f_globals,
Expand Down
11 changes: 11 additions & 0 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
("___check_type_id", check_type_id),
("___check_obj_id", check_obj_id),
("___is_grad_enabled", torch.is_grad_enabled),
("___are_deterministic_algorithms_enabled", torch.are_deterministic_algorithms_enabled),
("___odict_getitem", collections.OrderedDict.__getitem__),
("___dict_param_key_ids", dict_param_key_ids),
("___dict_const_keys", dict_const_keys),
Expand Down Expand Up @@ -414,6 +415,16 @@ def GRAD_MODE(self, guard: Guard):
code = "not ___is_grad_enabled()"
self._produce_guard_code(guard, [code])

def DETERMINISTIC_ALGORITHMS(self, guard: Guard):
"""Guard on the initial determinism algorithms state"""
assert guard.source is GuardSource.GLOBAL
code = None
if convert_frame.initial_deterministic_algorithms_state:
code = "___are_deterministic_algorithms_enabled()"
else:
code = "not ___are_deterministic_algorithms_enabled()"
self._produce_guard_code(guard, [code])

def SHAPE_ENV(self, guard: Guard):
# Let's handle ShapeEnv guards. To do this, we will resolve
# shape variables to sources from tracked_fakes. This must happen after
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .source import (
ConstantSource,
is_constant_source,
DeterministicAlgorithmsSource,
LocalInputSource,
LocalSource,
ShapeEnvSource,
Expand Down Expand Up @@ -198,6 +199,9 @@ def __init__(
# that show up in ShapeEnv
self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))

self.guards.add(DeterministicAlgorithmsSource().make_guard(
GuardBuilder.DETERMINISTIC_ALGORITHMS))

# tracked_fakes says where any tensor that was wrapped to fake came
# from. It is similar to GraphArg, in that all GraphArgs will get
# will get added to TrackedFakes, but TrackedFakes also contains
Expand Down
9 changes: 9 additions & 0 deletions torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,15 @@ def guard_source(self):
return _GUARD_SOURCE_NOT_NN_MODULE[self.inner.guard_source()]


@dataclasses.dataclass
class DeterministicAlgorithmsSource(Source):
def name(self):
return ""

def guard_source(self):
return GuardSource.GLOBAL


@dataclasses.dataclass
class ConstantSource(Source):
source_name: str
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CUDAStreamVariable,
GetAttrVariable,
GradModeVariable,
DeterministicAlgorithmsVariable,
InspectSignatureVariable,
LambdaVariable,
NewCellVariable,
Expand Down Expand Up @@ -60,6 +61,7 @@
"FakeItemVariable",
"GetAttrVariable",
"GradModeVariable",
"DeterministicAlgorithmsVariable",
"InspectSignatureVariable",
"LambdaVariable",
"ListIteratorVariable",
Expand Down
39 changes: 39 additions & 0 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,45 @@ def fn_name(self):
return "set_grad_enabled"


class DeterministicAlgorithmsVariable(ContextWrappingVariable):
"""represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()"""

_guards_singleton = {Guard("", GuardSource.GLOBAL, GuardBuilder.DETERMINISTIC_ALGORITHMS)}

@staticmethod
def create(tx, target_value, **kwargs):
var = DeterministicAlgorithmsVariable(
target_values=[target_value],
initial_values=[torch.are_deterministic_algorithms_enabled()],
**kwargs,
)
var._call_func(tx, [target_value])
return var

def __init__(self, target_values, initial_values=None, **kwargs):
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.guards = self.guards | self._guards_singleton

def enter(self, tx):
return variables.ConstantVariable(None, **VariableTracker.propagate(self))

def _call_func(self, tx, values):
assert len(values) == 1
value = values[0]
tx.output.create_node(
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
),
torch._C._set_deterministic_algorithms(value)

def module_name(self):
return "torch"

def fn_name(self):
return "use_deterministic_algorithms"


class AutocastModeVariable(ContextWrappingVariable):
@staticmethod
def create(target_values, kwargs):
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def call_function(
CUDAStreamContextVariable,
CUDAStreamVariable,
GradModeVariable,
DeterministicAlgorithmsVariable,
SymNodeVariable,
TensorVariable,
UserDefinedObjectVariable,
Expand Down Expand Up @@ -274,6 +275,13 @@ def call_function(
return ConstantVariable(torch.is_grad_enabled(), **options).add_guards(
GradModeVariable._guards_singleton
)
elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
return DeterministicAlgorithmsVariable.create(tx, args[0].as_python_constant(), **options)
elif self.value is torch.are_deterministic_algorithms_enabled:
assert not (args or kwargs)
return ConstantVariable(torch.are_deterministic_algorithms_enabled(), **options).add_guards(
DeterministicAlgorithmsVariable._guards_singleton
)
elif self.value is torch.cuda.stream:
log.warning(
"torch.cuda.stream() not fully supported, streams may be ignored"
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,10 @@ def index_put_(self, indices, values, accumulate=False):
):
return index_put_as_masked_fill(self, indices, values, accumulate)

if torch.are_deterministic_algorithms_enabled():
# Fallback. Inductor lowerings are non-deterministic (they use atomic ops).
return index_put_fallback(self, indices, values, accumulate)

# Fallback if there is a boolean index
for index in indices:
if index is not None and index.get_dtype() in {torch.bool, torch.uint8}:
Expand Down

0 comments on commit 93955b4

Please sign in to comment.