Skip to content

Commit

Permalink
[wip][inductor] convert layout of conv weight ahead of time
Browse files Browse the repository at this point in the history
ghstack-source-id: 705643764afe28030fbc46bfe289acda5840cf37
Pull Request resolved: #103642
  • Loading branch information
shunting314 committed Jun 15, 2023
1 parent 49dcf48 commit 3df114e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .debug import DebugContext
from .decomposition import select_decomp_table
from .fx_passes.joint_graph import joint_graph_passes
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
from .fx_passes.post_grad import post_grad_passes, view_to_reshape, convert_conv_weights_to_channels_last
from .fx_passes.pre_grad import pre_grad_passes
from .graph import GraphLowering
from .pattern_matcher import clone_graph
Expand Down Expand Up @@ -679,6 +679,8 @@ def fw_compiler_freezing(
# partition_fn won't be called
joint_graph_passes(aot_autograd_model)

convert_conv_weights_to_channels_last(aot_autograd_model, aot_example_inputs)

opt_model, preserved_arg_indices = freeze(
dynamo_model,
aot_autograd_model,
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def run_node(self, node):
# TODO - remove constant from node_replacement when it has no uses
if node.op != "get_attr" and isinstance(out, torch.Tensor):
node_replacements[node] = out
# print(f"Shunting: {node.format_node()} becomes a const") # TODO

return out

Expand Down
34 changes: 34 additions & 0 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,37 @@ def view_to_reshape(gm):
for nd in gm.graph.nodes:
if nd.target == torch.ops.aten.view.default:
nd.target = torch.ops.aten.reshape.default

def convert_conv_weights_to_channels_last(gm, example_inputs):
"""
Convert 4d convolution weight tensor to channels last format.
Need change
1. the layout in example_inputs since they are compile time inputs
2. and the layout in params_flat since they are runtime inputs
TODO: need do this after freezing since the parameter need to be
converted to fp16 first???
"""
params_flat = torch._guards.TracingContext.get().params_flat
assert len(params_flat) <= len(example_inputs)
phs = [n for n in gm.graph.nodes if n.op == "placeholder"]
assert len(phs) == len(example_inputs)

for real_param, fake_param, ph in zip(params_flat, example_inputs, phs):
if len(real_param.shape) != 4:
# not a 4d tensor, skip
continue

# if any user of the placehodler node is convolution and the placeholder
# node is used as weight, then convert the layout ahead of time.
conv_users = [n for n in ph.users if n.target == aten.convolution.default]

is_conv_weight = False
for conv_user in conv_users:
breakpoint() #TODO
if conv_users.args[1] is ph:
is_conv_weight = True
break

if is_conv_weight:
assert False, "convert real and fake param to cl" # TODO

0 comments on commit 3df114e

Please sign in to comment.