Skip to content

Commit

Permalink
Tagging based min cut partitioner
Browse files Browse the repository at this point in the history
ghstack-source-id: de40da68dd7d58a37196a87da22cbe53d17dd3f0
Pull Request resolved: #103357
  • Loading branch information
anijain2305 committed Jun 10, 2023
1 parent 4d8f564 commit b0ef590
Showing 1 changed file with 210 additions and 4 deletions.
214 changes: 210 additions & 4 deletions torch/_functorch/partitioners.py
Expand Up @@ -12,7 +12,7 @@
import os
import itertools
import sympy
from collections import defaultdict
from collections import defaultdict, deque
from torch.fx.passes import graph_drawer
from typing import Tuple
from .compile_utils import fx_graph_cse, get_aten_target
Expand All @@ -26,13 +26,30 @@ def is_symint_node(node):
assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta"
return "val" in node.meta and isinstance(node.meta['val'], torch.SymInt)

def must_recompute(node):
return "recompute" in node.meta and node.meta["recompute"]

def has_recomputable_ops(fx_g):
found = False
for node in fx_g.graph.nodes:
if must_recompute(node):
return True
return False

def has_recomputable_rng_ops(fx_g):
for node in fx_g.graph.nodes:
if must_recompute(node) and hasattr(node.target, "tags") and torch.Tag.nondeterministic_seeded in node.target.tags:
return True
return False

class InvalidNodeBase:
def __repr__(self):
return "Invalid Node"


InvalidNode = InvalidNodeBase()
counter = itertools.count()



def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs):
Expand Down Expand Up @@ -99,7 +116,6 @@ def _is_primal(node):
def _is_tangent(node):
return node.op == "placeholder" and "tangents" in node.target


def _is_bwd_seed_offset(node):
return node.op == "placeholder" and ("bwd_seed" in node.target or "bwd_base_offset" in node.target)

Expand Down Expand Up @@ -231,6 +247,8 @@ def default_partition(
Returns:
Returns the generated forward and backward Fx graph modules.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
Expand Down Expand Up @@ -357,6 +375,172 @@ def pointwise_ops():
return ops


def tangent_driven_topological_reordering(gm):
new_graph = fx.Graph()
env = {}

# Add new placeholder nodes in the order specified by the inputs
for node in gm.graph.nodes:
if node.op == "placeholder":
new_node = new_graph.placeholder(node.name)
# Can't use node_copy here as we may be turning previous call_function into placeholders
new_node.meta = node.meta
env[node] = new_node


def insert_node_in_graph(node):
if node in env:
return env[node]

if node.op == "output":
raise RuntimeError("Should not be here")

for arg in node.all_input_nodes:
env[arg] = insert_node_in_graph(arg)

env[node] = new_graph.node_copy(node, lambda x: env[x])

return env[node]

tangent_inputs = list(filter(_is_tangent, gm.graph.nodes))
q = deque(tangent_inputs)

seen = set()
while len(q):
node = q.popleft()
insert_node_in_graph(node)
for user in node.users:
# Check if the node is alredy visited
if user not in seen and user.op != "output":
q.append(user)
seen.add(user)



output_node = [node for node in gm.graph.nodes if node.op == "output"][0]

new_outputs = []
for output in output_node.args[0]:
if isinstance(output, torch.fx.node.Node):
new_outputs.append(insert_node_in_graph(output))
else:
new_outputs.append(output)

new_graph.output(new_outputs)
new_gm = torch.fx.GraphModule(gm, new_graph)
return new_gm


def functionalize_rng_ops(joint_module, fw_module, bw_module):
# We will use functionalize wrappers to wrap the random ops and share rng
# state between the fwd and bwd graphs. The map contains the pair of nodes
# that should run with same rng state.

# To make this tranformation, in the fwd pass
# 1) Replace rand with run_and_save_rng_state wrapper
# 2) Replace the users of the original op with the output[1] of this op.
# 3) Collect all the rng_state - output[0] of each op, and make them output nodes.

# In the bwd pass
# 1) Add the input nodes just before the tangents for the stashed rng states
# 2) Replace rand with run_with_save_rng_state wrappers
# 3) Use the stashed states as inputs to these ops
def get_random_nodes(gmod):
random_nodes = {}
for node in gmod.graph.nodes:
if (
node.op == "call_function"
and hasattr(node.target, "tags")
and torch.Tag.nondeterministic_seeded in node.target.tags
):
random_nodes[node.name] = node
return random_nodes

joint_graph_rng_ops = get_random_nodes(joint_module)
fw_graph_rng_ops = get_random_nodes(fw_module)
bw_graph_rng_ops = get_random_nodes(bw_module)
recomputable_rng_ops_map = dict()
for node in joint_module.graph.nodes:
if (
must_recompute(node)
and hasattr(node.target, "tags")
and torch.Tag.nondeterministic_seeded in node.target.tags
):
base_node = joint_graph_rng_ops[node.name]
fw_node = fw_graph_rng_ops[node.name]
bw_node = bw_graph_rng_ops[node.name]
recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node}

run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state
run_with_rng_state = torch._prims.rng_prims.run_with_rng_state

for node in bw_module.graph.nodes:
if node.op == "placeholder" and "tangent" in node.name:
bw_tangent_start_node = node
break

fw_rng_state_outputs = []
for base_node, node_pair in recomputable_rng_ops_map.items():
fw_node = node_pair["fwd"]
bw_node = node_pair["bwd"]
fw_graph = fw_module.graph
with fw_graph.inserting_before(fw_node):
functional_fw_node = fw_graph.create_node(
"call_function",
run_and_save_rng,
args=(fw_node.target, *fw_node.args),
kwargs=fw_node.kwargs
)
state = fw_graph.create_node("call_function", operator.getitem, args=(functional_fw_node, 0), kwargs={})
rng_output = fw_graph.create_node("call_function", operator.getitem, args=(functional_fw_node, 1,), kwargs={})
fw_node.replace_all_uses_with(rng_output)
fw_graph.erase_node(fw_node)
fw_rng_state_outputs.append(state)


bw_graph = bw_module.graph
with bw_graph.inserting_before(bw_tangent_start_node):
state_name = f"rng_state_output_{next(counter)}"
bw_rng_state_node = bw_graph.placeholder(state_name)
bw_rng_state_node.meta["val"] = torch.cuda.get_rng_state()

with bw_graph.inserting_before(bw_node):
rng_output = bw_graph.create_node(
"call_function",
run_with_rng_state,
args=(bw_rng_state_node, bw_node.target, *bw_node.args),
kwargs=bw_node.kwargs
)

bw_node.replace_all_uses_with(rng_output)
bw_graph.erase_node(bw_node)


# Add the rng states in the output of the fwd graph
fw_output = [node for node in fw_module.graph.nodes if node.op == "output"][0]
outputs = fw_output.args[0] + fw_rng_state_outputs
fw_module.graph.output(outputs)
fw_module.graph.erase_node(fw_output)
fw_module.recompile()
bw_module.recompile()
return fw_module, bw_module


def cleanup_recompute_tags(joint_module):
"""
If there are two consecutive checkpointed blocks with no operator in
between, we would still want to stash the tensor at the boundary of
checkpointed blocks. The following pass makes the last output node
non-recomputable to allow for that.
"""
for node in joint_module.graph.nodes:
if must_recompute(node):
for user in node.users:
if must_recompute(user) and user.meta["recompute"] > node.meta["recompute"]:
node.meta["recompute"] = 0
return joint_module


def min_cut_rematerialization_partition(
joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser", recomputable_ops=None,
*, num_fwd_outputs
Expand Down Expand Up @@ -403,6 +587,11 @@ def min_cut_rematerialization_partition(
joint_module.graph = cse_graph
full_bw_graph = joint_module.graph

graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
joint_module = cleanup_recompute_tags(joint_module)

name_to_node = {}
for node in joint_module.graph.nodes:
name_to_node[node.name] = node
Expand Down Expand Up @@ -501,11 +690,21 @@ def is_materialized_backwards(node):
return False

def ban_recomputation(node):
if AGGRESSIVE_RECOMPUTATION:
# if graph_has_recomputable_ops:
# return not must_recompute(node)
if must_recompute(node):
return False
elif AGGRESSIVE_RECOMPUTATION:
return (node.op == 'call_function' and get_aten_target(node) in unrecomputable_ops)
else:
if node.op != 'call_function':
return False

# Ban recomputation if one of the users is a must recompute node
for user in node.users:
if must_recompute(user):
return True

if get_aten_target(node) not in recomputable_ops:
return True
if node.target == operator.getitem:
Expand Down Expand Up @@ -611,12 +810,19 @@ def get_node_weight(node) -> int:
# To make this stuff deterministic
node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x])
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes = list(filter(lambda n: is_sym_node(n), saved_values))
saved_values = list(filter(lambda n: not is_sym_node(n), saved_values))
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)

if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module
)
bw_module = tangent_driven_topological_reordering(bw_module)

if AOT_PARTITIONER_DEBUG:
print("Theoretical Activations Stored: ", sum([_size_of(i) for i in saved_values]) / 1e9)
fw_module_nodes = {node.name for node in fw_module.graph.nodes if node.op == 'call_function'}
Expand Down

0 comments on commit b0ef590

Please sign in to comment.