diff --git a/torch/_inductor/mkldnn.py b/torch/_inductor/mkldnn.py index ebbf9acfd808..45a31c2f97c5 100644 --- a/torch/_inductor/mkldnn.py +++ b/torch/_inductor/mkldnn.py @@ -295,8 +295,8 @@ def pack_module(gm: torch.fx.GraphModule): ): continue else: - computation_node_input_size = ( - node.args[0].meta.get("tensor_meta").shape + computation_node_input_size = tuple( + int(x) for x in node.args[0].meta.get("tensor_meta").shape ) if any(size == 0 for size in computation_node_input_size): continue