diff --git a/test/cpp/test_custom_operators.cpp b/test/cpp/test_custom_operators.cpp new file mode 100644 index 00000000000..d9a950fd7cd --- /dev/null +++ b/test/cpp/test_custom_operators.cpp @@ -0,0 +1,30 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#include +#include +#include + +// FIXME: the include path differs from OSS due to the extra csrc +#include + +TEST(test_custom_operators, nms) { + // make sure that the torchvision ops are visible to the jit interpreter + auto& ops = torch::jit::getAllOperatorsFor(torch::jit::Symbol::fromQualString("torchvision::nms")); + ASSERT_EQ(ops.size(), 1); + + auto& op = ops.front(); + ASSERT_EQ(op->schema().name(), "torchvision::nms"); + + torch::jit::Stack stack; + at::Tensor boxes = at::rand({50, 4}), scores = at::rand({50}); + double thresh = 0.7; + + torch::jit::push(stack, boxes, scores, thresh); + op->getOperation()(&stack); + at::Tensor output_jit; + torch::jit::pop(stack, output_jit); + + at::Tensor output = vision::ops::nms(boxes, scores, thresh); + ASSERT_TRUE(output_jit.allclose(output)); + +}