diff --git a/include/glow/Backends/Interpreter/InterpreterFunction.h b/include/glow/Backends/Interpreter/InterpreterFunction.h index 561f7a2f63..0ad07241bd 100644 --- a/include/glow/Backends/Interpreter/InterpreterFunction.h +++ b/include/glow/Backends/Interpreter/InterpreterFunction.h @@ -324,6 +324,9 @@ class BoundInterpreterFunction { template void fwdEmbeddingBagByteRowwiseOffsetsImpl( const EmbeddingBagByteRowwiseOffsetsInst *I); + + template void fwdFlipInstImpl(const FlipInst *I); + ///@} }; diff --git a/include/glow/Graph/Graph.h b/include/glow/Graph/Graph.h index 62bbd13c00..5c48e6f9f7 100644 --- a/include/glow/Graph/Graph.h +++ b/include/glow/Graph/Graph.h @@ -646,6 +646,10 @@ class Function final : public Named { llvm::ArrayRef shuffle, const std::string &layout = ANY_LAYOUT); + /// Create a node with the name \p name which flips (reorders) the elements + /// of the input \p input along the given axis \p axis. + FlipNode *createFlip(llvm::StringRef name, NodeValue input, unsigned_t axis); + /// Create a series of nodes that implement a Broadcast operation. The \p /// input Tensor is broadcasted based on \p newShape and along the \p axis, /// which defines the offset from the leading dimension under which diff --git a/include/glow/Graph/VerifierHelper.h b/include/glow/Graph/VerifierHelper.h index b33f1c4f64..f4c801cf4c 100644 --- a/include/glow/Graph/VerifierHelper.h +++ b/include/glow/Graph/VerifierHelper.h @@ -106,6 +106,12 @@ struct CompareOperatorLessEqual : public CompareWithName { bool operator()(const Ty &a, const Ty &b) const override { return a <= b; } llvm::StringRef getCompareName() const override { return "LessEqual"; } }; + +/// Operator <. +template struct CompareOperatorLess : public CompareWithName { + bool operator()(const Ty &a, const Ty &b) const override { return a < b; } + llvm::StringRef getCompareName() const override { return "Less"; } +}; /// @} /// Main API of the verifier. diff --git a/include/glow/Importer/ONNXModelLoader.h b/include/glow/Importer/ONNXModelLoader.h index 959fc63bcc..1c211001df 100644 --- a/include/glow/Importer/ONNXModelLoader.h +++ b/include/glow/Importer/ONNXModelLoader.h @@ -252,6 +252,10 @@ class ONNXModelLoader Error loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op, const ArgumentDictionaryTy &dict); + /// Load Flip Glow operator. + Error loadFlip(const ONNX_NAMESPACE::NodeProto &op, + const ArgumentDictionaryTy &dict); + protected: /// Load the network operators from the GraphProto. /// \returns Error if network cannot be loaded. diff --git a/lib/Backends/CPU/CPUBackend.cpp b/lib/Backends/CPU/CPUBackend.cpp index 6ca56a2586..e1a3a0f0ab 100644 --- a/lib/Backends/CPU/CPUBackend.cpp +++ b/lib/Backends/CPU/CPUBackend.cpp @@ -106,6 +106,12 @@ bool CPUBackend::isOpSupported(const NodeInfo &NI) const { {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy, ElemKind::BoolTy}); + case Kinded::Kind::FlipNodeKind: + return NI.allInputsAndOutputsHaveSameElemKind( + {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int16QTy, + ElemKind::Int32QTy, ElemKind::Int32ITy, ElemKind::Int64ITy, + ElemKind::BoolTy}); + case Kinded::Kind::SparseLengthsSumNodeKind: return NI.allInputsAndOutputsHaveSameElemKind( {ElemKind::FloatTy}, {SparseLengthsSumNode::IndicesIdx, diff --git a/lib/Backends/CPU/libjit/libjit.cpp b/lib/Backends/CPU/libjit/libjit.cpp index 38d98cefa7..d69460b5bb 100644 --- a/lib/Backends/CPU/libjit/libjit.cpp +++ b/lib/Backends/CPU/libjit/libjit.cpp @@ -563,6 +563,39 @@ static void libjit_transpose_generic(const T *inW, T *outW, const dim_t *idim, } } +template +static void libjit_flip_generic(const T *inW, T *outW, const dim_t *dims, + dim_t axis, dim_t numDims) { + + // Product of outer dimensions excluding the flip dimension. + dim_t outerLen = 1; + for (dim_t idx = 0; idx < axis; idx++) { + outerLen *= dims[idx]; + } + + // Flip dimension. + dim_t len = dims[axis]; + + // Product of inner dimensions excluding the flip dimension. + dim_t innerLen = 1; + for (dim_t idx = axis + 1; idx < numDims; idx++) { + innerLen *= dims[idx]; + } + + // Flip axis such that input data is read linearly. + const T *inpPtr = inW; + T *outPtr = outW + (len - 1) * innerLen; + for (dim_t outerIdx = 0; outerIdx < outerLen; outerIdx++) { + for (dim_t idx = 0; idx < len; idx++) { + for (dim_t innerIdx = 0; innerIdx < innerLen; innerIdx++) { + *outPtr++ = *inpPtr++; + } + outPtr -= 2 * innerLen; + } + outPtr += 2 * len * innerLen; + } +} + template static void libjit_max_pool_generic(const T *inW, T *outW, const dim_t *inWdims, const dim_t *outWdims, dim_t *kernelSizes, @@ -1944,6 +1977,36 @@ void libjit_transpose_b(const bool *inW, bool *outW, const dim_t *idim, libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims); } +void libjit_flip_i8(const int8_t *inW, int8_t *outW, const dim_t *dims, + dim_t axis, dim_t numDims) { + libjit_flip_generic(inW, outW, dims, axis, numDims); +} + +void libjit_flip_i16(const int16_t *inW, int16_t *outW, const dim_t *dims, + dim_t axis, dim_t numDims) { + libjit_flip_generic(inW, outW, dims, axis, numDims); +} + +void libjit_flip_i32(const int32_t *inW, int32_t *outW, const dim_t *dims, + dim_t axis, dim_t numDims) { + libjit_flip_generic(inW, outW, dims, axis, numDims); +} + +void libjit_flip_u(const int64_t *inW, int64_t *outW, const dim_t *dims, + dim_t axis, dim_t numDims) { + libjit_flip_generic(inW, outW, dims, axis, numDims); +} + +void libjit_flip_f(const float *inW, float *outW, const dim_t *dims, dim_t axis, + dim_t numDims) { + libjit_flip_generic(inW, outW, dims, axis, numDims); +} + +void libjit_flip_b(const bool *inW, bool *outW, const dim_t *dims, dim_t axis, + dim_t numDims) { + libjit_flip_generic(inW, outW, dims, axis, numDims); +} + void libjit_insert_tensor_f(float *tensor, float *slice, dim_t *offset, dim_t *tensorDim, dim_t *sliceDim, dim_t numDimsTensor, dim_t numDimsSlice, diff --git a/lib/Backends/Interpreter/Interpreter.cpp b/lib/Backends/Interpreter/Interpreter.cpp index e2e03199f1..dc7309fe68 100644 --- a/lib/Backends/Interpreter/Interpreter.cpp +++ b/lib/Backends/Interpreter/Interpreter.cpp @@ -523,6 +523,7 @@ bool Interpreter::isOpSupported(const NodeInfo &NI) const { case Kinded::Kind::TransposeNodeKind: case Kinded::Kind::ReshapeNodeKind: case Kinded::Kind::SaveNodeKind: + case Kinded::Kind::FlipNodeKind: // These work regardless of the underlying type. return true; diff --git a/lib/Backends/Interpreter/InterpreterNodes.cpp b/lib/Backends/Interpreter/InterpreterNodes.cpp index 5df8c6d502..5df8bc6da1 100644 --- a/lib/Backends/Interpreter/InterpreterNodes.cpp +++ b/lib/Backends/Interpreter/InterpreterNodes.cpp @@ -28,6 +28,36 @@ using namespace glow; +#define dispatchImpl(functionName, elemTy, ...) \ + switch (elemTy) { \ + case ElemKind::FloatTy: \ + functionName(__VA_ARGS__); \ + break; \ + case ElemKind::Float16Ty: \ + functionName(__VA_ARGS__); \ + break; \ + case ElemKind::Int8QTy: \ + functionName(__VA_ARGS__); \ + break; \ + case ElemKind::Int16QTy: \ + functionName(__VA_ARGS__); \ + break; \ + case ElemKind::Int32QTy: \ + functionName(__VA_ARGS__); \ + break; \ + case ElemKind::Int32ITy: \ + functionName(__VA_ARGS__); \ + break; \ + case ElemKind::Int64ITy: \ + functionName(__VA_ARGS__); \ + break; \ + case ElemKind::BoolTy: \ + functionName(__VA_ARGS__); \ + break; \ + default: \ + llvm_unreachable("Type is not supported"); \ + } + #define dispatchFloatingPointImpl(functionName, elemTy, ...) \ switch (elemTy) { \ case ElemKind::FloatTy: \ @@ -4012,3 +4042,53 @@ void BoundInterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) { #undef CONVERT llvm_unreachable("Type not supported"); } + +template +void BoundInterpreterFunction::fwdFlipInstImpl(const FlipInst *I) { + + static_assert(max_tensor_dimensions == 6, + "Loops below assume max_tensor_dimensions = 6."); + + auto *src = I->getSrc(); + auto *dest = I->getDest(); + + // Get unowned handles of src and dest with dims expanded to maximum. + ShapeVector eDims = expandDimsToMax(src->dims()); + auto eSrc = getTensor(src)->getUnowned(eDims); + auto eDest = getTensor(dest)->getUnowned(eDims); + auto srcH = eSrc.getHandle(); + auto destH = eDest.getHandle(); + +#define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5) \ + for (dim_t idx0 = 0; idx0 < eDims[0]; idx0++) \ + for (dim_t idx1 = 0; idx1 < eDims[1]; idx1++) \ + for (dim_t idx2 = 0; idx2 < eDims[2]; idx2++) \ + for (dim_t idx3 = 0; idx3 < eDims[3]; idx3++) \ + for (dim_t idx4 = 0; idx4 < eDims[4]; idx4++) \ + for (dim_t idx5 = 0; idx5 < eDims[5]; idx5++) { \ + destH.at({_D0, _D1, _D2, _D3, _D4, _D5}) = \ + srcH.at({idx0, idx1, idx2, idx3, idx4, idx5}); \ + } \ + return; + + switch (I->getAxis()) { + case 0: + LOOP_AXIS_CASE(eDims[0] - 1 - idx0, idx1, idx2, idx3, idx4, idx5); + case 1: + LOOP_AXIS_CASE(idx0, eDims[1] - 1 - idx1, idx2, idx3, idx4, idx5); + case 2: + LOOP_AXIS_CASE(idx0, idx1, eDims[2] - 1 - idx2, idx3, idx4, idx5); + case 3: + LOOP_AXIS_CASE(idx0, idx1, idx2, eDims[3] - 1 - idx3, idx4, idx5); + case 4: + LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, eDims[4] - 1 - idx4, idx5); + case 5: + LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, idx4, eDims[5] - 1 - idx5); + default: + llvm_unreachable("Axis should be less than max_tensor_dimensions."); + } +} + +void BoundInterpreterFunction::fwdFlipInst(const FlipInst *I) { + dispatchImpl(fwdFlipInstImpl, I->getSrc()->getElementType(), I); +} diff --git a/lib/Exporter/ONNXModelWriter.cpp b/lib/Exporter/ONNXModelWriter.cpp index f020eb6d7d..11206f2b4d 100644 --- a/lib/Exporter/ONNXModelWriter.cpp +++ b/lib/Exporter/ONNXModelWriter.cpp @@ -725,6 +725,14 @@ Error ONNXModelWriter::writeTranspose(const TransposeNode *node, return writeAllWithNode("Transpose", node, graph, proto); } +Error ONNXModelWriter::writeFlip(const FlipNode *node, GraphType &graph) { + auto *proto = graph.add_node(); + // Add dictionary entries. + addValueAttribute(proto, "axis", node->getAxis()); + + return writeAllWithNode("Flip", node, graph, proto); +} + Error ONNXModelWriter::writeConvolution(const ConvolutionNode *node, GraphType &graph) { // Loading convolution creates a sandwich with Transpose nodes for Input, diff --git a/lib/Graph/Graph.cpp b/lib/Graph/Graph.cpp index 82f07aadc1..c5e4832797 100644 --- a/lib/Graph/Graph.cpp +++ b/lib/Graph/Graph.cpp @@ -1070,6 +1070,12 @@ TransposeNode *Function::createTranspose(llvm::StringRef name, NodeValue input, return addNode(new TransposeNode(name, NT, input, shuffle.vec(), currLayout)); } +FlipNode *Function::createFlip(llvm::StringRef name, NodeValue input, + unsigned_t axis) { + auto OT = getParent()->uniqueType(*input.getType()); + return addNode(new FlipNode(name, OT, input, axis)); +} + Node *Function::createBroadcast(llvm::StringRef name, NodeValue input, UnsignedArrayRef newShape, unsigned_t axis) { const auto &origDims = input.dims(); diff --git a/lib/Graph/Nodes.cpp b/lib/Graph/Nodes.cpp index adafc2e0a1..40eb57c799 100644 --- a/lib/Graph/Nodes.cpp +++ b/lib/Graph/Nodes.cpp @@ -905,6 +905,16 @@ bool TransposeNode::verify() const { return isValid; } +bool FlipNode::verify() const { + auto dest = getResult(); + auto src = getInput(); + dim_t axis = getAxis(); + bool isValid = checkSameType(src, dest, this); + isValid &= expectCompareTrue("Invalid axis", axis, (dim_t)src.dims().size(), + this, CompareOperatorLess()); + return isValid; +} + bool ChannelShuffleNode::verify() const { bool isValid = expectCompareTrue("Channel shuffle into a different size.", getResult().getType()->size(), diff --git a/lib/Graph/TensorLayout.cpp b/lib/Graph/TensorLayout.cpp index fea391114b..09e71a393c 100644 --- a/lib/Graph/TensorLayout.cpp +++ b/lib/Graph/TensorLayout.cpp @@ -648,6 +648,7 @@ static bool acceptsAnyInputLayout(const glow::Node *node) { case Kinded::Kind::ReshapeNodeKind: case Kinded::Kind::MeanVarNormalizationNodeKind: case Kinded::Kind::MatMulNodeKind: + case Kinded::Kind::FlipNodeKind: case Kinded::Kind::SGDNodeKind: { return true; } diff --git a/lib/Importer/ONNXModelLoader.cpp b/lib/Importer/ONNXModelLoader.cpp index b7f654694f..8096c22fa3 100644 --- a/lib/Importer/ONNXModelLoader.cpp +++ b/lib/Importer/ONNXModelLoader.cpp @@ -2097,6 +2097,22 @@ Error ONNXModelLoader::loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op, return Error::success(); } +Error ONNXModelLoader::loadFlip(const ONNX_NAMESPACE::NodeProto &op, + const ArgumentDictionaryTy &dict) { + NodeValue input; + ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0))); + + unsigned_t axis = 0; + if (dict.count("axis")) { + ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict.at("axis"))); + } + + Node *N = G_.createFlip("flip", input, axis); + + RETURN_IF_ERR(addNodeAsOutput(op, N)); + return Error::success(); +} + Error ONNXModelLoader::loadRowwiseQuantizedFullyConnected( const ONNX_NAMESPACE::NodeProto &op, const ArgumentDictionaryTy &dict) { // TODO @@ -2248,6 +2264,9 @@ Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) { if (typeName == "AdaptiveAvgPool") { return loadAdaptiveAvgPool(op, dict); } + if (typeName == "Flip") { + return loadFlip(op, dict); + } if (typeName == "Identity") { return loadIdentity(op, dict); } diff --git a/lib/LLVMIRCodeGen/LLVMIRGen.cpp b/lib/LLVMIRCodeGen/LLVMIRGen.cpp index 976dc26e88..f84ff8f8fc 100644 --- a/lib/LLVMIRCodeGen/LLVMIRGen.cpp +++ b/lib/LLVMIRCodeGen/LLVMIRGen.cpp @@ -2280,6 +2280,20 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder, break; } + case Kinded::Kind::FlipInstKind: { + auto *FI = cast(I); + auto *dest = FI->getDest(); + auto *src = FI->getSrc(); + auto *destPtr = emitValueAddress(builder, dest); + auto *srcPtr = emitValueAddress(builder, src); + auto *dims = emitValueDims(builder, src); + auto *axis = emitConstDimT(builder, FI->getAxis()); + auto *dimsSize = emitConstDimT(builder, src->getType()->dims().size()); + auto *F = getFunction("flip", src->getElementType()); + createCall(builder, F, {srcPtr, destPtr, dims, axis, dimsSize}); + break; + } + // Alloc and Dealloc instructions are handled by the memory allocator. case Kinded::Kind::AllocActivationInstKind: case Kinded::Kind::DeallocActivationInstKind: diff --git a/tests/models/onnxModels/flipNoAxis.onnxtxt b/tests/models/onnxModels/flipNoAxis.onnxtxt new file mode 100644 index 0000000000..8a253f2068 --- /dev/null +++ b/tests/models/onnxModels/flipNoAxis.onnxtxt @@ -0,0 +1,141 @@ +ir_version: 5 +producer_name: "onnx-flip" +graph { + node { + input: "X" + output: "Y" + name: "flip" + op_type: "Flip" + } + node { + input: "Y" + input: "Y_ref" + output: "Y_err" + name: "error" + op_type: "Sub" + } + name: "flip_test" + initializer { + dims: 2 + dims: 3 + dims: 4 + data_type: 1 + float_data: 1.0 + float_data: 2.0 + float_data: 3.0 + float_data: 4.0 + float_data: 5.0 + float_data: 6.0 + float_data: 7.0 + float_data: 8.0 + float_data: 9.0 + float_data: 10.0 + float_data: 11.0 + float_data: 12.0 + float_data: 13.0 + float_data: 14.0 + float_data: 15.0 + float_data: 16.0 + float_data: 17.0 + float_data: 18.0 + float_data: 19.0 + float_data: 20.0 + float_data: 21.0 + float_data: 22.0 + float_data: 23.0 + float_data: 24.0 + name: "X" + } + initializer { + dims: 2 + dims: 3 + dims: 4 + data_type: 1 + float_data: 13.0 + float_data: 14.0 + float_data: 15.0 + float_data: 16.0 + float_data: 17.0 + float_data: 18.0 + float_data: 19.0 + float_data: 20.0 + float_data: 21.0 + float_data: 22.0 + float_data: 23.0 + float_data: 24.0 + float_data: 1.0 + float_data: 2.0 + float_data: 3.0 + float_data: 4.0 + float_data: 5.0 + float_data: 6.0 + float_data: 7.0 + float_data: 8.0 + float_data: 9.0 + float_data: 10.0 + float_data: 11.0 + float_data: 12.0 + name: "Y_ref" + } + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + + input { + name: "Y_ref" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Y_err" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/tests/models/onnxModels/flipWithAxis.onnxtxt b/tests/models/onnxModels/flipWithAxis.onnxtxt new file mode 100644 index 0000000000..e414a64d53 --- /dev/null +++ b/tests/models/onnxModels/flipWithAxis.onnxtxt @@ -0,0 +1,146 @@ +ir_version: 5 +producer_name: "onnx-flip" +graph { + node { + input: "X" + output: "Y" + name: "flip" + op_type: "Flip" + attribute { + name: "axis" + i: 1 + type: INT + } + } + node { + input: "Y" + input: "Y_ref" + output: "Y_err" + name: "error" + op_type: "Sub" + } + name: "flip_test" + initializer { + dims: 2 + dims: 3 + dims: 4 + data_type: 1 + float_data: 1.0 + float_data: 2.0 + float_data: 3.0 + float_data: 4.0 + float_data: 5.0 + float_data: 6.0 + float_data: 7.0 + float_data: 8.0 + float_data: 9.0 + float_data: 10.0 + float_data: 11.0 + float_data: 12.0 + float_data: 13.0 + float_data: 14.0 + float_data: 15.0 + float_data: 16.0 + float_data: 17.0 + float_data: 18.0 + float_data: 19.0 + float_data: 20.0 + float_data: 21.0 + float_data: 22.0 + float_data: 23.0 + float_data: 24.0 + name: "X" + } + initializer { + dims: 2 + dims: 3 + dims: 4 + data_type: 1 + float_data: 9.0 + float_data: 10.0 + float_data: 11.0 + float_data: 12.0 + float_data: 5.0 + float_data: 6.0 + float_data: 7.0 + float_data: 8.0 + float_data: 1.0 + float_data: 2.0 + float_data: 3.0 + float_data: 4.0 + float_data: 21.0 + float_data: 22.0 + float_data: 23.0 + float_data: 24.0 + float_data: 17.0 + float_data: 18.0 + float_data: 19.0 + float_data: 20.0 + float_data: 13.0 + float_data: 14.0 + float_data: 15.0 + float_data: 16.0 + name: "Y_ref" + } + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + + input { + name: "Y_ref" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Y_err" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/tests/unittests/OnnxImporterTest.cpp b/tests/unittests/OnnxImporterTest.cpp index dd94b38455..f15725ffaf 100644 --- a/tests/unittests/OnnxImporterTest.cpp +++ b/tests/unittests/OnnxImporterTest.cpp @@ -2761,3 +2761,37 @@ TEST(onnx, importLSTMForwardInputForget) { importLSTM(GLOW_DATA_PATH "tests/models/onnxModels/lstmForwardInputForget.onnxtxt"); } + +/// Test loading Flip from a ONNX model. The ONNX model already computes +/// the error. +static void importFlip(std::string fileName) { + ExecutionEngine EE; + auto &mod = EE.getModule(); + Function *F = mod.createFunction("main"); + + PlaceholderBindings bindings; + { + ONNXModelLoader onnxLD(fileName, {}, {}, *F); + bindings.allocate(mod.getPlaceholders()); + } + + // Compile and run. + EE.compile(CompilationMode::Infer); + EE.run(bindings); + + // Verify error. + Placeholder *Y_err_ph = mod.getPlaceholderByName("Y_err"); + EXPECT_TRUE(Y_err_ph); + auto err = bindings.get(Y_err_ph)->getHandle(); + for (size_t idx = 0; idx < Y_err_ph->getType()->size(); idx++) { + EXPECT_EQ(err.raw(idx), 0); + } +} + +TEST(onnx, importFlipWithAxis) { + importFlip(GLOW_DATA_PATH "tests/models/onnxModels/flipWithAxis.onnxtxt"); +} + +TEST(onnx, importFlipNoAxis) { + importFlip(GLOW_DATA_PATH "tests/models/onnxModels/flipNoAxis.onnxtxt"); +} diff --git a/tests/unittests/OperatorTest.cpp b/tests/unittests/OperatorTest.cpp index fc40758f9b..1c7d792064 100644 --- a/tests/unittests/OperatorTest.cpp +++ b/tests/unittests/OperatorTest.cpp @@ -3079,6 +3079,299 @@ TEST_P(OperatorTest, TransposeIntoReshapeOptim) { } } +/// Helper to check the code generation for flip nodes. +template +static void testFlip(glow::PlaceholderBindings &bindings, glow::Module &mod, + glow::Function *F, glow::ExecutionEngine &EE, + std::vector inputData, + std::vector expectedData, + llvm::ArrayRef dims, dim_t axis, + ElemKind elemKind = ElemKind::FloatTy) { + + // Create network. + auto *input = + createPlaceholderConditionallyQuantized(mod, elemKind, dims, "input", + /* isTrainable */ false); + auto *flip = F->createFlip("flip", input, axis); + Placeholder *output = F->createSave("save", flip)->getPlaceholder(); + + // Allocate input/output and initialize input. + auto inputH = bindings.allocate(input)->getHandle(); + auto outputH = bindings.allocate(output)->getHandle(); + inputH = inputData; + + // Compile and run. + EE.compile(CompilationMode::Infer); + EE.run(bindings); + + // Compare output with reference. + EXPECT_EQ(outputH.size(), expectedData.size()); + for (size_t i = 0; i < expectedData.size(); i++) { + EXPECT_EQ(outputH.raw(i), expectedData[i]); + } +} + +/// Test Flip 1D with Int8. +TEST_P(OperatorTest, Flip1D_Int8) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {4, 3, 2, 1}, {4}, 0, + ElemKind::Int8QTy); +} + +/// Test Flip 1D with Int32. +TEST_P(OperatorTest, Flip1D_Int32) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {4, 3, 2, 1}, {4}, + 0, ElemKind::Int32QTy); +} + +/// Test Flip 1D with Int64. +TEST_P(OperatorTest, Flip1D_Int64) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {4, 3, 2, 1}, {4}, + 0, ElemKind::Int64ITy); +} + +#define FLIP_3D_INPUT \ + { 1, 2, 3, 4, 5, 6, 7, 8 } +#define FLIP_3D_AXIS0 \ + { 5, 6, 7, 8, 1, 2, 3, 4 } +#define FLIP_3D_AXIS1 \ + { 3, 4, 1, 2, 7, 8, 5, 6 } +#define FLIP_3D_AXIS2 \ + { 2, 1, 4, 3, 6, 5, 8, 7 } + +#define FLIP_4D_INPUT \ + { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 } +#define FLIP_4D_AXIS0 \ + { 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8 } +#define FLIP_4D_AXIS1 \ + { 5, 6, 7, 8, 1, 2, 3, 4, 13, 14, 15, 16, 9, 10, 11, 12 } +#define FLIP_4D_AXIS2 \ + { 3, 4, 1, 2, 7, 8, 5, 6, 11, 12, 9, 10, 15, 16, 13, 14 } +#define FLIP_4D_AXIS3 \ + { 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15 } + +#define FLIP_5D_INPUT \ + { \ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, \ + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 \ + } +#define FLIP_5D_AXIS0 \ + { \ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 1, 2, 3, \ + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 \ + } +#define FLIP_5D_AXIS1 \ + { \ + 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 25, 26, 27, 28, 29, \ + 30, 31, 32, 17, 18, 19, 20, 21, 22, 23, 24 \ + } +#define FLIP_5D_AXIS2 \ + { \ + 5, 6, 7, 8, 1, 2, 3, 4, 13, 14, 15, 16, 9, 10, 11, 12, 21, 22, 23, 24, 17, \ + 18, 19, 20, 29, 30, 31, 32, 25, 26, 27, 28 \ + } +#define FLIP_5D_AXIS3 \ + { \ + 3, 4, 1, 2, 7, 8, 5, 6, 11, 12, 9, 10, 15, 16, 13, 14, 19, 20, 17, 18, 23, \ + 24, 21, 22, 27, 28, 25, 26, 31, 32, 29, 30 \ + } +#define FLIP_5D_AXIS4 \ + { \ + 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, \ + 21, 24, 23, 26, 25, 28, 27, 30, 29, 32, 31 \ + } + +#define FLIP_6D_INPUT \ + { \ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, \ + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, \ + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, \ + 56, 57, 58, 59, 60, 61, 62, 63, 64 \ + } +#define FLIP_6D_AXIS0 \ + { \ + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, \ + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 1, 2, 3, 4, 5, \ + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, \ + 24, 25, 26, 27, 28, 29, 30, 31, 32 \ + } +#define FLIP_6D_AXIS1 \ + { \ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 1, 2, 3, \ + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 49, 50, 51, 52, 53, 54, \ + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 33, 34, 35, 36, 37, 38, 39, \ + 40, 41, 42, 43, 44, 45, 46, 47, 48 \ + } +#define FLIP_6D_AXIS2 \ + { \ + 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 25, 26, 27, 28, 29, \ + 30, 31, 32, 17, 18, 19, 20, 21, 22, 23, 24, 41, 42, 43, 44, 45, 46, \ + 47, 48, 33, 34, 35, 36, 37, 38, 39, 40, 57, 58, 59, 60, 61, 62, 63, \ + 64, 49, 50, 51, 52, 53, 54, 55, 56 \ + } +#define FLIP_6D_AXIS3 \ + { \ + 5, 6, 7, 8, 1, 2, 3, 4, 13, 14, 15, 16, 9, 10, 11, 12, 21, 22, 23, 24, 17, \ + 18, 19, 20, 29, 30, 31, 32, 25, 26, 27, 28, 37, 38, 39, 40, 33, 34, \ + 35, 36, 45, 46, 47, 48, 41, 42, 43, 44, 53, 54, 55, 56, 49, 50, 51, \ + 52, 61, 62, 63, 64, 57, 58, 59, 60 \ + } +#define FLIP_6D_AXIS4 \ + { \ + 3, 4, 1, 2, 7, 8, 5, 6, 11, 12, 9, 10, 15, 16, 13, 14, 19, 20, 17, 18, 23, \ + 24, 21, 22, 27, 28, 25, 26, 31, 32, 29, 30, 35, 36, 33, 34, 39, 40, \ + 37, 38, 43, 44, 41, 42, 47, 48, 45, 46, 51, 52, 49, 50, 55, 56, 53, \ + 54, 59, 60, 57, 58, 63, 64, 61, 62 \ + } +#define FLIP_6D_AXIS5 \ + { \ + 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, \ + 21, 24, 23, 26, 25, 28, 27, 30, 29, 32, 31, 34, 33, 36, 35, 38, 37, \ + 40, 39, 42, 41, 44, 43, 46, 45, 48, 47, 50, 49, 52, 51, 54, 53, 56, \ + 55, 58, 57, 60, 59, 62, 61, 64, 63 \ + } + +/// Test Flip 1D with Float. +TEST_P(OperatorTest, Flip1D_Axis0_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, {1, 2}, {2, 1}, {2}, 0); +} + +/// Test Flip 2D with Float. +TEST_P(OperatorTest, Flip2D_Axis0_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {3, 4, 1, 2}, {2, 2}, + 0); +} +TEST_P(OperatorTest, Flip2D_Axis1_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {2, 1, 4, 3}, {2, 2}, + 1); +} + +/// Test Flip 3D with Float. +TEST_P(OperatorTest, Flip3D_Axis0_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_3D_INPUT, FLIP_3D_AXIS0, + {2, 2, 2}, 0); +} +TEST_P(OperatorTest, Flip3D_Axis1_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_3D_INPUT, FLIP_3D_AXIS1, + {2, 2, 2}, 1); +} +TEST_P(OperatorTest, Flip3D_Axis2_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_3D_INPUT, FLIP_3D_AXIS2, + {2, 2, 2}, 2); +} + +/// Test Flip 4D with Float. +TEST_P(OperatorTest, Flip4D_Axis0_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_4D_INPUT, FLIP_4D_AXIS0, + {2, 2, 2, 2}, 0); +} +TEST_P(OperatorTest, Flip4D_Axis1_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_4D_INPUT, FLIP_4D_AXIS1, + {2, 2, 2, 2}, 1); +} +TEST_P(OperatorTest, Flip4D_Axis2_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_4D_INPUT, FLIP_4D_AXIS2, + {2, 2, 2, 2}, 2); +} +TEST_P(OperatorTest, Flip4D_Axis3_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_4D_INPUT, FLIP_4D_AXIS3, + {2, 2, 2, 2}, 3); +} + +/// Test Flip 5D with Float. +TEST_P(OperatorTest, Flip5D_Axis0_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS0, + {2, 2, 2, 2, 2}, 0); +} +TEST_P(OperatorTest, Flip5D_Axis1_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS1, + {2, 2, 2, 2, 2}, 1); +} +TEST_P(OperatorTest, Flip5D_Axis2_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS2, + {2, 2, 2, 2, 2}, 2); +} +TEST_P(OperatorTest, Flip5D_Axis3_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS3, + {2, 2, 2, 2, 2}, 3); +} +TEST_P(OperatorTest, Flip5D_Axis4_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS4, + {2, 2, 2, 2, 2}, 4); +} + +/// Test Flip 6D with Float. +TEST_P(OperatorTest, Flip6D_Axis0_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS0, + {2, 2, 2, 2, 2, 2}, 0); +} +TEST_P(OperatorTest, Flip6D_Axis1_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS1, + {2, 2, 2, 2, 2, 2}, 1); +} +TEST_P(OperatorTest, Flip6D_Axis2_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS2, + {2, 2, 2, 2, 2, 2}, 2); +} +TEST_P(OperatorTest, Flip6D_Axis3_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS3, + {2, 2, 2, 2, 2, 2}, 3); +} +TEST_P(OperatorTest, Flip6D_Axis4_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS4, + {2, 2, 2, 2, 2, 2}, 4); +} +TEST_P(OperatorTest, Flip6D_Axis5_Float) { + ENABLED_BACKENDS(Interpreter, CPU); + testFlip(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS5, + {2, 2, 2, 2, 2, 2}, 5); +} + +#undef FLIP_3D_INPUT +#undef FLIP_3D_AXIS0 +#undef FLIP_3D_AXIS1 +#undef FLIP_3D_AXIS2 +#undef FLIP_4D_INPUT +#undef FLIP_4D_AXIS0 +#undef FLIP_4D_AXIS1 +#undef FLIP_4D_AXIS2 +#undef FLIP_4D_AXIS3 +#undef FLIP_5D_INPUT +#undef FLIP_5D_AXIS0 +#undef FLIP_5D_AXIS1 +#undef FLIP_5D_AXIS2 +#undef FLIP_5D_AXIS3 +#undef FLIP_5D_AXIS4 +#undef FLIP_6D_INPUT +#undef FLIP_6D_AXIS0 +#undef FLIP_6D_AXIS1 +#undef FLIP_6D_AXIS2 +#undef FLIP_6D_AXIS3 +#undef FLIP_6D_AXIS4 +#undef FLIP_6D_AXIS5 + /// Check that gather on Int64ITy/size_t works. TEST_P(OperatorTest, GatherSizeT) { CHECK_IF_ENABLED(); diff --git a/tools/ClassGen/InstrGen.cpp b/tools/ClassGen/InstrGen.cpp index cf01528494..a603e00f30 100644 --- a/tools/ClassGen/InstrGen.cpp +++ b/tools/ClassGen/InstrGen.cpp @@ -680,6 +680,18 @@ int main(int argc, char **argv) { .autoVerify(VerifyKind::SameElementType, {"Dest", "Src"}) .autoIRGen(); + //===--------------------------------------------------------------------===// + // Reorder transformations + //===--------------------------------------------------------------------===// + + BB.newInstr("Flip") + .addOperand("Dest", OperandKind::Out) + .addOperand("Src", OperandKind::In) + .addMember(MemberType::Unsigned, "Axis") + .autoVerify(VerifyKind::SameElementType, {"Dest", "Src"}) + .autoVerify(VerifyKind::SameShape, {"Dest", "Src"}) + .autoIRGen(); + //===--------------------------------------------------------------------===// // Instructions used for debugging/profiling/printing //===--------------------------------------------------------------------===// diff --git a/tools/ClassGen/NodeGen.cpp b/tools/ClassGen/NodeGen.cpp index fa1f0de0a8..982376071c 100644 --- a/tools/ClassGen/NodeGen.cpp +++ b/tools/ClassGen/NodeGen.cpp @@ -832,6 +832,19 @@ int main(int argc, char **argv) { "neighbor interpolation. The Output tensor is of shape [N, " "floor(H*HeightScale), floor(W*WidthScale), C]"); + //===--------------------------------------------------------------------===// + // Reorder transformations + //===--------------------------------------------------------------------===// + + BB.newNode("Flip") + .addInput("Input") + .addMember(MemberType::Unsigned, "Axis") + .addResultFromCtorArg() + .setDocstring( + "Reverse the order of elements in a tensor along the given axis. The " + "shape of the tensor is preserved, but the elements are reordered. " + "The node is inspired from Python numpy."); + //===--------------------------------------------------------------------===// // Nodes used for network training //===--------------------------------------------------------------------===//