From fffae03e4f5dc1b2df433755c8c6136b657d318a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 4 Jan 2021 15:43:47 -0800 Subject: [PATCH] eager quant: fix error with removing forward hooks Summary: https://github.com/pytorch/pytorch/issues/49739 reports a crash where removing forward hooks results in a ``` RuntimeError: OrderedDict mutated during iteration ``` Unfortunately I cannot repro this inside the PyTorch module, but the issue author has a good point and and we should not mutate the dict inside of the iteration. Test Plan: ``` // test plan from https://github.com/pytorch/pytorch/pull/46871 which // originally added this python test/test_quantization.py TestEagerModeQATOps ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 52660b6634a933a9e73c1f6dce6d8586a6e1c2c6 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49813 --- torch/quantization/quantize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index 1be867e0a299..511cf5c9988a 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -256,9 +256,12 @@ def _remove_activation_post_process(module): delattr(module, 'activation_post_process') # remove activation_post_proceess hook + handle_ids_to_remove = set() for handle_id, hook_fn in module._forward_hooks.items(): if hook_fn is _observer_forward_hook: - module._forward_hooks.pop(handle_id) + handle_ids_to_remove.add(handle_id) + for handle_id in handle_ids_to_remove: + module._forward_hooks.pop(handle_id) # TODO: rename to something more general def _remove_qconfig(module):