Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TORCHINDUCTOR_TRACE=1 asserts on self.debug_path #1752

Closed
ngimel opened this issue Oct 22, 2022 · 0 comments
Closed

TORCHINDUCTOR_TRACE=1 asserts on self.debug_path #1752

ngimel opened this issue Oct 22, 2022 · 0 comments
Assignees

Comments

@ngimel
Copy link
Contributor

ngimel commented Oct 22, 2022

After pytorch/pytorch#87438 I can't run with TORCHINDUCTOR_TRACE=1:

import torch
import torch._inductor
import torch._dynamo as torchdynamo

@torchdynamo.optimize("inductor")
def div(x, y):
    return torch.ops.aten.div(x, y)


x = torch.randn(4, dtype=torch.float, device="cuda:0", requires_grad=True)
y = torch.randn(4, dtype=torch.float, device="cuda:0")
out = div(x,y)
out.backward(torch.rand_like(x))

asserts

  File "/scratch/ngimel/work/pytorch/torch/_inductor/compile_fx.py", line 362, in bw_compiler
    return compile_fx_inner(
  File "/scratch/ngimel/work/pytorch/torch/_dynamo/debug_utils.py", line 466, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "/scratch/ngimel/work/pytorch/torch/_inductor/debug.py", line 176, in inner
    with DebugContext():
  File "/scratch/ngimel/work/pytorch/torch/_inductor/debug.py", line 243, in __enter__
    self._path = self.create_debug_dir()
  File "/scratch/ngimel/work/pytorch/torch/_inductor/debug.py", line 185, in create_debug_dir
    dynamo_utils.get_debug_dir(),
  File "/scratch/ngimel/work/pytorch/torch/_dynamo/utils.py", line 965, in get_debug_dir
    return debug_dir.get()
  File "/scratch/ngimel/work/pytorch/torch/_dynamo/utils.py", line 957, in get
    assert self.debug_path is not None
AssertionError

when run with TORCHINDUCTOR_TRACE=1
cc @mlazos

@mlazos mlazos self-assigned this Oct 25, 2022
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Oct 25, 2022
Fixes pytorch/torchdynamo#1758, pytorch/torchdynamo#1752

- minifier_launcher.py now dumps checkpoints to \<cwd\>/checkpoints when run
- a single debug directory is created per script invocation, asserts failing with no directory will no longer occur
- torchinductor debug tracing will correctly dump to the debug directory now since no prior setup is needed, (the directory was incorrectly only initialized during dynamo tracing)

cc @jansel @lezcano @fdrocha @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: #87682
Approved by: https://github.com/ezyang
@mlazos mlazos closed this as completed Oct 26, 2022
sgrigory pushed a commit to sgrigory/pytorch that referenced this issue Oct 28, 2022
Fixes pytorch/torchdynamo#1758, pytorch/torchdynamo#1752

- minifier_launcher.py now dumps checkpoints to \<cwd\>/checkpoints when run
- a single debug directory is created per script invocation, asserts failing with no directory will no longer occur
- torchinductor debug tracing will correctly dump to the debug directory now since no prior setup is needed, (the directory was incorrectly only initialized during dynamo tracing)

cc @jansel @lezcano @fdrocha @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: pytorch#87682
Approved by: https://github.com/ezyang
kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Nov 5, 2022
Fixes pytorch/torchdynamo#1758, pytorch/torchdynamo#1752

- minifier_launcher.py now dumps checkpoints to \<cwd\>/checkpoints when run
- a single debug directory is created per script invocation, asserts failing with no directory will no longer occur
- torchinductor debug tracing will correctly dump to the debug directory now since no prior setup is needed, (the directory was incorrectly only initialized during dynamo tracing)

cc @jansel @lezcano @fdrocha @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: pytorch#87682
Approved by: https://github.com/ezyang
kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Dec 10, 2022
Fixes pytorch/torchdynamo#1758, pytorch/torchdynamo#1752

- minifier_launcher.py now dumps checkpoints to \<cwd\>/checkpoints when run
- a single debug directory is created per script invocation, asserts failing with no directory will no longer occur
- torchinductor debug tracing will correctly dump to the debug directory now since no prior setup is needed, (the directory was incorrectly only initialized during dynamo tracing)

cc @jansel @lezcano @fdrocha @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: pytorch#87682
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants