Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOT Autograd - LSTM - grads not generated (model tts_angular) #586

Closed
anijain2305 opened this issue Mar 10, 2022 · 3 comments · Fixed by pytorch/torchdynamo#975
Closed

Comments

@anijain2305
Copy link
Contributor

This is a subgraph from tts_angular model

The generated backward pass has many None outputs, suggesting that that requires_grad is somehow not passed correctly when LSTM cell is used.

import functorch
import torch
from torch.nn import *

from functorch.compile import memory_efficient_fusion, print_compile, aot_module, decomposition_table
import importlib
import torchdynamo
import copy
import itertools
from torchdynamo.optimizations import backends

class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_lstm = LSTM(40, 768, batch_first=True)
        self.weight = Parameter(torch.randn(torch.Size([256, 768], requires_grad=True)))



    def forward(self, x):
        self_lstm = self.self_lstm(x);  x = None
        getitem = self_lstm[0];  self_lstm = None
        linear = torch.nn.functional.linear(getitem, self.weight, bias = None);  getitem = self_linear_weight = None
        return (linear,)

def reduce_out(out):
    if isinstance(out, torch.Tensor):
        return torch.sigmoid(out).sum()
    elif isinstance(out, (tuple, list)):
        return sum([reduce_out(x) for x in out])
    raise NotImplementedError("Don't know how to reduce", type(out))


def checkpoint_params(gm):
    rng_state = torch.clone(torch.random.get_rng_state())
    saved_state = []
    for param in itertools.chain(gm.parameters(), gm.buffers()):
        saved_state.append((param, param._version, torch.clone(param)))

    def restore():
        with torch.no_grad():
            torch.random.set_rng_state(rng_state)
            for param, version, original_value in saved_state:
                if param._version != version:
                    param.copy_(original_value)

    return restore


def clone_me(x):
    if x is None:
        return None
    return x.detach().clone().requires_grad_(x.requires_grad)

def collect_results(model, prediction, loss, example_inputs):
    results = []
    results.append(prediction)
    results.append(loss)
    for param in model.parameters():
        results.append(clone_me(param.grad))
    for example in example_inputs:
        if isinstance(example, list):
            for inp in example:
                results.append(clone_me(inp.grad))
        else:
            results.append(clone_me(example.grad))
    return results

def same(a, b):
    """Check correctness to see if a and b match"""
    if isinstance(a, (list, tuple, torch.nn.ParameterList)):
        if not isinstance(b, (list, tuple)):
            return False
        return all(same(ai, bi) for ai, bi in zip(a, b))
    elif isinstance(a, torch.Tensor):
        assert isinstance(b, torch.Tensor)
        if not  torch.allclose(a, b, atol=1e-5, rtol=1e-5):
            print(a.flatten()[1], b.flatten()[1])
            print(a.size())
        return torch.allclose(a, b, atol=1e-5, rtol=1e-5)
    elif isinstance(a, (int, float, type(None), bool, torch.device)):
        return a == b
    else:
        raise RuntimeError(f"unsupported type: {type(a).__name__}")


def clone_inputs(inputs):
    clones = [clone_me(x) for x in inputs]
    for c in clones:
        c.grad = None
    return clones


def get_results(mod, inputs):
    cloned_inputs = clone_inputs(inputs)
    mod.zero_grad(True)
    ref = mod(*cloned_inputs)
    l = reduce_out(ref)
    l.backward()
    ref_results = collect_results(mod, ref, l, cloned_inputs)
    return ref_results

def test_module():
    inp0 = torch.randn(64, 50, 40, device="cuda", requires_grad=True)
    inputs = [inp0, ]

    mod = Bar().to(device="cuda")
    restore = checkpoint_params(mod)
    orig_mod_results = get_results(mod, inputs)

    restore()
    new_mod = copy.deepcopy(mod)
    copy_mod_results = get_results(new_mod, inputs)
    print("Are Orig_mod and Copy_mod same:", same(orig_mod_results, copy_mod_results))
    # assert same(orig_mod_results, copy_mod_results), "Deepcopy of a mod fails, what the hell"

    restore()
    aot_mod = aot_module(mod, fw_compiler=print_compile)
    aot_mod_results = get_results(aot_mod, inputs)

    print("Recheck Are Orig_mod and Copy_mod same:", same(orig_mod_results, copy_mod_results))
    print("Are Orig_mod and AOT_mod same:", same(orig_mod_results, aot_mod_results))
    print("Are Copy_mod and AOT_mod same:", same(copy_mod_results, aot_mod_results))

test_module()
@ezyang
Copy link
Contributor

ezyang commented Aug 12, 2022

This now is erroring

   File "/raid/ezyang/pytorch-scratch2/torch/nn/modules/rnn.py", line 774, in forward                                                       
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/utils/_python_dispatch.py", line 74, in wrapped                                                
    return f(self, *args, **kwargs)                                                                                                        
  File "/raid/ezyang/pytorch-scratch2/torch/fx/experimental/proxy_tensor.py", line 408, in __torch_dispatch__                              
    return proxy_call(self, func_overload, args, kwargs)                                                                                   
  File "/raid/ezyang/pytorch-scratch2/torch/fx/experimental/proxy_tensor.py", line 168, in proxy_call                                      
    proxy_res = func_overload(*proxy_args, **proxy_kwargs)                                                                                 
  File "/raid/ezyang/pytorch-scratch2/torch/_ops.py", line 60, in __call__                                                                 
    return self._op(*args, **kwargs or {})                                                                                                 
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 321, in __torch_function__                                                  
    return tracer.create_proxy('call_function', orig_method, args, kwargs,                                         
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 65, in create_proxy                                                         
    args_ = self.create_arg(args)                                                                                                          
  File "/raid/ezyang/pytorch-scratch2/torch/fx/experimental/proxy_tensor.py", line 347, in create_arg                                      
    return super().create_arg(a)                                                                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/fx/_symbolic_trace.py", line 343, in create_arg                                                
    return super().create_arg(a)                                                                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 127, in create_arg                                                          
    return type(a)(self.create_arg(elem) for elem in a)                                                                                    
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 127, in <genexpr>                                                           
    return type(a)(self.create_arg(elem) for elem in a)                                                                                    
  File "/raid/ezyang/pytorch-scratch2/torch/fx/experimental/proxy_tensor.py", line 347, in create_arg                                      
    return super().create_arg(a)                                                                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/fx/_symbolic_trace.py", line 343, in create_arg                                                
    return super().create_arg(a)                                                                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 153, in create_arg                                                          
    raise NotImplementedError(f"argument of type: {type(a)}")        
NotImplementedError: argument of type: <class 'torch.storage.UntypedStorage'>   

Last time I saw this it was because we tried to copy a Proxy

@ezyang
Copy link
Contributor

ezyang commented Aug 23, 2022

I think the minimal repro may no longer be valid. When I run the original tts_angular it passes

$ python benchmarks/torchbench.py --training --devices=cuda --accuracy-aot-nop --use-eval-mode -k tts_angular  

@ezyang ezyang closed this as completed Aug 23, 2022
@ezyang ezyang reopened this Aug 23, 2022
@ezyang
Copy link
Contributor

ezyang commented Aug 23, 2022

OK it turns out inductor still triggers this

./benchmarks/torchbench.py --inductor  -dcuda --no-skip -k tts_angular

ezyang added a commit to pytorch/torchdynamo that referenced this issue Aug 23, 2022
They call Tensor.set_ internally with Storage, which is no go for AOTAutograd.
Inline into them so that we can graph break.

Fixes pytorch/functorch#586

Test strategy:

```
./benchmarks/torchbench.py --inductor  -dcuda --no-skip -k tts_angular
```

Note that inductor is still failing, but differently, after this PR.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
ezyang added a commit to pytorch/torchdynamo that referenced this issue Aug 23, 2022
They call Tensor.set_ internally with Storage, which is no go for AOTAutograd.
Inline into them so that we can graph break.

Fixes pytorch/functorch#586

Test strategy:

```
./benchmarks/torchbench.py --inductor  -dcuda --no-skip -k tts_angular
```

Note that inductor is still failing, but differently, after this PR.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants