Skip to content

Commit

Permalink
[wip][inductor] convert layout of conv weight ahead of time for infer…
Browse files Browse the repository at this point in the history
…ence

ghstack-source-id: 43e0ce6937b992ccb4e13099fd1b57251a70bc36
Pull Request resolved: #103642
  • Loading branch information
shunting314 committed Jun 16, 2023
1 parent 49dcf48 commit 5315329
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 13 deletions.
35 changes: 35 additions & 0 deletions test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import weakref

import torch
from torch import nn

import torch._dynamo
from torch._inductor import config
Expand Down Expand Up @@ -229,6 +230,40 @@ def foo(mod, inp):
self.assertEqual(eager, compiled)
self.assertTrue(weight_ref() is None)

def test_conv_weight_layout_convert(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
3, 128, kernel_size=3, padding=1, stride=1, bias=False
)
def forward(self, x):
return self.conv(x)
def get_example_inputs(self):
return torch.rand(2, 3, 5, 5).cuda(),

from torch._inductor.compile_fx import compile_fx, compile_fx_inner
nconv = 0
def my_inner_compile(gm, example_inputs, *args, **kwargs):
out = compile_fx_inner(gm, example_inputs, *args, **kwargs)

nonlocal nconv
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
nconv += len(convs)
for conv in convs:
weight_node = conv.args[1]
weight_const_tensor = getattr(gm, weight_node.target)
self.assertTrue(weight_const_tensor.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(weight_node.meta['val'].is_contiguous(memory_format=torch.channels_last))

return out

mod = torch.compile(Model().cuda(), backend=functools.partial(compile_fx, inner_compile=my_inner_compile))
inp = mod.get_example_inputs()
with torch.no_grad():
mod(*inp)

self.assertTrue(nconv == 1)

if HAS_CPU and not torch.backends.mps.is_available():

Expand Down
8 changes: 0 additions & 8 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,14 +2783,6 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
torch._guards.TracingContext.get().fw_metadata = fw_metadata


# the compiler need to use this field to find the original modol outputs
# from the AOTAutograd fwd module's outputs. Thus compiler can make sure
# optimizations like layout optimization does not change those tensors'
# layout.
# TODO once https://github.com/pytorch/pytorch/pull/100652/files#r1212002707 is in
# change to access fw_metadata from the global tracing context.
fw_module.meta["original_output_start_index"] = fw_metadata.num_mutated_inputs

compiled_fw_func = aot_config.fw_compiler(
fw_module, adjusted_flat_args
)
Expand Down
7 changes: 2 additions & 5 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def fw_compiler_freezing(
graph_id,
forward_device,
):
from torch._inductor.freezing import freeze
from torch._inductor.freezing import freeze, convert_conv_weights_to_channels_last

# partition_fn won't be called
joint_graph_passes(aot_autograd_model)
Expand All @@ -684,6 +684,7 @@ def fw_compiler_freezing(
aot_autograd_model,
fw_metadata=torch._guards.TracingContext.get().fw_metadata,
)
convert_conv_weights_to_channels_last(aot_autograd_model)

aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
num_fixed = len(preserved_arg_indices) - num_example_inputs
Expand Down Expand Up @@ -821,12 +822,8 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
orig_model_outputs_node.args
)
num_orig_model_outputs = len(orig_model_outputs)
original_output_start_index = model.meta.get(
"original_output_start_index", 0
)
else:
num_orig_model_outputs = num_model_outputs
original_output_start_index = 0

assert num_orig_model_outputs <= num_model_outputs

Expand Down
37 changes: 37 additions & 0 deletions torch/_inductor/freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
from typing import List, Optional, Tuple

import torch
from torch import nn
import torch.utils._pytree as pytree
from . import config
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._mode_utils import no_dispatch
from torch._dynamo.utils import dynamo_timed

aten = torch.ops.aten

def replace_node_with_constant(gm, node, constant):
g = gm.graph
Expand Down Expand Up @@ -102,6 +107,7 @@ def run_node(self, node):
# TODO - remove constant from node_replacement when it has no uses
if node.op != "get_attr" and isinstance(out, torch.Tensor):
node_replacements[node] = out
# print(f"Shunting: {node.format_node()} becomes a const") # TODO

return out

Expand Down Expand Up @@ -224,3 +230,34 @@ def discard_traced_gm_params(mod):
e_t.requires_grad_(True)
e_t._is_param = True
setattr(mod, attr_name, e_t)


@dynamo_timed
def convert_conv_weights_to_channels_last(gm):
"""
Convert 4d convolution weight tensor to channels last format.
This method assumes the graph is already freezed.
"""
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
for conv in convs:
weight_node = conv.args[1]
# is a constant tensor
if weight_node.op == "get_attr":
param_tensor = getattr(gm, weight_node.target)
if len(param_tensor.shape) != 4:
# not a 4d tensor, skip
continue
with no_dispatch():
cl_param_tensor = param_tensor.to(memory_format=torch.channels_last)
if isinstance(param_tensor, nn.Parameter):
cl_param_tensor = nn.Parameter(cl_param_tensor)
if cl_param_tensor is not param_tensor:
setattr(gm, weight_node.target, cl_param_tensor)

# Even though inductor does not use meta['val'] or meta['tensor_meta']
# for get_attr node, we still update them to be consistent.
weight_node.meta['val'] = weight_node.meta['val'].to(memory_format=torch.channels_last)
weight_node.meta['tensor_meta'] = _extract_tensor_metadata(
weight_node.meta['val']
)

0 comments on commit 5315329

Please sign in to comment.