Skip to content

Commit

Permalink
Pow ONNX Operator added with the corresponding unit tests to validate…
Browse files Browse the repository at this point in the history
… the code.
  • Loading branch information
Neel-Shah-29 committed Jul 15, 2022
1 parent a1155ed commit 1808e4b
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 2 deletions.
1 change: 1 addition & 0 deletions tmva/sofie/CMakeLists.txt
Expand Up @@ -31,6 +31,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
TMVA/ROperator_Concat.hxx
TMVA/ROperator_Identity.hxx
TMVA/ROperator_Softmax.hxx
TMVA/ROperator_Pow.hxx
TMVA/SOFIE_common.hxx
TMVA/SOFIEHelpers.hxx
SOURCES
Expand Down
1 change: 1 addition & 0 deletions tmva/sofie/inc/TMVA/OperatorList.hxx
Expand Up @@ -16,3 +16,4 @@
#include "TMVA/ROperator_Identity.hxx"
#include "TMVA/ROperator_Softmax.hxx"
#include "TMVA/ROperator_Concat.hxx"
#include "TMVA/ROperator_Pow.hxx"
94 changes: 94 additions & 0 deletions tmva/sofie/inc/TMVA/ROperator_Pow.hxx
@@ -0,0 +1,94 @@
#ifndef TMVA_SOFIE_ROPERATOR_Pow
#define TMVA_SOFIE_ROPERATOR_Pow

#include "TMVA/SOFIE_common.hxx"
#include "TMVA/ROperator.hxx"
#include "TMVA/RModel.hxx"

#include <sstream>

namespace TMVA{
namespace Experimental{
namespace SOFIE{

template <typename T>
class ROperator_Pow final : public ROperator
{

private:

std::string fNX1;
std::string fNX2;
std::string fNY;
std::vector<size_t> fShape;

public:
public:
ROperator_Pow(){}
ROperator_Pow(std::string nameX1, std::string nameX2, std::string nameY):
fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}

// type of output given input
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
return input;
}

// shape of output tensors given input tensors
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
// assume now inputs have same shape (no broadcasting)
auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
return ret;
}

void Initialize(RModel& model){
// input must be a graph input, or already initialized intermediate tensor
if (model.CheckIfTensorAlreadyExist(fNX1) == false){
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX1 + "is not found in model");
}
if (model.CheckIfTensorAlreadyExist(fNX2) == false) {
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX2 + "is not found in model");
}
auto shapeX1 = model.GetTensorShape(fNX1);
auto shapeX2 = model.GetTensorShape(fNX2);
// If the shape of 2 tensors are not same we perform multi-directional Broadcasting.
// We only support tensors with same length and the resultant output length should also be same.
if (shapeX1 != shapeX2) {
fShape = UTILITY::Multidirectional_broadcast(shapeX1,shapeX2);
size_t length1 = ConvertShapeToLength(shapeX1);
size_t length2 = ConvertShapeToLength(shapeX2);
size_t output_length = ConvertShapeToLength(fShape);
if(length1 != length2 || length1 != output_length){
throw std::runtime_error(std::string("TMVA SOFIE Binary Op does not support input tensors with different lengths. The output tensor should also have the same length as the input tensors."));
}
}
// If both the tensors have same shape then assign the same shape to resultant output.
else if(shapeX1 == shapeX2){
fShape = shapeX1;
}
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX1), fShape);
}


std::string Generate(std::string OpName){
OpName = "op_" + OpName;
if (fShape.empty()) {
throw std::runtime_error("TMVA SOFIE Pow Op called to Generate without being initialized first");
}
std::stringstream out;
size_t length = ConvertShapeToLength(fShape);
out << "\n//------ POW\n";
out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
out << SP << SP << "tensor_" << fNY << "[id] = std::pow(tensor_" << fNX1 << "[id] , tensor_" << fNX2 << "[id]);\n";
out << SP << "}\n";
return out.str();
}

};

}//SOFIE
}//Experimental
}//TMVA


#endif //TMVA_SOFIE_ROPERATOR_Pow

56 changes: 56 additions & 0 deletions tmva/sofie/test/TestCustomModelsFromONNX.cxx
Expand Up @@ -63,6 +63,12 @@
#include "AvgPool_FromONNX.hxx"
#include "input_models/references/AvgPool.ref.hxx"

#include "Pow_FromONNX.hxx"
#include "input_models/references/Pow.ref.hxx"

#include "Pow_broadcast_FromONNX.hxx"
#include "input_models/references/Pow_broadcast.ref.hxx"

#include "RNNBatchwise_FromONNX.hxx"
#include "input_models/references/RNNBatchwise.ref.hxx"

Expand Down Expand Up @@ -643,6 +649,56 @@ TEST(ONNX, AvgPool){

}

TEST(ONNX, Pow){
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Preparing the standard input
std::vector<float> input1({
1, 2, 3
});
std::vector<float> input2({
4, 5, 6
});

TMVA_SOFIE_Pow::Session s("Pow_FromONNX.dat");
std::vector<float> output = s.infer(input2.data(),input1.data());
// Checking output size
EXPECT_EQ(output.size(), sizeof(Pow_ExpectedOutput::outputs) / sizeof(float));

float *correct = Pow_ExpectedOutput::outputs;

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
}

}

TEST(ONNX, Pow_broadcast){
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Preparing the standard input
std::vector<float> input1({
1, 2, 3, 3, 4, 5
});
std::vector<float> input2({
2, 3, 4, 2, 3, 4
});

TMVA_SOFIE_Pow_broadcast::Session s("Pow_broadcast_FromONNX.dat");
std::vector<float> output = s.infer(input2.data(),input1.data());
// Checking output size
EXPECT_EQ(output.size(), sizeof(Pow_broadcast_ExpectedOutput::outputs) / sizeof(float));

float *correct = Pow_broadcast_ExpectedOutput::outputs;

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
}

}

TEST(ONNX, RNNBatchwise)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
Expand Down
16 changes: 16 additions & 0 deletions tmva/sofie/test/input_models/Pow.onnx
@@ -0,0 +1,16 @@
pytorch1.11.0:�
)
onnx::Pow_0
onnx::Pow_12Pow_0"Powtorch-jit-exportZ
onnx::Pow_0


Z
onnx::Pow_1


b
2


B
18 changes: 18 additions & 0 deletions tmva/sofie/test/input_models/Pow_broadcast.onnx
@@ -0,0 +1,18 @@
pytorch1.11.0:�
)
onnx::Pow_0
onnx::Pow_12Pow_0"Powtorch-jit-exportZ!
onnx::Pow_0



Z
onnx::Pow_1


b
2



B
5 changes: 5 additions & 0 deletions tmva/sofie/test/input_models/references/Pow.ref.hxx
@@ -0,0 +1,5 @@
namespace Pow_ExpectedOutput{
float outputs[] = {
1, 32, 729
};
} // namespace Pow_ExpectedOutput
6 changes: 6 additions & 0 deletions tmva/sofie/test/input_models/references/Pow_broadcast.ref.hxx
@@ -0,0 +1,6 @@
namespace Pow_broadcast_ExpectedOutput{
float outputs[] = {
1, 8, 81,
9, 64, 625
};
} // namespace Pow_broadcast_ExpectedOutput
4 changes: 3 additions & 1 deletion tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx
Expand Up @@ -44,6 +44,7 @@ std::unique_ptr<ROperator> make_ROperator_BasicBinary(const onnx::NodeProto &nod
std::unique_ptr<ROperator> make_ROperator_Identity(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
std::unique_ptr<ROperator> make_ROperator_Softmax(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
std::unique_ptr<ROperator> make_ROperator_Concat(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
std::unique_ptr<ROperator> make_ROperator_Pow(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);

using factoryMethodMap = std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(const onnx::NodeProto&, const onnx::GraphProto&, std::unordered_map<std::string, ETensorType>&)>;
const factoryMethodMap mapOptypeOperator = {
Expand Down Expand Up @@ -73,7 +74,8 @@ const factoryMethodMap mapOptypeOperator = {
{"Flatten", &make_ROperator_Reshape},
{"Identity", &make_ROperator_Identity},
{"Softmax", &make_ROperator_Softmax},
{"Concat", &make_ROperator_Concat}
{"Concat", &make_ROperator_Concat},
{"Pow", &make_ROperator_Pow}
};

using factoryMethodMap1 = std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(const onnx::NodeProto&,const onnx::NodeProto&, const onnx::GraphProto&, std::unordered_map<std::string, ETensorType>&)>;
Expand Down
34 changes: 33 additions & 1 deletion tmva/sofie_parsers/src/RModelParser_ONNX.cxx
Expand Up @@ -127,6 +127,38 @@ std::unique_ptr<ROperator> make_ROperator_Transpose(const onnx::NodeProto& nodep
return op;
}

std::unique_ptr<ROperator> make_ROperator_Pow(const onnx::NodeProto& nodeproto, const onnx::GraphProto& /*graphproto */, std::unordered_map<std::string, ETensorType>& tensor_type){

ETensorType input_type;

auto input_name = nodeproto.input(0);
auto it = tensor_type.find(input_name);
if (it != tensor_type.end()){
input_type = it->second;
}else{
throw std::runtime_error("TMVA::SOFIE ONNX Parser pow op has input tensor" + input_name + " but its type is not yet registered");
}

std::unique_ptr<ROperator> op;


switch(input_type){
case ETensorType::FLOAT:
op.reset(new ROperator_Pow<float>(nodeproto.input(0), nodeproto.input(1), nodeproto.output(0)));
break;
default:
throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Pow does not yet support input type " + std::to_string(static_cast<int>(input_type)));
}

ETensorType output_type = (op->TypeInference({input_type}))[0];
auto it2 = tensor_type.find(nodeproto.output(0));
if (it2 == tensor_type.end()){
tensor_type[nodeproto.output(0)] = output_type;
}

return op;
}

std::unique_ptr<ROperator> make_ROperator_Identity(const onnx::NodeProto& nodeproto, const onnx::GraphProto& /*graphproto */, std::unordered_map<std::string, ETensorType>& tensor_type){

ETensorType input_type;
Expand Down Expand Up @@ -1381,7 +1413,7 @@ RModel RModelParser_ONNX::Parse(std::string filename, bool verbose){
rmodel.AddBlasRoutines({"Gemm", "Axpy"});
} else if (op_type == "RNN") {
rmodel.AddBlasRoutines({"Gemm", "Axpy"});
} else if (op_type == "Selu" || op_type == "Sigmoid") {
} else if (op_type == "Selu" || op_type == "Sigmoid" || op_type == "Pow") {
rmodel.AddNeededStdLib("cmath");
} else if (op_type == "LSTM") {
rmodel.AddBlasRoutines({"Gemm", "Axpy"});
Expand Down

0 comments on commit 1808e4b

Please sign in to comment.