Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOT Autograd fails to get correct grads for view and Inplace Relu #514

Open
anijain2305 opened this issue Feb 17, 2022 · 4 comments
Open
Assignees

Comments

@anijain2305
Copy link
Contributor

anijain2305 commented Feb 17, 2022

While working on TorchDynamo + AOT integration, I came across the following bug

import torch
from torch.nn import *
from functorch.compile import print_compile, aot_module
import copy

class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # self.conv = Conv2d(3, 2, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        # self.instance_norm = InstanceNorm2d(2, affine=True, track_running_stats=True)
        self.relu = ReLU(inplace=True)

    def forward(self, x : torch.Tensor):
        # self_main_0 = self.conv(x)
        # self_main_1 = self.instance_norm(self_main_0)
        self_main_0 = x * 2
        self_main_1 = self_main_0.view([1, 3, 128, 128])
        self_main_2 = self.relu(self_main_1)
        return self_main_2



mod = Bar().to(device="cuda")
# Reduce randomness bits
mod.eval()

inp0 = torch.randn(1, 3, 128, 128, device='cuda', requires_grad=True)
inputs = (inp0, )

cloned_inp0 = inp0.clone().detach().requires_grad_(True)
cloned_inputs = (cloned_inp0, )

# Reference calculation
mod.zero_grad()
duplicated_mod = copy.deepcopy(mod)
ref = duplicated_mod(*inputs)
ref.sum().backward()
ref_grads = []
for param in duplicated_mod.parameters():
    ref_grads.append(param.grad)



# AOT stuff
fx_mod = torch.fx.symbolic_trace(mod)
aot_mod = aot_module(fx_mod, print_compile)
aot_mod.zero_grad()
with torch.jit.fuser("fuser2"):
    res = aot_mod(*cloned_inputs)
    res.sum().backward()

res_grads = []
for param in aot_mod.parameters():
    res_grads.append(param.grad)


assert torch.allclose(ref, res)

for (a, b) in zip(ref_grads, res_grads):
    assert torch.allclose(a, b, atol=1e-4, rtol=1e-4), print(a, b)

for (a, b) in zip(inputs, cloned_inputs):
    assert torch.allclose(a.grad, b.grad, atol=1e-4, rtol=1e-4), print(a.grad, b.grad)

view + inplace_Relu seems to give wrong backward trace.

@Chillee @jansel

@Chillee
Copy link
Contributor

Chillee commented Feb 17, 2022

Do you need the instance norm/batch norm stuff here?

@anijain2305
Copy link
Contributor Author

Not necessarily. I kept it there because initially the graph was larger and consisted instance_norm. I minimized the example by looking at the FX graphs to make it easier.

@Chillee Chillee self-assigned this Feb 24, 2022
@anijain2305
Copy link
Contributor Author

anijain2305 commented Mar 3, 2022

Another issue - pytorch_struct

import torch
from torch.nn import *
import functorch
from functorch.compile import memory_efficient_fusion

class FxModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, einsum, unsqueeze_3):
        exp = einsum.exp()
        gather = torch.gather(exp, 3, unsqueeze_3);  expand = unsqueeze_3 = None
        return gather

inp0 = torch.randn([1, 3, 2, 10], device='cuda', requires_grad=True)
inp1 = torch.ones([1, 3, 2, 1], dtype=torch.int64, device='cuda')
inps = [inp0, inp1]

cloned_inps = [x.clone().detach() for x in inps]
cloned_inps[0].requires_grad_(True)
cloned_inps[0].grad = None

mod = FxModule().to(device="cuda")
ref = mod(*inps)
ref.sum().backward()

aot_mod = memory_efficient_fusion(mod)
res = aot_mod(*cloned_inps)
res.sum().backward()

assert torch.allclose(ref, res)
print(inps[0].grad)
print(cloned_inps[0].grad)
assert torch.allclose(inps[0].grad, cloned_inps[0].grad)

print("Success")

@anijain2305
Copy link
Contributor Author

anijain2305 commented Mar 4, 2022

Both these cases require special/hacky handling in AOT Autograd if they have to be supported quickly.

A better approach is the functionalization pass. CC'ing @bdhirsh to try functionalization on these two test cases as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants