Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/dynamo/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,34 @@ def fn(x, y, z):
any(e.name == "TorchDynamo Cache Lookup" for e in prof.events())
)

def test_profiler_enabled_export(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = torch.sin(x)
if torch.autograd._profiler_enabled():
return torch.cos(x)
else:
return torch.sigmoid(x)

mod = Mod()

x = torch.randn(4)
opt_mod = torch._dynamo.export(mod, (x))

ref = mod(x)
res = opt_mod.graph_module(x)
self.assertEqual(ref, res)

with torch.autograd.profiler.profile():
ref = mod(x)
# Reexport because export skips guards
opt_mod = torch._dynamo.export(mod, (x))
res = opt_mod.graph_module(x)
self.assertEqual(ref, res)

def test_profiler_dynamo_compiled_region(self):
def fn(x, y):
r = y.sum(dim=1)
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,11 @@ def default_debug_dir_root() -> str:
env_name_default="TORCH_DYNAMO_RUN_GC_AFTER_COMPILE",
)

# Does not graph break on torch.autograd._profiler_enabled if set to True. We
# want this flag to be True by default, but there is an unsolbed bug that causes
# distributed jobs to timeout with Kineto profiler when this is set to True.
constant_fold_autograd_profiler_enabled = False

# Takes the function/module decorated with torch.compile and passes it through a
# wrapper. This ensures that nn.module hooks are also compiled in the same frame.
wrap_top_frame = False
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,6 +2038,7 @@ def fakify_with_ambient(
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
constant_fold_autograd_profiler_enabled=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
),
_compiling_state_context(),
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/functional_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
automatic_dynamic_shapes=False,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
constant_fold_autograd_profiler_enabled=True,
log_graph_in_out_metadata=True,
)

Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
"torch.compiler.is_exporting": TorchInGraphFunctionVariable,
"torch.autograd._profiler_enabled": SkipFunctionVariable,
"torch._C._to_dlpack": SkipFunctionVariable,
"torch.to_dlpack": SkipFunctionVariable,
# We graph break on RNG state setters or getters like
Expand Down Expand Up @@ -2445,6 +2444,7 @@
"torch.atleast_3d",
"torch.autograd._calculate_shape",
"torch.autograd._is_checkpoint_valid",
"torch.autograd._profiler_enabled",
"torch.autograd._make_grads",
"torch.autograd._register_py_tensor_class_for_device",
"torch.autograd._tensor_or_tensors_to_tuple",
Expand Down
20 changes: 20 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,26 @@ def call_obj_hasattr(self, tx: "InstructionTranslator", name):
def can_constant_fold_through(self):
if self.value in constant_fold_functions:
return True

if (
self.value is torch.autograd._profiler_enabled
and config.constant_fold_autograd_profiler_enabled
):
# The relevant flag is enabled only for export. One might wonder
# why?
#
# Actually we would like to not graph break even in the case of
# Dynamo. But there is a weird-unsolved bug with Kineto + Dynamo
# when there are distributed jobs that lead to NCCL timeouts. This
# bug is a rare edege case, but we have not been able to root cause
# it yet. See https://www.internalfb.com/sevmanager/view/560336 for
# more details.
#
# So is this safe for export? Yes, for export, we do not anticipate
# JIT tracing in distributed job training, and the weird edge-case
# interaction with Kineto is not a valid usecase. So, this is ok.
return True

return getattr(self.value, "__module__", None) == "math"


Expand Down
Loading