Skip to content

Commit

Permalink
[quant][fx][graphmode] Run symbolic_trace in quantization (#45919)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #45919

As discussed with JIT team, we'll run symbolic trace in quantization functions
prepare_fx now takes orginal pytorch model (torch.nn.Module) instead of `GraphModule` as input

Test Plan: Imported from OSS

Reviewed By: supriyar

Differential Revision: D24145857

fbshipit-source-id: 2b7a4ca525a7a8c23a26af54ef594c6a951e4024
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Oct 9, 2020
1 parent c6672a6 commit 2b204e6
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 87 deletions.
69 changes: 17 additions & 52 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
import torch.nn.intrinsic.quantized as nniq
import torch.multiprocessing as mp

# symbolic trace
from torch.fx import symbolic_trace

from torch.fx.symbolic_trace import Tracer

# graph mode quantization based on fx
from torch.quantization import (
QuantType,
Expand Down Expand Up @@ -175,10 +170,9 @@ def forward(self, x):
return F.linear(x, self.weight)

m = M(torch.rand(1, 1)).eval()
original = symbolic_trace(m)
qconfig = default_dynamic_qconfig
qconfig_dict = {'': qconfig}
prepared = prepare_fx(original, qconfig_dict)
prepared = prepare_fx(m, qconfig_dict)
quantized = convert_fx(prepared, debug=True)
qparams = (quantized._scale_0, quantized._zero_point_0)
weight_obs = qconfig.weight()
Expand Down Expand Up @@ -224,7 +218,6 @@ def forward(self, x):
if weight_prepack_node:
node_occurrence[weight_prepack_node] = 0
m = ModuleClass(*module_constructor_inputs).eval()
m = symbolic_trace(m)
qconfig_dict = {"": float16_dynamic_qconfig}
m = prepare_fx(m, qconfig_dict)
m = convert_fx(m, debug=debug)
Expand Down Expand Up @@ -259,9 +252,6 @@ def forward(self, x):
device = torch.device('cuda:0')
model.to(device)

# symbolically trace
model = symbolic_trace(model)

# QAT prepare
model = prepare_qat_fx(model, qconfig_dict)

Expand All @@ -287,7 +277,6 @@ def forward(self, x):
return self.conv(x)

model = M().eval()
model = symbolic_trace(model)
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(
model, qconfig_dict, inplace=False)
Expand Down Expand Up @@ -316,7 +305,7 @@ def forward(self, x):
return {"output": self.conv(x["input"])}

dict_input = {"input": torch.randn(1, 1, 1, 1)}
m = symbolic_trace(M()).eval()
m = M().eval()
qconfig_dict = {"": default_qconfig}
m = prepare_fx(m, qconfig_dict)
m(dict_input)
Expand All @@ -332,12 +321,6 @@ def __init__(self):
def forward(self, x):
return self.conv(x)

class CustomTracer(Tracer):
def is_leaf_module(self, m, module_qualified_name):
return (m.__module__.startswith('torch.nn') and
not isinstance(m, torch.nn.Sequential)) or \
isinstance(m, StandaloneModule)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -362,17 +345,16 @@ def forward(self, x):

data = torch.randn(1, 1, 1, 1)
# instantiate M and RefM and align the parameters
original_m = M()
original_ref_m = RefM()
original_m = M().eval()
original_ref_m = RefM().eval()
original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())

m = torch.fx.GraphModule(original_m, CustomTracer().trace(original_m)).eval()
qconfig_dict = {'': default_qconfig, 'standalone_module_name': ['standalone']}
# check prepared model
m = prepare_fx(m, qconfig_dict)
m = prepare_fx(original_m, qconfig_dict)
# calibration
m(data)
# input and output of first conv, observer for standalone module
Expand Down Expand Up @@ -406,8 +388,7 @@ def forward(self, x):
res = m(data)

# quantize the reference model
ref_m = symbolic_trace(original_ref_m).eval()
ref_m = prepare_fx(ref_m, qconfig_dict)
ref_m = prepare_fx(original_ref_m, qconfig_dict)
ref_m(data)
ref_m = convert_fx(ref_m)
ref_res = ref_m(data)
Expand All @@ -427,7 +408,6 @@ def forward(self, x):
return x

m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"": default_qconfig,
"module_name": [("conv2", None)]}
m = prepare_fx(m, qconfig_dict)
Expand Down Expand Up @@ -457,7 +437,6 @@ def forward(self, x):
return x

m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]}
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
Expand All @@ -482,7 +461,6 @@ def forward(self, x, y):
return x + y

m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"object_type": [(operator.add, default_qconfig)]}
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
Expand Down Expand Up @@ -510,7 +488,6 @@ def forward(self, x):
return x

m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]}
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
Expand Down Expand Up @@ -547,7 +524,6 @@ def forward(self, x):
return x

m = M().eval()
m = symbolic_trace(m)
global_qconfig = default_qconfig
object_type_qconfig = default_dynamic_qconfig
module_name_regex_qconfig = float16_dynamic_qconfig
Expand All @@ -574,7 +550,6 @@ def forward(self, x):
return self.avg_pool(x)

m = M().eval()
m = symbolic_trace(m)
qconfig_dict = {'': default_qconfig}
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
Expand All @@ -587,13 +562,9 @@ def forward(self, x):

@skipIfNoFBGEMM
def test_qat_and_script(self):

model = LinearModelWithSubmodule()
model = LinearModelWithSubmodule().train()
qengine = torch.backends.quantized.engine
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig(qengine)}

# symbolically trace
model = symbolic_trace(model)
model = prepare_qat_fx(model, qconfig_dict)

# ensure scripting works
Expand Down Expand Up @@ -629,8 +600,6 @@ def test_save_observer_state_dict(self):
orig = LinearModelWithSubmodule().eval()
model = orig
qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')}
# symbolically trace
model = symbolic_trace(model)
model = prepare_fx(model, qconfig_dict)

# run it through input
Expand All @@ -647,7 +616,6 @@ def test_save_observer_state_dict(self):

# Load the stats into new model
model_2 = orig
model_2 = symbolic_trace(model_2)
model_2 = prepare_fx(model_2, qconfig_dict)

loaded_dict = torch.load(b)
Expand All @@ -659,6 +627,7 @@ def test_save_observer_state_dict(self):
self.assertEqual(quant(x), quant_2(x))

@skipIfNoFBGEMM
@unittest.skip("Fix in next PR, will need to change API")
def test_custom_module_class(self):
class CustomModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -739,8 +708,8 @@ def forward(self, x):

data = torch.randn(1, 1, 1, 1)
# instantiate M and RefM and align the parameters
original_m = M()
original_ref_m = RefM()
original_m = M().eval()
original_ref_m = RefM().eval()
original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach())
Expand All @@ -762,7 +731,7 @@ def is_leaf_module(self, m, module_qualified_name):
register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)

m = torch.fx.GraphModule(original_m, CustomTracer().trace(original_m)).eval()
m = torch.fx.GraphModule(original_m, CustomTracer().trace(original_m))
qconfig_dict = {'': default_qconfig}
# check prepared model
m = prepare_fx(m, qconfig_dict)
Expand All @@ -785,8 +754,7 @@ def is_leaf_module(self, m, module_qualified_name):
res = m(data)

# quantize the reference model
ref_m = symbolic_trace(original_ref_m).eval()
ref_m = prepare_fx(ref_m, qconfig_dict)
ref_m = prepare_fx(original_ref_m, qconfig_dict)
ref_m(data)
ref_m = convert_fx(ref_m)
ref_res = ref_m(data)
Expand Down Expand Up @@ -1351,10 +1319,9 @@ def forward(self, x):
# This model is not executable since we just put all ops
# in the same forward
m = M().eval()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(original, qconfig_dict)
prepared = prepare_fx(m, qconfig_dict)
# not runnable
quantized = convert_fx(prepared)

Expand Down Expand Up @@ -1440,10 +1407,9 @@ def forward(self, x):
# This model is not executable since we just put all ops
# in the same forward
m = M().eval()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(original, qconfig_dict)
prepared = prepare_fx(m, qconfig_dict)
# not runnable
quantized = convert_fx(prepared)

Expand Down Expand Up @@ -1492,12 +1458,11 @@ def _test_model_impl(

qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
qconfig_dict = {'': qconfig}
graph_module = symbolic_trace(model)
# print('graph module:', graph_module.src)
script = torch.jit.script(graph_module)
script = torch.jit.script(model)

# make sure graph module and script module are both runanble
original_out = graph_module(input_value)
original_out = model(input_value)
is_not_tuple_out = not isinstance(original_out, tuple)
script_out = script(input_value)
self.assertEqual(
Expand All @@ -1508,7 +1473,7 @@ def _test_model_impl(
if mode != 'static':
model.train()

prepared = prepare_fx(graph_module, qconfig_dict)
prepared = prepare_fx(model, qconfig_dict)

if mode == 'ddp':
mp.spawn(run_ddp,
Expand Down
4 changes: 1 addition & 3 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from torch.fx import (
GraphModule,
Proxy,
symbolic_trace,
map_arg
)

Expand Down Expand Up @@ -413,9 +412,8 @@ def insert_observer(node, observer, device):
if isinstance(obj, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = self.modules[node.target]
traced_standalone_module = symbolic_trace(standalone_module)
prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx
observed_standalone_module = prepare(traced_standalone_module, {'': qconfig})
observed_standalone_module = prepare(standalone_module, {'': qconfig})
observed_standalone_module.qconfig = qconfig
standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
observed_standalone_module = mark_observed_standalone_module(observed_standalone_module)
Expand Down

0 comments on commit 2b204e6

Please sign in to comment.