From fc66e05ef71bd2f4f14412e2c705c35d38e66764 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 30 Dec 2020 14:41:47 -0800 Subject: [PATCH] eager quant: fix error with removing forward hooks Summary: https://github.com/pytorch/pytorch/issues/49739 reports a crash where removing forward hooks results in a ``` RuntimeError: OrderedDict mutated during iteration ``` Unfortunately I cannot repro this inside the PyTorch module, but the issue author has a good point and and we should not mutate the dict inside of the iteration. Test Plan: ``` // test plan from https://github.com/pytorch/pytorch/pull/46871 which // originally added this python test/test_quantization.py TestEagerModeQATOps ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b7ef4e3c381c9451b2a6368c69f27e345837aee9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49813 --- torch/quantization/quantize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index a9417ecb80f3..616744ab8bfd 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -252,9 +252,12 @@ def _remove_activation_post_process(module): delattr(module, 'activation_post_process') # remove activation_post_proceess hook + handle_ids_to_remove = set() for handle_id, hook_fn in module._forward_hooks.items(): if hook_fn is _observer_forward_hook: - module._forward_hooks.pop(handle_id) + handle_ids_to_remove.add(handle_id) + for handle_id in handle_ids_to_remove: + module._forward_hooks.pop(handle_id) # TODO: rename to something more general def _remove_qconfig(module):