Skip to content

Commit

Permalink
Refactor partitioner and clean it up (#126318)
Browse files Browse the repository at this point in the history
Pull Request resolved: #126318
Approved by: https://github.com/anijain2305
  • Loading branch information
Chillee authored and ZelboK committed May 19, 2024
1 parent b1770bd commit c271827
Show file tree
Hide file tree
Showing 4 changed files with 547 additions and 484 deletions.
1 change: 0 additions & 1 deletion functorch/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torch._functorch.partitioners import (
default_partition,
draw_graph,
draw_joint_graph,
min_cut_rematerialization_partition,
)
from torch._functorch.python_key import pythonkey_decompose
64 changes: 0 additions & 64 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4835,70 +4835,6 @@ def f(a, b, c, d):
self.assertEqual(get_num_ins_outs(fw_graph), (4, 2))
self.assertEqual(get_num_ins_outs(bw_graph), (2, 4))

@unittest.skipIf(not USE_NETWORKX, "networkx not available")
def test_min_cut_partitioner_recomputable_ops(self):
def f(x):
return x * x * x

recomputable_ops = []
partition_fn = partial(
min_cut_rematerialization_partition, recomputable_ops=recomputable_ops
)

fw_graph, bw_graph = get_fw_bw_graph(
f, [torch.randn(3, requires_grad=True)], partition_fn
)
# Expected forward graph:
# opcode name target args kwargs
# ------------- --------- --------------- -------------------------- --------
# placeholder primals_1 primals_1 () {}
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
# output output output ([mul_1, primals_1, mul],) {}
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
# Expected backward graph:
# opcode name target args kwargs
# ------------- ---------- --------------- ----------------------- --------
# placeholder primals_1 primals_1 () {}
# placeholder mul mul () {}
# placeholder tangents_1 tangents_1 () {}
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
# call_function add aten.add.Tensor (mul_2, mul_4) {}
# call_function add_1 aten.add.Tensor (add, mul_4) {}
# output output output ([add_1],) {}
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))

recomputable_ops = [torch.ops.aten.mul]
partition_fn = partial(
min_cut_rematerialization_partition, recomputable_ops=recomputable_ops
)
fw_graph, bw_graph = get_fw_bw_graph(
f, [torch.randn(3, requires_grad=True)], partition_fn
)
# Expected forward graph:
# opcode name target args kwargs
# ------------- --------- --------------- ---------------------- --------
# placeholder primals_1 primals_1 () {}
# call_function mul aten.mul.Tensor (primals_1, primals_1) {}
# call_function mul_1 aten.mul.Tensor (mul, primals_1) {}
# output output output ([mul_1, primals_1],) {}
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
# Expected backward graph:
# opcode name target args kwargs
# ------------- ---------- --------------- ----------------------- --------
# placeholder primals_1 primals_1 () {}
# placeholder tangents_1 tangents_1 () {}
# call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED
# call_function mul_2 aten.mul.Tensor (tangents_1, mul) {}
# call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {}
# call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {}
# call_function add aten.add.Tensor (mul_2, mul_4) {}
# call_function add_1 aten.add.Tensor (add, mul_4) {}
# output output output ([add_1],) {}
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))

def test_contiguous(self):
# The test simulates the condition where transpose followed by view
# happens in the backward pass.
Expand Down
4 changes: 3 additions & 1 deletion torch/_functorch/compile_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# mypy: ignore-errors


from typing import Callable

import torch
import torch.fx as fx
from torch.utils import _pytree as pytree
Expand All @@ -9,7 +11,7 @@
aten = torch.ops.aten


def get_aten_target(node):
def get_aten_target(node: fx.Node) -> Callable:
if hasattr(node.target, "overloadpacket"):
return node.target.overloadpacket
return node.target
Expand Down

0 comments on commit c271827

Please sign in to comment.