Skip to content

Commit

Permalink
[static runtime] binding for aten::sub_out (#56656)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #56656

Test Plan:
```
./buck-out/opt/gen/caffe2/caffe2/fb/predictor/ptvsc2_predictor_bench --scripted_model=/data/users/ansha/tmp/adfinder/aug_1x/210616848_0.predictor.disagg.local.local.pt --pt_inputs=/data/users/ansha/tmp/adfinder/aug_1x/210616848_0.predictor.disagg.input_data.container.pt --iters=500 --warmup_iters=500 --num_threads=1 --pt_enable_static_runtime=1 --pt_cleanup_activations=true --pt_enable_out_variant=1 --pt_optimize_memory=1 --compare_results=1 --do_profile=1 --adsfinder_compatibility=1
```
```
Time per node type:
        1.85766 ms.    35.7817%. fb::sigrid_transforms_torch_bind (1 nodes)
         1.1238 ms.    21.6464%. aten::linear (6 nodes)
       0.858116 ms.    16.5288%. aten::argmin (1 nodes)
       0.334183 ms.    6.43694%. aten::matmul (1 nodes)
       0.173697 ms.     3.3457%. fb::clip_ranges_gather_sigrid_hash_v3 (77 nodes)
       0.118827 ms.    2.28881%. fb::clip_ranges_gather (263 nodes)
       0.101348 ms.    1.95215%. aten::sub (1 nodes)
      0.0748209 ms.    1.44118%. aten::repeat (1 nodes)
      0.0582576 ms.    1.12214%. aten::norm (1 nodes)
      0.0474353 ms.   0.913686%. fb::batch_box_cox (1 nodes)
      0.0457588 ms.   0.881393%. aten::__getitem__ (506 nodes)
      0.0435175 ms.   0.838222%. prim::TupleUnpack (254 nodes)
      0.0425416 ms.   0.819425%. aten::sigmoid (2 nodes)
      0.0383822 ms.   0.739308%. fb::offsets_to_ranges (253 nodes)
      0.0330187 ms.   0.635996%. aten::mul (3 nodes)
       0.027534 ms.   0.530352%. fb::simple_embedding_bag_sum (3 nodes)
      0.0274914 ms.   0.529532%. aten::pow (1 nodes)
      0.0236733 ms.   0.455989%. fb::casted_batch_one_hot_lengths (1 nodes)
       0.023348 ms.   0.449723%. fb::concat_add_mul_replacenan_clip (1 nodes)
      0.0193511 ms.   0.372735%. aten::sum (3 nodes)
      0.0188839 ms.   0.363737%. prim::DictConstruct (2 nodes)
      0.0183191 ms.   0.352858%. prim::TupleConstruct (1 nodes)
      0.0119029 ms.    0.22927%. aten::div (1 nodes)
      0.0103263 ms.   0.198902%. static_runtime::to_copy (8 nodes)
     0.00977658 ms.   0.188314%. prim::ListConstruct (4 nodes)
     0.00924042 ms.   0.177986%. fb::sigrid_hash_precompute (1 nodes)
     0.00692162 ms.   0.133322%. aten::contiguous (1 nodes)
     0.00567485 ms.   0.109307%. aten::narrow (4 nodes)
     0.00362285 ms.  0.0697823%. aten::logit (1 nodes)
     0.00329995 ms.  0.0635627%. aten::add (1 nodes)
     0.00285633 ms.  0.0550178%. aten::full (1 nodes)
     0.00268469 ms.  0.0517118%. fb::gather_ranges (4 nodes)
     0.00248577 ms.  0.0478803%. aten::stack (1 nodes)
     0.00241782 ms.  0.0465715%. aten::relu (1 nodes)
     0.00233674 ms.  0.0450096%. aten::clamp_min (1 nodes)
     0.00222238 ms.  0.0428068%. static_runtime::reshape_copy (2 nodes)
     0.00171177 ms.  0.0329716%. aten::size (3 nodes)
     0.00120008 ms.  0.0231155%. aten::expand_as (1 nodes)
     0.00112628 ms.  0.0216942%. fb::clip_ranges (2 nodes)
     0.00103193 ms.  0.0198768%. fb::lengths_to_offsets (3 nodes)
    0.000598624 ms.  0.0115305%. static_runtime::flatten_copy (1 nodes)
    0.000236196 ms. 0.00454954%. prim::device (1 nodes)
        5.19164 ms. in Total
StaticRuntime setup time: 0.000868 ms
Memory allocation time: 0.0109619 ms
Memory deallocation time: 0.071791 ms
Outputs deallocation time: 0.0560187 ms
Total memory managed: 1232320 bytes
Total number of reused tensors: 32
W0421 17:40:52.053653 1746499 PyTorchPredictorContainer.cpp:200] Failed to load metadata file
W0421 17:40:52.053757 1746499 PyTorchPredictorContainer.cpp:457] Couldn't find model param config file xl_model_weights/model_param_config
I0421 17:40:52.053779 1746499 PyTorchPredictorBenchLib.cpp:137] PyTorch predictor: number of prediction threads 1
I0421 17:40:52.185776 1746499 PyTorchPredictorBenchLib.cpp:230] PyTorch run finished. Milliseconds per iter: 131.985. Iters per second: 7.57661
I0421 17:40:52.337853 1746499 PtVsBlackBoxPredictorBenchLib.cpp:132] Finished comparing PT static runtime and jit interpreter results
```

Reviewed By: hlu1

Differential Revision: D27929253

fbshipit-source-id: 5a7984ba3ce2d6d4bce0a0ab6c5e09e8c037b44e
  • Loading branch information
ajyu authored and facebook-github-bot committed Apr 22, 2021
1 parent 3355c30 commit 690c8b4
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
20 changes: 20 additions & 0 deletions benchmarks/static_runtime/test_scripts.h
Expand Up @@ -244,3 +244,23 @@ const auto div_scalar_mode = R"JIT(
def forward(self, a: Tensor, b: float, c: str):
return torch.div(a, b, rounding_mode=c)
)JIT";

const auto sub_tensor = R"JIT(
def forward(self, a: Tensor, b: Tensor):
return torch.sub(a, b)
)JIT";

const auto sub_scalar = R"JIT(
def forward(self, a: Tensor, b: int):
return torch.sub(a, b)
)JIT";

const auto sub_tensor_alpha = R"JIT(
def forward(self, a: Tensor, b: Tensor, c: float):
return torch.sub(a, b, alpha=c)
)JIT";

const auto sub_scalar_alpha = R"JIT(
def forward(self, a: Tensor, b: float, c: int):
return torch.sub(a, b, alpha=c)
)JIT";
17 changes: 17 additions & 0 deletions benchmarks/static_runtime/test_static_runtime.cc
Expand Up @@ -174,6 +174,23 @@ TEST(StaticRuntime, IndividualOps_Div) {
testStaticRuntime(div_scalar_mode, args3);
}

TEST(StaticRuntime, IndividualOps_Sub) {
auto a = at::randn({2, 3});
auto b = at::randn({2, 3});

std::vector<IValue> args0{a, b};
testStaticRuntime(sub_tensor, args0);

std::vector<IValue> args1{a, 3};
testStaticRuntime(sub_scalar, args1);

std::vector<IValue> args2{a, b, 2.3};
testStaticRuntime(sub_tensor_alpha, args2);

std::vector<IValue> args3{a, 2.3, 4};
testStaticRuntime(sub_scalar_alpha, args3);
}

TEST(StaticRuntime, IndividualOps_Reshape) {
auto a = at::randn({2, 3});
auto b = std::vector<int64_t>({3, 2});
Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/jit/runtime/static/ops.cpp
Expand Up @@ -1150,5 +1150,23 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode);
};
});

REGISTER_OPERATOR_FUNCTOR(aten::sub, aten_sub, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto alpha = p_node->Input(2).toScalar();

if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);

const auto& in1_t = p_node->Input(1).isTensor()
? p_node->Input(1).toTensor()
: at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar());
at::cpu::sub_out(out_t, in0_t, in1_t, alpha);
};
});
} // namespace jit
} // namespace torch

0 comments on commit 690c8b4

Please sign in to comment.