From c26830e084ffe42c6f72ab525f99b97104640a25 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 29 Dec 2020 13:47:18 -0800 Subject: [PATCH] Clean up some type annotations in caffe2/torch/quantization (#49942) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49942 Upgrades type annotations from Python2 to Python3 Test Plan: Sandcastle tests Reviewed By: vkuzo Differential Revision: D25717551 fbshipit-source-id: d52555d701793b7c07e561df56acd82c966bdb7c --- torch/quantization/_numeric_suite_fx.py | 2 +- torch/quantization/fake_quantize.py | 6 ++---- torch/quantization/observer.py | 3 +-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/torch/quantization/_numeric_suite_fx.py b/torch/quantization/_numeric_suite_fx.py index eb1596832c4d..aeba95bb4e8f 100644 --- a/torch/quantization/_numeric_suite_fx.py +++ b/torch/quantization/_numeric_suite_fx.py @@ -21,7 +21,7 @@ def remove_qconfig_observer_fx(model): # remove activation post process act_post_process_removed_graph = Graph() - env = {} # type: Dict[str, Any] + env: Dict[str, Any] = {} modules = dict(model.named_modules()) diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py index f0ee8453557d..460b1c277a93 100644 --- a/torch/quantization/fake_quantize.py +++ b/torch/quantization/fake_quantize.py @@ -41,8 +41,7 @@ def calculate_qparams(self, **kwargs): pass @torch.jit.export - def enable_fake_quant(self, enabled=True): - # type: (bool) -> None + def enable_fake_quant(self, enabled: bool = True) -> None: self.fake_quant_enabled[0] = 1 if enabled else 0 @torch.jit.export @@ -50,8 +49,7 @@ def disable_fake_quant(self): self.enable_fake_quant(False) @torch.jit.export - def enable_observer(self, enabled=True): - # type: (bool) -> None + def enable_observer(self, enabled: bool = True) -> None: self.observer_enabled[0] = 1 if enabled else 0 @torch.jit.export diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 32d07c939695..7addaa622962 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -877,8 +877,7 @@ def _combine_histograms(self, orig_hist = orig_hist + interpolated_histogram.to(torch.float) return orig_hist - def forward(self, x_orig): - # type: (torch.Tensor) -> torch.Tensor + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: x = x_orig.detach() min_val = self.min_val max_val = self.max_val