Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions torch/_functorch/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,9 +1506,12 @@ def dp_knapsack(


def _optimize_runtime_with_given_memory(
joint_graph: fx.Graph,
memory: List[float],
runtimes: List[float],
max_memory: float,
node_info: NodeInfo,
all_recomputable_banned_nodes: List[fx.Node],
) -> Tuple[float, List[int], List[int]]:
SOLVER = config.activation_memory_budget_solver
if SOLVER == "greedy":
Expand All @@ -1517,6 +1520,11 @@ def _optimize_runtime_with_given_memory(
return ilp_knapsack(memory, runtimes, max_memory)
elif SOLVER == "dp":
return dp_knapsack(memory, runtimes, max_memory)
elif callable(SOLVER):
saved_node_idx, recomp_node_idx = SOLVER(
memory, joint_graph, max_memory, node_info, all_recomputable_banned_nodes
)
return (0.0, saved_node_idx, recomp_node_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, will rebase changes onto tuple

else:
raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}")

Expand Down Expand Up @@ -1572,7 +1580,9 @@ def realize_symbol(d):


def choose_saved_values_set(
joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1
joint_graph: fx.Graph,
node_info: NodeInfo,
memory_budget=1,
) -> List[fx.Node]:
if memory_budget > 1 or memory_budget < 0:
raise RuntimeError(
Expand Down Expand Up @@ -1680,18 +1690,24 @@ def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]:
]
from torch.utils._mode_utils import no_dispatch

def get_saved_values_knapsack(memory_budget):
def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
with no_dispatch():
(
expected_runtime,
saved_node_idxs,
recomputable_node_idxs,
) = _optimize_runtime_with_given_memory(
memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0)
joint_graph,
memories_banned_nodes,
runtimes_banned_nodes,
max(memory_budget, 0),
node_info,
all_recomputable_banned_nodes,
)
dont_ban = set()
for idx in recomputable_node_idxs:
dont_ban.add(all_recomputable_banned_nodes[idx])
if idx in all_recomputable_banned_nodes:
dont_ban.add(all_recomputable_banned_nodes[idx])
assert dont_ban.issubset(all_recomputable_banned_nodes)

saved_values, _ = solve_min_cut(
Expand All @@ -1706,7 +1722,7 @@ def get_saved_values_knapsack(memory_budget):
options = []
for sweep_memory_budget in range(100, -1, -5):
saved_values, expected_runtime = get_saved_values_knapsack(
sweep_memory_budget / 100
sweep_memory_budget / 100, node_info, joint_graph
)
options.append(
(
Expand Down Expand Up @@ -1751,7 +1767,7 @@ def get_saved_values_knapsack(memory_budget):
# tensors we actually banned from recompute, but there may be other
# tensors that we choose to save.

return get_saved_values_knapsack(memory_budget=memory_budget)[0]
return get_saved_values_knapsack(memory_budget, node_info, joint_graph)[0]


def min_cut_rematerialization_partition(
Expand Down Expand Up @@ -1877,7 +1893,9 @@ def classify_nodes(joint_module):
break
# print("Memory Budget: ", memory_budget)
saved_values = choose_saved_values_set(
joint_graph, node_info, memory_budget=memory_budget
joint_graph,
node_info,
memory_budget=memory_budget,
)
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes = list(filter(is_sym_node, saved_values))
Expand Down
Loading