Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug issue with AOTAutograd for speech_transformer/hf_GPT2/hf_T5 #85

Closed
anijain2305 opened this issue Mar 22, 2022 · 3 comments
Closed

Comments

@anijain2305
Copy link
Contributor

anijain2305 commented Mar 22, 2022

The three models - speech_transformer, hf_GPT2 and hf_T5 fail with similar type of error signature.

TorchDynamo finds static subgraphs and sends them to AOT Autograd. AOT Autograd generates the forward and backward graphs. The output of AOT Autograd is a autograd.Function (code). AOT Autograd saves some tensors for the backward pass gradient computation in the forward pass.

The issue arises in the backward pass. When we read the saved_tensors, one of the item in the saved_tensors is not of Tensor type anymore. This causes cryptic error messages like the one below. And this type changes from run to run. I have seen immutable_dict, tuple and even weakref and builtin.

ERROR:root:unhandled error
Traceback (most recent call last):
  File "torchbench.py", line 1006, in run_one_model
    new_result = model_iter_fn(model, example_inputs)
  File "torchbench.py", line 482, in forward_and_backward_pass
    def forward_and_backward_pass(mod, inputs, collect_outputs=True):
  File "torchbench.py", line 482, in forward_and_backward_pass
    def forward_and_backward_pass(mod, inputs, collect_outputs=True):
  File "torchbench.py", line 482, in forward_and_backward_pass
    def forward_and_backward_pass(mod, inputs, collect_outputs=True):
  [Previous line repeated 2 more times]
  File "/fsx/users/anijain/functorch/functorch/_src/monkey_patching.py", line 97, in _backward
    return _old_backward(*args, **kwargs)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/_tensor.py", line 395, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/autograd/function.py", line 253, in apply
    return user_fn(self, *args)
  File "/fsx/users/anijain/torchdynamo/torchdynamo/eval_frame.py", line 58, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/anijain/functorch/functorch/_src/aot_autograd.py", line 188, in backward
    out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
  File "/fsx/users/anijain/torchdynamo/torchdynamo/eval_frame.py", line 58, in _fn
    return fn(*args, **kwargs)
  File "/data/home/anijain/miniconda/envs/pytorch_dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: forward() Expected a value of type 'Tensor (inferred)' for argument 'primals_14' but instead found type 'tuple'.
Inferred 'primals_14' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 19
Value: ('___check_obj_id', '___check_tensors', '___check_type_id', '___guarded_code')

I further looked into C++ and starting printing the type of objects while saving the tensors at the end of forward pass, and reading them back in backward pass. I observed the weird behavior in this line -(https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/python_function.cpp#L834). This is called in the backward pass, when we call ctx.saved_tensors.

When I print the unpacked_var, it is a tensor. It has its dim, I can print its shape and everything.
But Py_TYPE(value)→tp_name equals immutable_dict here.
The unpack_fn is basically THPVariable_Wrap - (https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/python_function.cpp#L849).

For completeness, adding images for the failure

Repro - python torchbench.py --training --devices=cuda --accuracy-aot-nop --only=hf_GPT2
image

Repro - python torchbench.py --training --devices=cuda --accuracy-aot-nop --only=speech_transformer
image

Repro - python torchbench.py --training --devices=cuda --accuracy-aot-nop --only=hf_T5
image

@anijain2305
Copy link
Contributor Author

anijain2305 commented Apr 12, 2022

Repro for the bug

import torch
from torch.nn import *
from functorch.compile import aot_module_simplified, nop
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_debug_strategy1
from torchdynamo.eval_frame import skip_code

class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_ln_1_weight = torch.nn.Parameter(torch.empty([768], dtype=torch.float32))
        self.self_ln_1_bias = torch.nn.Parameter(torch.empty([768], dtype=torch.float32))

    def forward(self, hidden_states):
        self_ln_1_weight = self.self_ln_1_weight
        self_ln_1_bias = self.self_ln_1_bias
        layer_norm = torch.nn.functional.layer_norm(hidden_states, (768,), weight = self_ln_1_weight, bias = self_ln_1_bias, eps = 1e-05)
        view = layer_norm.view(-1, 768)
        return (view, )

modules = []
aot_modules = []
num_hidden_layers = 12
for idx in range(num_hidden_layers):
    mod = Bar().to(device="cuda")
    modules.append(mod)
    aot_modules.append(aot_module_simplified(mod, nop))

class MockModule(torch.nn.Module):
    def __init__(self, use_aot):
        super().__init__()
        self.embed_dim = 768
        self.use_aot = use_aot
    
    def forward(self, x):
        hidden_states = x
        for idx in range(num_hidden_layers):
            if self.use_aot:
                outputs = aot_modules[idx](hidden_states)
            else:
                outputs = modules[idx](hidden_states)
            hidden_states = outputs[0]
        return hidden_states


x = torch.randn(torch.Size([4, 512, 768]), device="cuda", requires_grad=True)

# Torch baseline
mod = MockModule(use_aot=False).to("cuda")
ref = mod(x)
ref[0].sum().backward()

# AOT - no Dynamo
# aot_mod = MockModule(use_aot=True).to("cuda")
# res = aot_mod(x)
# res[0].sum().backward()

# Dynamo with AOT
with torchdynamo.optimize(aot_autograd_debug_strategy1):
    # Use aot should be false. Backend is already set to be AOT
    aot_mod = MockModule(use_aot=False).to("cuda")
    skip_code(aot_mod.forward.__code__)
    res = aot_mod(x)
    res[0].sum().backward()

print("Success")

@anijain2305
Copy link
Contributor Author

@Chillee has been looking at this and isolated the problem with even smaller repro than the above. Assigning this to him.

@Chillee
Copy link
Contributor

Chillee commented Apr 19, 2022

Resolved in pytorch/pytorch#75933

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
Development

No branches or pull requests

4 participants