Skip to content

Commit

Permalink
Update on "[reland][quant][eagermode] Move custom_module registration…
Browse files Browse the repository at this point in the history
… to prepare/convert_custom_config_dict (#46293)"

Summary:

Test Plan: Imported from OSS

Reviewed By: raghuramank100

Differential Revision: [D24322747](https://our.internmc.facebook.com/intern/diff/D24322747)

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Oct 16, 2020
1 parent e28da93 commit 50a32d5
Showing 1 changed file with 34 additions and 38 deletions.
72 changes: 34 additions & 38 deletions torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,24 +191,23 @@ def prepare(model, inplace=False, allow_list=None,
`observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
`prepare_custom_config_dict`: customization configuration dictionary for prepare function
Example:
.. note::
import torch
class CustomModule(torch.nn.Module):
pass
class ObservedCustomModule(torch.nn.Module):
pass
prepare_custom_config_dict = {
# user will manually define the corresponding observed
# module class which has a from_float class method that converts
# float custom module to observed custom module
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
Example of prepare_custom_config_dict:
import torch
class CustomModule(torch.nn.Module):
pass
class ObservedCustomModule(torch.nn.Module):
pass
prepare_custom_config_dict = {
# user will manually define the corresponding observed
# module class which has a from_float class method that converts
# float custom module to observed custom module
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
# get around the no code/output warning
print(prepare_custom_config_dict)
Expand Down Expand Up @@ -420,28 +419,25 @@ def convert(
is mutated
`convert_custom_config_dict`: custom configuration dictionary for convert function
Returns: A module with all the children quantized
Example:
xo
.. note::
import torch
class ObservedCustomModule(torch.nn.Module):
pass
class QuantizedCustomModule(torch.nn.Module):
pass
convert_custom_config_dict = {
# user will manually define the corresponding quantized
# module class which has a from_observed class method that converts
# observed custom module to quantized custom module
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
# get around the no code/output warning
print(convert_custom_config_dict)
Example of convert_custom_config_dict:
import torch
class ObservedCustomModule(torch.nn.Module):
pass
class QuantizedCustomModule(torch.nn.Module):
pass
convert_custom_config_dict = {
# user will manually define the corresponding quantized
# module class which has a from_observed class method that converts
# observed custom module to quantized custom module
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: QuantizedCustomModule
}
}
Returns: A module with all the children quantized
"""
torch._C._log_api_usage_once("quantization_api.quantize.convert")
Expand Down

0 comments on commit 50a32d5

Please sign in to comment.