Skip to content

Commit

Permalink
[AOTAutograd / Functionalization] Fix incorrect expand_inverse (#122114)
Browse files Browse the repository at this point in the history
This is a rebase of #114538,
originally submited by @jon-chuang.

Fixes #114302

Pull Request resolved: #122114
Approved by: https://github.com/bdhirsh
  • Loading branch information
zou3519 authored and pytorchmergebot committed Mar 18, 2024
1 parent ba69dc6 commit e6cf3e9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/FunctionalInverses.cpp
Expand Up @@ -174,8 +174,8 @@ Tensor FunctionalInverses::expand_inverse(const Tensor& base, const Tensor& muta
return mutated_view.as_strided_symint(
base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
} else {
return at::sum_to(
mutated_view,
return base + at::sum_to(
mutated_view - base,
base.sym_sizes(),
/*always_return_non_view=*/inverse_return_mode == InverseReturnMode::NeverView
);
Expand Down
59 changes: 59 additions & 0 deletions test/dynamo/test_aot_autograd.py
@@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import re
import unittest
from textwrap import dedent
from unittest.mock import patch

Expand Down Expand Up @@ -1062,6 +1063,64 @@ def f(x):
):
f(x)

def test_aot_autograd_expand_mutation_functionalizes(self):
def fn(x):
y = x.expand(3, *x.shape)
y[0, 0].add_(5)
return y

opt_fn = torch.compile(fn, backend="aot_eager")

x = torch.arange(6)
x_opt = x.clone().detach()
self.assertEqual(fn(x), opt_fn(x_opt))
self.assertEqual(x, x_opt)

def test_aot_autograd_expand_mutation_backwards(self):
def fn(x, z):
y = x.expand(3, *x.shape)
y[1, 1].mul_(5)
ret = y * z
return ret

opt_fn = torch.compile(fn, backend="aot_eager")

x = torch.arange(6, dtype=torch.float)
z = x.clone().detach()
x_opt = x.clone().detach()
z_opt = x.clone().detach()

z.requires_grad = True
z_opt.requires_grad = True

res = fn(x, z)
opt_res = opt_fn(x_opt, z_opt)

self.assertEqual(res, opt_res)

res.sum().backward()
opt_res.sum().backward()

self.assertEqual(x, x_opt)
self.assertEqual(z.grad, z_opt.grad)

# We don't know how to catch multiple mutations to the same memory location
@unittest.expectedFailure
def test_aot_autograd_expand_mutation_error(self):
def fn(x):
y = x.expand(3, *x.shape)
y[0:3, 0].add_(5)
return y

opt_fn = torch.compile(fn, backend="aot_eager")

x = torch.arange(6)
x_opt = x.clone().detach()
with self.assertRaises(Exception):
fn(x)
with self.assertRaises(Exception):
opt_fn(x_opt)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down

0 comments on commit e6cf3e9

Please sign in to comment.