Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numeric suite: add types to eager #51168

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 41 additions & 27 deletions torch/quantization/_numeric_suite.py
@@ -1,10 +1,9 @@

import torch
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.quantization import prepare
from typing import Dict
from typing import Dict, List, Optional, Any, Union, Callable, Set

from .quantization_mappings import (
get_default_compare_output_module_list,
Expand All @@ -18,7 +17,10 @@
}


def _find_match(str_list, key_str, postfix):
def _find_match(
str_list: Union[Dict[str, Any], List[str]], key_str: str,
postfix: str,
) -> Optional[str]:
split_str = key_str.split(".")
if split_str[-1] == postfix:
match_string = "".join(key_str.split(".")[0:-1])
Expand All @@ -42,11 +44,14 @@ def _find_match(str_list, key_str, postfix):
return s2
if match_string == pattern2:
return s2
return None
else:
return None


def compare_weights(float_dict, quantized_dict):
def compare_weights(
float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
) -> Dict[str, Dict[str, torch.Tensor]]:
r"""Compare the weights of the float module with its corresponding quantized
module. Return a dict with key corresponding to module names and each entry being
a dictionary with two keys 'float' and 'quantized', containing the float and
Expand Down Expand Up @@ -105,7 +110,10 @@ def compare_weights(float_dict, quantized_dict):
return weight_dict


def _get_logger_dict_helper(mod, target_dict, prefix=""):
def _get_logger_dict_helper(
mod: nn.Module, target_dict: Dict[str, Any],
prefix: str = "",
) -> None:
r"""This is the helper function for get_logger_dict

Args:
Expand All @@ -127,7 +135,7 @@ def get_prefix(prefix):
_get_logger_dict_helper(child, target_dict, module_prefix)


def get_logger_dict(mod, prefix=""):
def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
r"""Traverse the modules and save all logger stats into target dict.
This is mainly used for quantization accuracy debug.

Expand Down Expand Up @@ -195,11 +203,11 @@ def forward(self, x):
return x


def _convert_tuple_to_list(t):
def _convert_tuple_to_list(t: Any) -> Any:
return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t


def _dequantize_tensor_list(t):
def _dequantize_tensor_list(t: Any) -> Any:
return (
list(_dequantize_tensor_list(x) for x in t)
if type(t) is list
Expand Down Expand Up @@ -228,52 +236,52 @@ def __init__(self, q_module, float_module, Logger):
self.dequant = nnq.DeQuantize()
self.logger = Logger()

def forward(self, *x):
def forward(self, *x) -> torch.Tensor:
xl = _convert_tuple_to_list(x)
output = self.orig_module(*xl)
xl_float = _dequantize_tensor_list(xl)
shadow_output = self.shadow_module(*xl_float)
self.logger(output, shadow_output)
return output

def add(self, x, y):
def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = self.orig_module.add(x, y)
x = x.dequantize()
y = y.dequantize()
shadow_output = self.shadow_module.add(x, y)
self.logger(output, shadow_output)
return output

def add_scalar(self, x, y):
def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
output = self.orig_module.add_scalar(x, y)
x = x.dequantize()
shadow_output = self.shadow_module.add_scalar(x, y)
self.logger(output, shadow_output)
return output

def mul(self, x, y):
def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = self.orig_module.mul(x, y)
x = x.dequantize()
y = y.dequantize()
shadow_output = self.shadow_module.mul(x, y)
self.logger(output, shadow_output)
return output

def mul_scalar(self, x, y):
def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
output = self.orig_module.mul_scalar(x, y)
x = x.dequantize()
shadow_output = self.shadow_module.mul_scalar(x, y)
self.logger(output, shadow_output)
return output

def cat(self, x, dim=0):
def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
output = self.orig_module.cat(x, dim)
x = [y.dequantize() for y in x]
shadow_output = self.shadow_module.cat(x, dim)
self.logger(output, shadow_output)
return output

def add_relu(self, x, y):
def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = self.orig_module.add_relu(x, y)
x = x.dequantize()
y = y.dequantize()
Expand All @@ -282,7 +290,10 @@ def add_relu(self, x, y):
return output


def prepare_model_with_stubs(float_module, q_module, module_swap_list, Logger):
def prepare_model_with_stubs(
float_module: nn.Module, q_module: nn.Module,
module_swap_list: Set[type], Logger: Callable,
) -> None:
r"""Prepare the model by attaching the float module to its matching quantized
module as the shadow if the float module type is in module_swap_list.

Expand Down Expand Up @@ -322,8 +333,9 @@ def prepare_model_with_stubs(float_module, q_module, module_swap_list, Logger):


def compare_model_stub(
float_model, q_model, module_swap_list, *data, Logger=ShadowLogger
):
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
*data, Logger=ShadowLogger
) -> Dict[str, Dict]:
r"""Compare quantized module in a model with its floating point counterpart,
feeding both of them the same input. Return a dict with key corresponding to
module names and each entry being a dictionary with two keys 'float' and
Expand Down Expand Up @@ -361,7 +373,9 @@ def compare_model_stub(
return ob_dict


def get_matching_activations(float_module, q_module):
def get_matching_activations(
float_module: nn.Module, q_module: nn.Module,
) -> Dict[str, Dict[str, torch.Tensor]]:
r"""Find the matching activation between float and quantized modules.

Args:
Expand All @@ -387,11 +401,11 @@ def get_matching_activations(float_module, q_module):


def prepare_model_outputs(
float_module,
q_module,
float_module: nn.Module,
q_module: nn.Module,
Logger=OutputLogger,
allow_list=None
):
) -> None:
r"""Prepare the model by attaching the logger to both float module
and quantized module if they are in the allow_list.

Expand All @@ -406,9 +420,9 @@ def prepare_model_outputs(
allow_list = get_default_compare_output_module_list()

qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None)
float_module.qconfig = qconfig_debug
float_module.qconfig = qconfig_debug # type: ignore
prepare(float_module, inplace=True, allow_list=allow_list)
q_module.qconfig = qconfig_debug
q_module.qconfig = qconfig_debug # type: ignore
prepare(
q_module,
inplace=True,
Expand All @@ -418,12 +432,12 @@ def prepare_model_outputs(


def compare_model_outputs(
float_model,
q_model,
float_model: nn.Module,
q_model: nn.Module,
*data,
Logger=OutputLogger,
allow_list=None
):
) -> Dict[str, Dict[str, torch.Tensor]]:
r"""Compare output activations between float and quantized models at
corresponding locations for the same input. Return a dict with key corresponding
to quantized module names and each entry being a dictionary with two keys
Expand Down