From ecabe753039ac0a09451558ab3a04c90e0ee8cf7 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Fri, 25 Aug 2023 19:59:54 -0700 Subject: [PATCH] Add lru_cache for node comparison (#148) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/148 cache the args to optimize the number of calls. For `is_same_node` function, we're expecting the same results each time it's called with the same inputs. For mv2, ``` buck run mode/dev-nosan //executorch/examples/backend:xnnpack_examples -- -m mv2 -q -d ``` The total number of nodes in mv2 graph after quantize is 520. Before the change: mv2: The number of call to `is_same_node` is 6388036 🙈 mv2: to_backend call time: 30.79 second (including dumping the logs) mv3: takes forever After the change: mv2: The number of call to `is_same_node` is 520. It's more reasonable mv2: to_backend call time: 10.15 second (including dumping the logs) mv3: `to_backend` call takes 16 second. Reviewed By: digantdesai, mcr229 Differential Revision: D48708658 fbshipit-source-id: ce2b7f5835ed17fa6df75a3bfcee495b2ab0a4a2 --- exir/backend/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exir/backend/utils.py b/exir/backend/utils.py index 2ee48a105a5..2defad588a7 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from functools import lru_cache from typing import Iterable, List, Tuple import torch @@ -16,6 +17,7 @@ T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default +@lru_cache(maxsize=128) def is_same_node( node_left: Iterable[torch.fx.Node], node_right: Iterable[torch.fx.Node], @@ -39,8 +41,6 @@ def is_same_node( if len(list(node_left)) != len(list(node_right)): return False for n_left, n_right in zip(node_left, node_right): - # pyre-fixme[6]: For 1st argument expected `Iterable[Node]` but got `Node`. - # pyre-fixme[6]: For 2nd argument expected `Iterable[Node]` but got `Node`. if not is_same_node(n_left, n_right): return False return True