Skip to content

Commit

Permalink
[caffe2/tools/autograd] Fix non-determinism in code gen (#101425)
Browse files Browse the repository at this point in the history
Fix several cases of leaking set-iteration-order to generated sources, causing non-determinism in generated code.

Pull Request resolved: #101425
Approved by: https://github.com/albanD
  • Loading branch information
andrewjcg authored and pytorchmergebot committed May 16, 2023
1 parent a837609 commit 799ef7e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
6 changes: 3 additions & 3 deletions tools/autograd/gen_autograd_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,9 @@ def save_var(var: SavedAttribute, is_output: bool) -> None:
PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
)

for var in info.all_saved_inputs:
for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
save_var(var, is_output=False)
for var in info.all_saved_outputs:
for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)):
save_var(var, is_output=True)

# lock the mutex when we release variables and in Node::apply to protect thread safety
Expand All @@ -770,7 +770,7 @@ def save_var(var: SavedAttribute, is_output: bool) -> None:
# Generate aliases for gradients named for returned values.
body.extend(
f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];"
for name in info.used_named_gradients
for name in sorted(info.used_named_gradients)
)

def emit_derivative(
Expand Down
4 changes: 3 additions & 1 deletion tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ def gen(
def gen_tags_enum() -> Dict[str, str]:
return {
"enum_of_valid_tags": (
"".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])
"".join(
[f'\n.value("{tag}", at::Tag::{tag})' for tag in sorted(valid_tags)]
)
)
}

Expand Down
10 changes: 5 additions & 5 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def gen_variable_type(
# dispatch key that appears in derivatives.yaml
def wrapper_registrations(used_keys: Set[str]) -> str:
library_impl_macro_list: List[str] = []
for key in used_keys:
for key in sorted(used_keys):
dispatch_key = key
if key == "Default":
dispatch_key = "Autograd"
Expand All @@ -843,7 +843,7 @@ def wrapper_registrations(used_keys: Set[str]) -> str:
"type_derived_method_definitions": "\n\n".join(
[
"${" + f"type_derived_method_definitions_{key}" + "}"
for key in used_keys
for key in sorted(used_keys)
]
),
"wrapper_registrations": wrapper_registrations(used_keys),
Expand All @@ -854,8 +854,8 @@ def wrapper_registrations(used_keys: Set[str]) -> str:
fm2 = FileManager(install_dir=out, template_dir=out + "/templates", dry_run=False)

sharded_keys = set(
[f"type_derived_method_definitions_{key}" for key in used_keys]
+ [f"wrapper_registrations_{key}" for key in used_keys]
[f"type_derived_method_definitions_{key}" for key in sorted(used_keys)]
+ [f"wrapper_registrations_{key}" for key in sorted(used_keys)]
)
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
# template regarding sharding of the generated files.
Expand Down Expand Up @@ -1337,7 +1337,7 @@ def save_variables(
) -> Sequence[str]:
# assign the saved variables to the generated grad_fn
stmts: List[str] = []
for arg in saved_variables:
for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
name = (
arg.nctype.name.name
if isinstance(arg.nctype.name, SpecialArgName)
Expand Down

0 comments on commit 799ef7e

Please sign in to comment.