-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[export] Temporarily bypass torch_fn in partitioner #134292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134292
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit c1b4102 with merge base 78d69bf ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D61569049 |
5b2b612
to
e28380d
Compare
This pull request was exported from Phabricator. Differential Revision: D61569049 |
e28380d
to
b703779
Compare
This pull request was exported from Phabricator. Differential Revision: D61569049 |
Summary: Pull Request resolved: pytorch#134292 "torch_fn" is not correct for the decomposed add node from batch norm. This is a temporary workaround to bypass torch fn. For example, for the graph below (test_qat_conv2d_unary graph): ``` graph(): %conv_weight : [num_users=1] = get_attr[target=conv.weight] %bn_weight : [num_users=1] = get_attr[target=bn.weight] %bn_bias : [num_users=1] = get_attr[target=bn.bias] %bn_running_mean : [num_users=1] = get_attr[target=bn.running_mean] %bn_running_var : [num_users=1] = get_attr[target=bn.running_var] %bn_num_batches_tracked : [num_users=1] = get_attr[target=bn.num_batches_tracked] %x : [num_users=1] = placeholder[target=x] %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %conv_weight, None, [1, 1], [1, 1]), kwargs = {}) %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%bn_num_batches_tracked, 1), kwargs = {}) %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %bn_weight, %bn_bias, %bn_running_mean, %bn_running_var, True, 0.1, 1e-05, True), kwargs = {}) %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%batch_norm,), kwargs = {}) %max_pool2d : [num_users=1] = call_function[target=torch.ops.aten.max_pool2d.default](args = (%relu, [3, 3], [3, 3]), kwargs = {}) return (max_pool2d,) ``` the add_ node has `'torch_fn': ('add__1', 'method_descriptor.add_'),` in its meta. If we run the line below in `_annotate_qat_conv2d_bn_binary_unary`, we'll have a partition without output nodes. ``` find_sequential_partitions( gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] ) ```` ``` partition_list [ SourcePartition(nodes=[conv_weight, conv2d], source=<class 'torch.nn.modules.conv.Conv2d'>, input_nodes=[x], output_nodes=[conv2d], params=[conv_weight]), SourcePartition(nodes=[bn_weight, bn_bias, bn_running_mean, bn_running_var, bn_num_batches_tracked, add_, batch_norm], source=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, input_nodes=[conv2d], output_nodes=[batch_norm], params=[bn_num_batches_tracked, bn_running_var, bn_bias, bn_weight, bn_running_mean]), SourcePartition(nodes=[add_], source='add_', input_nodes=[bn_num_batches_tracked], output_nodes=[], params=[]) ] ``` We should not have the last partition. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_qat_conv2d buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:fx -- -r TestSourceMatcher ``` Reviewed By: angelayi Differential Revision: D61569049
b703779
to
c1b4102
Compare
This pull request was exported from Phabricator. Differential Revision: D61569049 |
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary:
"torch_fn" is not correct for the decomposed add node from batch norm. This is a temporary workaround to bypass torch fn.
For example, for the graph below (test_qat_conv2d_unary graph):
the add_ node has
'torch_fn': ('add__1', 'method_descriptor.add_'),
in its meta.If we run the line below in
_annotate_qat_conv2d_bn_binary_unary
, we'll have a partition without output nodes.We should not have the last partition.
Test Plan:
Differential Revision: D61569049