Skip to content

Commit

Permalink
[wip][inductor] convert layout of conv weight ahead of time for infer…
Browse files Browse the repository at this point in the history
…ence

ghstack-source-id: 3288073a1050922f36e67ccaad694f58942b9f5f
Pull Request resolved: #103642
  • Loading branch information
shunting314 committed Jun 15, 2023
1 parent 49dcf48 commit 025d708
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def fw_compiler_freezing(
graph_id,
forward_device,
):
from torch._inductor.freezing import freeze
from torch._inductor.freezing import freeze, convert_conv_weights_to_channels_last

# partition_fn won't be called
joint_graph_passes(aot_autograd_model)
Expand All @@ -684,6 +684,7 @@ def fw_compiler_freezing(
aot_autograd_model,
fw_metadata=torch._guards.TracingContext.get().fw_metadata,
)
convert_conv_weights_to_channels_last(aot_autograd_model)

aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
num_fixed = len(preserved_arg_indices) - num_example_inputs
Expand Down
27 changes: 27 additions & 0 deletions torch/_inductor/freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.utils._pytree as pytree
from . import config

aten = torch.ops.aten

def replace_node_with_constant(gm, node, constant):
g = gm.graph
Expand Down Expand Up @@ -102,6 +103,7 @@ def run_node(self, node):
# TODO - remove constant from node_replacement when it has no uses
if node.op != "get_attr" and isinstance(out, torch.Tensor):
node_replacements[node] = out
# print(f"Shunting: {node.format_node()} becomes a const") # TODO

return out

Expand Down Expand Up @@ -224,3 +226,28 @@ def discard_traced_gm_params(mod):
e_t.requires_grad_(True)
e_t._is_param = True
setattr(mod, attr_name, e_t)


@torch.utils._python_dispatch._disable_current_modes()
def convert_conv_weights_to_channels_last(gm):
"""
Convert 4d convolution weight tensor to channels last format.
Need change
1. the layout in example_inputs since they are compile time inputs
2. and the layout in params_flat since they are runtime inputs
This method assumes the graph is already freezed.
"""
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
for conv in convs:
weight_node = conv.args[1]
# is a constant tensor
if weight_node.op == "get_attr":
param_tensor = getattr(gm, weight_node.target)
if len(param_tensor.shape) != 4:
# not a 4d tensor, skip
continue
cl_param_tensor = param_tensor.to(memory_format=torch.channels_last)
if cl_param_tensor is not param_tensor:
setattr(gm, weight_node.target, cl_param_tensor)
# TODO change meta['val'] as well?

0 comments on commit 025d708

Please sign in to comment.