Skip to content

Commit

Permalink
[pytorch] Replace "blacklist" in test/test_mobile_optimizer.py
Browse files Browse the repository at this point in the history
Summary:
This diff addresses #41443.
It is a clone of D23205313 which could not be imported from GitHub
for strange reasons.

Test Plan: Continuous integration.

Reviewed By: AshkanAliabadi

Differential Revision: D23967322

fbshipit-source-id: 1140c9de3c58fd155e40f4e21c7cdf9d927b2ad2
  • Loading branch information
Meghan Lele authored and facebook-github-bot committed Sep 29, 2020
1 parent 17be7c6 commit a1de00d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/test_mobile_optimizer.py
Expand Up @@ -100,8 +100,8 @@ def forward(self, x):
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)


optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blacklist_no_prepack)
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)

FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
Expand All @@ -118,14 +118,14 @@ def forward(self, x):
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward(bn_scripted_module._c).graph))

optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_prepack)
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
bn_input = torch.rand(1, 1, 6, 6)
torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)

optimization_blacklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_fold_bn)
optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
.run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
bn_input = torch.rand(1, 1, 6, 6)
Expand Down

0 comments on commit a1de00d

Please sign in to comment.