From 2b507a2f20c1f15ce78ff8576bf6caa14934398d Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Sun, 28 Apr 2024 01:47:56 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torch/_inductor/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index c6d363ab4ba21..aa2c29adfcc96 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1286,7 +1286,7 @@ def debug(msg): ir.get_stride_order(n.meta["val"].stride()), allow_padding=True, ) - if user.target in need_fixed_channels_last_layout: + if user.target in need_fixed_channels_last_layout and n is user.args[0]: result = ir.ExternKernel.require_stride_order( result, ir.get_stride_order( From c2a83e963a765b0a55844059cc510ed663c787ca Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Sun, 28 Apr 2024 21:21:52 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torch/_inductor/graph.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index aa2c29adfcc96..4c5c6a26aeb88 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1286,7 +1286,10 @@ def debug(msg): ir.get_stride_order(n.meta["val"].stride()), allow_padding=True, ) - if user.target in need_fixed_channels_last_layout and n is user.args[0]: + if ( + user.target in need_fixed_channels_last_layout + and n is user.args[0] + ): result = ir.ExternKernel.require_stride_order( result, ir.get_stride_order(