Skip to content

Commit

Permalink
[WIP] reduce overhead
Browse files Browse the repository at this point in the history
ghstack-source-id: 27bbfa172ea0f5ba67e6cf1308b13b606a1aaa33
Pull Request resolved: #98529
  • Loading branch information
eellison committed Apr 6, 2023
1 parent c905eb0 commit 0e5a104
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 66 deletions.
56 changes: 51 additions & 5 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
)


def cdata(t):
return t.untyped_storage()._cdata


class TestCase(TorchTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -127,13 +131,15 @@ def curr_node(self):
def get_root_children(self):
return [root.num_descendants() for root in self.get_roots()]

def cudagraphify_impl(self, *args, **kwargs):
def cudagraphify_impl(
self, *args, is_inference=True, is_backward=False, **kwargs
):
return tree_cudagraphify_impl(
*args,
**kwargs,
device_index=self.device_idx,
is_backward=False,
is_inference=True,
is_inference=is_inference,
is_backward=is_backward,
)

@staticmethod
Expand Down Expand Up @@ -418,7 +424,7 @@ def foo2(args):
self.assertEqual(all_live_block_count(), 0)

def test_aliased_storage_single_weakref(self):
@torch.compile
@torch.compile(mode="reduce-overhead")
def foo(x):
x = x * 20
x_alias = x[0]
Expand Down Expand Up @@ -447,6 +453,47 @@ def foo(x):

self.assertFalse(self.get_manager().new_graph_id().id == 0)

def test_aliasing_static_ref(self):
class Mod(torch.nn.Linear):
def forward(self, x):
return self.weight.T @ x, self.weight.T, self.weight[0:4]

m = Mod(10, 10).cuda()

@torch.compile(mode="reduce-overhead")
def foo(mod, x):
return mod(x)

@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[2:]

x = torch.rand([10, 10], device="cuda", requires_grad=True)
param_c = cdata(m.weight)
for _ in range(3):
# print("Runnng foo")
out1, alias_1, alias_2 = foo(m, x)
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)

# print("Runnng foo2")
out2 = foo2(out1)
out2.sum().backward()
self.assertEqual(cdata(out1), cdata(out2))

def test_aliased_static_parameter(self):
inp = torch.rand([20, 20], device="cuda")

def foo(args):
x = args[0]
args.clear()
return (x[0],)

foo_cg = self.cudagraphify_impl(foo, [inp], (0,))

for _ in range(3):
out = foo_cg([inp])[0]
self.assertEqual(cdata(inp), cdata(out))

@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_aliased_output_checkpoint(self):
def foo(args):
Expand Down Expand Up @@ -580,7 +627,6 @@ def foo(x):
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)

inp = torch.rand([4, 4], requires_grad=True, device="cuda")
print("Input ID", id(inp))
out = foo(inp)
out.sum().backward()

Expand Down

0 comments on commit 0e5a104

Please sign in to comment.