Skip to content

Commit

Permalink
Inductor Optimize For Inference/Freezing
Browse files Browse the repository at this point in the history
ghstack-source-id: 0ebf73bfa14ebebe376a2bba0d1acf0564379d44
Pull Request resolved: #100652
  • Loading branch information
eellison committed May 16, 2023
1 parent c4fe5ff commit ee2318b
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 2 deletions.
4 changes: 3 additions & 1 deletion torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,9 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
adjusted_flat_args = [seed, offset, *flat_args]
flat_args.clear() # Don't hold extra reference
compiled_fw = compiler(fw_module, adjusted_flat_args)

torch._guards.TracingContext.get().fw_metadata = fw_metadata
compiled_fw = compiler(fw_module, flat_args)

# This boxed_call handling happens inside create_runtime_wrapper as well.
# However, create_runtime_wrapper does not expect the rng offsets in the
Expand Down
2 changes: 2 additions & 0 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def __init__(self, fake_mode):
self.fake_mode = fake_mode
self.frame_summary_stack = []
self.loc_in_frame = None
# this is only set after aot_autograd
self.fw_metadata = None

@staticmethod
def extract_stack():
Expand Down
46 changes: 46 additions & 0 deletions torch/_inductor/Untitled-20.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

import torch


@torch.compile
def foo(x):
# GRAPH 1
y = x * x * x
# graph break triggered here
if y.sum() > 0:
# GRAPH 2
z = y ** y
else:
# GRAPH 3
z = (y.abs() ** y.abs())
torch._dynamo.graph_break()

return z * torch.rand_like(z)

# Running Graph 1, 2, and 4
foo(torch.arange(0, 10, device="cuda"))
# Replaying Graph 1, 2, and 4
foo(torch.arange(0, 10, device="cuda"))



@torch.compile
def foo(x):
# GRAPH 1
y = x * x * x
# graph break triggered here
if y.sum() > 0:
# GRAPH 2
z = y ** y
else:
# GRAPH 3
z = (y.abs() ** y.abs())
torch._dynamo.graph_break()

return z * torch.rand_like(z)

...

# x * x * 3, graph executed as part of first tape, but then during execution
# diverges from its recording, as we hit the abs() path
foo(torch.arange(-10, 0), device="cuda")
46 changes: 45 additions & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import logging
import sys
import unittest
import warnings

from copy import deepcopy
Expand Down Expand Up @@ -667,7 +668,50 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
)

fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)

if config.freezing:
from torch._inductor.freezing import freeze

def inference_compiler(model: torch.fx.GraphModule, example_inputs):
# partition_fn won't be called
joint_graph_passes(model)

opt_model, preserved_arg_indices = freeze(
model_,
model,
example_inputs,
fw_metadata=torch._guards.TracingContext.get().fw_metadata,
)

example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
num_fixed = len(preserved_arg_indices) - num_example_inputs

fake_mode = detect_fake_mode(example_inputs)

# constant params will be real tensors, not fake
with unittest.mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
optimized_function = inner_compile(
opt_model,
example_inputs,
num_fixed=num_fixed,
cudagraphs=cudagraphs,
graph_id=graph_id,
is_inference=True,
boxed_forward_device_index=forward_device,
)

# Need to drop the args we have constant-ified.
# TODO - find way for aot_autograd to not update calling convention
def wrapper(args):
args_new = [args[ind] for ind in preserved_arg_indices]
args.clear()
return optimized_function(args_new)

wrapper._boxed_call = True
return wrapper

else:
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)

def partition_fn(graph, joint_inputs, **kwargs):
joint_graph_passes(graph)
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def decide_compile_threads():

disable_cpp_codegen = is_fbcode()

freezing = True

freezing_discard_parameters = False


# config specific to codegen/cpp.py
class cpp:
Expand Down
174 changes: 174 additions & 0 deletions torch/_inductor/freezing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import itertools
import weakref
from typing import List, Optional

import torch
import torch.utils._pytree as pytree
from . import config


def replace_node_with_constant(gm, node, constant):
g = gm.graph

i = 0
while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1

with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)

# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)


def replace_params_with_constants(gm, real_inputs, example_inputs_, fw_metadata):
fake_inp_nodes = [node for (_, node) in zip(real_inputs, gm.graph.nodes)]

g = gm.graph

preserved_arg_indices = []

for i, (real_input, fake_input, node) in enumerate(
zip(real_inputs, example_inputs_, fake_inp_nodes)
):
assert real_input.shape == fake_input.shape

if i in fw_metadata.mutated_inp_indices:
preserved_arg_indices.append(i)
continue

replace_node_with_constant(gm, node, real_input)

# add on non param inputs
preserved_arg_indices.extend(range(len(real_inputs), len(example_inputs_)))

g.lint()
# is this necessary ?
gm.recompile()
return gm, preserved_arg_indices


@torch.utils._python_dispatch._disable_current_modes()
def constant_fold(gm):
unknown_value = object()

node_replacements = {}

class ConstantFolder(torch.fx.Interpreter):
def run_node(self, node):
args, kwargs = self.fetch_args_kwargs_from_env(node)
if unknown_value in pytree.tree_flatten((args, kwargs))[0]:
return unknown_value

# All mutations should either be removed or on inputs which we did not make constant
if (
isinstance(node.target, torch._ops.OpOverload)
and torch.Tag.nondeterministic_seeded in node.target.tags
):
return unknown_value

out = super().run_node(node)

# TODO - remove constant from node_replacement when it has no uses
if node.op != "get_attr" and isinstance(out, torch.Tensor):
node_replacements[node] = out

return out

def run(self):
env = {}
for n in self.module.graph.nodes:
if n.op == "placeholder":
env[n] = unknown_value
return super().run(initial_env=env)

ConstantFolder(gm).run()

for node, constant in node_replacements.items():
replace_node_with_constant(gm, node, constant)

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()


def freeze(
original_gm: torch.fx.GraphModule,
gm: torch.fx.GraphModule,
example_inputs_: List[torch.Tensor],
fw_metadata,
) -> Tuple[torch.fx.GraphModule, List[int]]:
"Inlines unmutated parameters into constants and runs constant propagation and other optimizations"

params = {
**dict(original_gm.named_parameters(remove_duplicate=False)),
**dict(original_gm.named_buffers(remove_duplicate=False)),
}
params_flat, _ = pytree.tree_flatten(params)
params_flat = tuple(params_flat)

# TODO - aot_autograd currently doesn't have a way of not updating the calling convention to include
# parameters, so we need to drop parameters that became constants from inputs. This also prevents
# deallocating unused parameters if `freezing_discard_parameters` is True.
gm, preserved_arg_indices = replace_params_with_constants(
gm, params_flat, example_inputs_, fw_metadata
)

constant_fold(gm)

# invalidate nn Modules
if config.freezing_discard_parameters:
invalidate_eager_modules()
return gm, preserved_arg_indices


class ErasedTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem, name, owning_mod):
return super().__new__(cls, elem.to(device="meta"))

def __init__(self, elem, name: Optional[str], mod):
self.erased_name = name
self.owning_mod_ref = weakref.ref(mod)

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
erased_tensors = [
e
for e in pytree.tree_flatten((args, kwargs))[0]
if isinstance(e, ErasedTensor)
]
assert len(erased_tensors) > 0
e = erased_tensors[0]

raise RuntimeError(
f"Trying to Run Pytorch Eager Module After Dynamo Freezing. "
"The original parameters have been discarded for memeory efficiency. "
f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}"
)


@torch.utils._python_dispatch._disable_current_modes()
def invalidate_eager_modules():
# TODO - could just invalidate the parameters that were folded
for mod in torch._guards.TracingContext.get().module_context.nn_modules.values():
if not isinstance(mod, torch.nn.Module):
continue

for attr_name, tensor in list(
itertools.chain(
mod.named_parameters(recurse=False), mod.named_buffers(recurse=False)
)
):
e_t = ErasedTensor(tensor, attr_name, mod)
if isinstance(tensor, torch.nn.Parameter):
e_t.requires_grad_(True)
e_t._is_param = True
setattr(mod, attr_name, e_t)

0 comments on commit ee2318b

Please sign in to comment.