Skip to content

Commit

Permalink
[quant][fx] Update name of packed weight attributes (#51259)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #51259

Store the FQN of the module that is using the packed weights (the quantized op)

In the case of fusion we update the scope mapping to store the module path of the fused node.

Test Plan:
python test/test_quantization.py test_packed_weight_fused_op

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D26117964

fbshipit-source-id: 9d929997baafb1c91063dd9786a451b0040ae461
  • Loading branch information
supriyar authored and facebook-github-bot committed Jan 29, 2021
1 parent 05c8cd7 commit 916af89
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
36 changes: 35 additions & 1 deletion test/quantization/test_quantize_fx.py
Expand Up @@ -1525,7 +1525,6 @@ def forward(self, x):
qconfig_dict = {"": default_qconfig}
m = prepare_fx(model, qconfig_dict)
m(torch.rand(5, 5))

m = convert_fx(m)
keys = m.state_dict().keys()
quant_scale_count = quant_zero_point = scale_count = zero_point_count = 0
Expand Down Expand Up @@ -1557,6 +1556,41 @@ def forward(self, x):
assert hasattr(m, "mods2_scale_0")
assert hasattr(m, "mods2_zero_point_0")

@skipIfNoFBGEMM
def test_packed_weight_fused_op(self):
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.ones(5, 5)
self.b = torch.zeros(5)

def forward(self, x):
return F.linear(x, self.w, self.b)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods1 = torch.nn.Sequential(
Linear(),
Linear()
)
self.mods2 = Linear()
self.relu = F.relu

def forward(self, x):
x = self.mods1(x)
x = self.mods2(x)
x = self.relu(x)
return x

model = M().eval()
qconfig_dict = {"": default_qconfig}
m = prepare_fx(model, qconfig_dict)
m(torch.rand(5, 5))
m = convert_fx(m)
assert hasattr(m, "mods1_0_packed_weight_0")
assert hasattr(m, "mods1_1_packed_weight_0")
assert hasattr(m, "mods2_packed_weight_0")

@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
Expand Down
18 changes: 16 additions & 2 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -339,8 +339,13 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.conv_node.name, scale, zero_point)
qconv_args = (conv_input, packed_weight, scale_node, zero_point_node)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
return quantizer.quantized_graph.create_node(
op = quantizer.quantized_graph.create_node(
'call_function', qconv_op, qconv_args, kwargs)
# Store the name of the fused op to get the path of node after fusion as well.
# TODO: may need to change the key to Node regenerate the map in each transformation,
# since we might not be able to rely on the name
quantizer.node_name_to_scope[op.name] = quantizer.node_name_to_scope[self.conv_node.name]
return op
else:
# conv2d_dyanmic branch
raise Exception("Only static quant is supported for conv")
Expand Down Expand Up @@ -496,13 +501,22 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.linear_node.name, scale, zero_point)

qlinear_args = (linear_input, packed_weight, scale_node, zero_point_node)
return quantizer.quantized_graph.create_node(
op = quantizer.quantized_graph.create_node(
"call_function", qlinear_op, qlinear_args, kwargs)
# Store the name of the fused op to get the path of node after fusion as well.
# TODO: may need to change the key to Node regenerate the map in each transformation,
# since we might not be able to rely on the name
quantizer.node_name_to_scope[op.name] = quantizer.node_name_to_scope[self.linear_node.name]
return op
else:
linear_input = load_arg(quantized=False)(self.linear_node.args[0])
qlinear_args = (linear_input, packed_weight) # type: ignore
op_out = quantizer.quantized_graph.create_node(
"call_function", torch.ops.quantized.linear_dynamic, qlinear_args, kwargs)
# Store the name of the dynamic op to get the path of node after replacement as well.
# TODO: may need to change the key to Node regenerate the map in each transformation,
# since we might not be able to rely on the name
quantizer.node_name_to_scope[op_out.name] = quantizer.node_name_to_scope[self.linear_node.name]
if self.relu_node:
op_out = quantizer.quantized_graph.create_node("call_function", torch.nn.functional.relu, (op_out,), {})
return op_out
Expand Down
6 changes: 4 additions & 2 deletions torch/quantization/fx/quantize.py
Expand Up @@ -932,15 +932,17 @@ def _fold_weight(self, quantized: GraphModule) -> GraphModule:

def load_arg(a):
return map_arg(a, lambda node: env[node.name])
get_new_packed_weight_name = \
get_new_attr_name_with_prefix('_fx_pass_packed_weight_')
quantized_root = quantized
quantized_graph = quantized.graph
for node in quantized_graph.nodes:
prepack_node = folded_nodes.get(node.name, None)
if prepack_node is node:
packed_weight = packed_weights[node.name]
# add a prepacked attribute to root
op_node = list(prepack_node.users)[0]
module_path, _ = self.node_name_to_scope[op_node.name]
get_new_packed_weight_name = \
get_new_attr_name_with_prefix(module_path + '_packed_weight_')
packed_weight_name = get_new_packed_weight_name(quantized_root)
setattr(quantized_root, packed_weight_name, packed_weight)
# replace prepack node with a getattr node
Expand Down

0 comments on commit 916af89

Please sign in to comment.