Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce overhead in CUDAGraph Trees #98529

Closed
wants to merge 10 commits into from
56 changes: 51 additions & 5 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
)


def cdata(t):
return t.untyped_storage()._cdata


class TestCase(TorchTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -127,13 +131,15 @@ def curr_node(self):
def get_root_children(self):
return [root.num_descendants() for root in self.get_roots()]

def cudagraphify_impl(self, *args, **kwargs):
def cudagraphify_impl(
self, *args, is_inference=True, is_backward=False, **kwargs
):
return tree_cudagraphify_impl(
*args,
**kwargs,
device_index=self.device_idx,
is_backward=False,
is_inference=True,
is_inference=is_inference,
is_backward=is_backward,
)

@staticmethod
Expand Down Expand Up @@ -418,7 +424,7 @@ def foo2(args):
self.assertEqual(all_live_block_count(), 0)

def test_aliased_storage_single_weakref(self):
@torch.compile
@torch.compile(mode="reduce-overhead")
def foo(x):
x = x * 20
x_alias = x[0]
Expand Down Expand Up @@ -447,6 +453,47 @@ def foo(x):

self.assertFalse(self.get_manager().new_graph_id().id == 0)

def test_aliasing_static_ref(self):
class Mod(torch.nn.Linear):
def forward(self, x):
return self.weight.T @ x, self.weight.T, self.weight[0:4]

m = Mod(10, 10).cuda()

@torch.compile(mode="reduce-overhead")
def foo(mod, x):
return mod(x)

@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[2:]

x = torch.rand([10, 10], device="cuda", requires_grad=True)
param_c = cdata(m.weight)
for _ in range(3):
# print("Runnng foo")
out1, alias_1, alias_2 = foo(m, x)
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)

# print("Runnng foo2")
out2 = foo2(out1)
out2.sum().backward()
self.assertEqual(cdata(out1), cdata(out2))

def test_aliased_static_parameter(self):
inp = torch.rand([20, 20], device="cuda")

def foo(args):
x = args[0]
args.clear()
return (x[0],)

foo_cg = self.cudagraphify_impl(foo, [inp], (0,))

for _ in range(3):
out = foo_cg([inp])[0]
self.assertEqual(cdata(inp), cdata(out))

@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_aliased_output_checkpoint(self):
def foo(args):
Expand Down Expand Up @@ -580,7 +627,6 @@ def foo(x):
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)

inp = torch.rand([4, 4], requires_grad=True, device="cuda")
print("Input ID", id(inp))
out = foo(inp)
out.sum().backward()

Expand Down