Skip to content

Commit

Permalink
[WIP] reduce overhead
Browse files Browse the repository at this point in the history
ghstack-source-id: d820ef3702340468a90b5189ec37b808b3018427
Pull Request resolved: #98529
  • Loading branch information
eellison committed Apr 6, 2023
1 parent c905eb0 commit 5468ae4
Show file tree
Hide file tree
Showing 4 changed files with 363 additions and 67 deletions.
87 changes: 82 additions & 5 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
)


def cdata(t):
return t.untyped_storage()._cdata


class TestCase(TorchTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -127,13 +131,15 @@ def curr_node(self):
def get_root_children(self):
return [root.num_descendants() for root in self.get_roots()]

def cudagraphify_impl(self, *args, **kwargs):
def cudagraphify_impl(
self, *args, is_inference=True, is_backward=False, **kwargs
):
return tree_cudagraphify_impl(
*args,
**kwargs,
device_index=self.device_idx,
is_backward=False,
is_inference=True,
is_inference=is_inference,
is_backward=is_backward,
)

@staticmethod
Expand Down Expand Up @@ -418,7 +424,7 @@ def foo2(args):
self.assertEqual(all_live_block_count(), 0)

def test_aliased_storage_single_weakref(self):
@torch.compile
@torch.compile(mode="reduce-overhead")
def foo(x):
x = x * 20
x_alias = x[0]
Expand Down Expand Up @@ -447,6 +453,78 @@ def foo(x):

self.assertFalse(self.get_manager().new_graph_id().id == 0)

def test_aliasing_static_ref(self):
class Mod(torch.nn.Linear):
def forward(self, x):
return self.weight.T @ x, self.weight.T, self.weight[0:4]

m = Mod(10, 10).cuda()

@torch.compile(mode="reduce-overhead")
def foo(mod, x):
return mod(x)

@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[2:]

x = torch.rand([10, 10], device="cuda", requires_grad=True)
param_c = cdata(m.weight)
for _ in range(3):
# print("Runnng foo")
out1, alias_1, alias_2 = foo(m, x)
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)

# print("Runnng foo2")
out2 = foo2(out1)
out2.sum().backward()
self.assertEqual(cdata(out1), cdata(out2))

def test_aliased_static_parameter(self):
inp = torch.rand([20, 20], device="cuda")

def foo(args):
x = args[0]
args.clear()
return (x[0],)

foo_cg = self.cudagraphify_impl(foo, [inp], (0,))

for _ in range(3):
out = foo_cg([inp])[0]
self.assertEqual(cdata(inp), cdata(out))

def test_alias_of_earlier_graph(self):
inp = torch.rand([20, 20], device="cuda")

def foo(args):
x = args[0]
args.clear()
out = x + x
return (out, out[1:])

foo_cg = foo
# foo_cg = self.cudagraphify_impl(foo, [inp], (0,), is_inference=False)

def foo2(args):
x = args[0]
args.clear()
return x[1:]

foo2_cg = foo2

# out = foo([inp])[0]

# foo2_cg = self.cudagraphify_impl(foo2, [out], is_inference=False)

for _ in range(1):
out1, out2 = foo_cg([inp])
self.assertEqual(cdata(out1), cdata(out2))

out3 = foo2_cg([out1])
self.assertEqual(cdata(out3), cdata(out1))
del out1, out2, out3

@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_aliased_output_checkpoint(self):
def foo(args):
Expand Down Expand Up @@ -580,7 +658,6 @@ def foo(x):
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)

inp = torch.rand([4, 4], requires_grad=True, device="cuda")
print("Input ID", id(inp))
out = foo(inp)
out.sum().backward()

Expand Down

0 comments on commit 5468ae4

Please sign in to comment.