From f8e98b815f69b1622b7b0d495e46740b04410e50 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 23 Dec 2020 16:07:25 -0800 Subject: [PATCH] eager quant: fix error with removing forward hooks Summary: https://github.com/pytorch/pytorch/issues/49739 reports a crash where removing forward hooks results in a ``` RuntimeError: OrderedDict mutated during iteration ``` Unfortunately I cannot repro this inside the PyTorch module, but the issue author has a good point and and we should not mutate the dict inside of the iteration. Test Plan: ``` // test plan from https://github.com/pytorch/pytorch/pull/46871 which // originally added this python test/test_quantization.py TestEagerModeQATOps ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 462777c1295c03f823f03c0ce173ebf2068b2cad Pull Request resolved: https://github.com/pytorch/pytorch/pull/49813 --- torch/quantization/quantize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index a9417ecb80f3..616744ab8bfd 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -252,9 +252,12 @@ def _remove_activation_post_process(module): delattr(module, 'activation_post_process') # remove activation_post_proceess hook + handle_ids_to_remove = set() for handle_id, hook_fn in module._forward_hooks.items(): if hook_fn is _observer_forward_hook: - module._forward_hooks.pop(handle_id) + handle_ids_to_remove.add(handle_id) + for handle_id in handle_ids_to_remove: + module._forward_hooks.pop(handle_id) # TODO: rename to something more general def _remove_qconfig(module):