Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Graph] Add Flip node. #3889

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/glow/Backends/Interpreter/InterpreterFunction.h
Expand Up @@ -324,6 +324,9 @@ class BoundInterpreterFunction {
template <typename T, typename AccumT>
void fwdEmbeddingBagByteRowwiseOffsetsImpl(
const EmbeddingBagByteRowwiseOffsetsInst *I);

template <typename ElemTy> void fwdFlipInstImpl(const FlipInst *I);

///@}
};

Expand Down
4 changes: 4 additions & 0 deletions include/glow/Graph/Graph.h
Expand Up @@ -646,6 +646,10 @@ class Function final : public Named {
llvm::ArrayRef<unsigned_t> shuffle,
const std::string &layout = ANY_LAYOUT);

/// Create a node with the name \p name which flips (reorders) the elements
/// of the input \p input along the given axis \p axis.
FlipNode *createFlip(llvm::StringRef name, NodeValue input, unsigned_t axis);

/// Create a series of nodes that implement a Broadcast operation. The \p
/// input Tensor is broadcasted based on \p newShape and along the \p axis,
/// which defines the offset from the leading dimension under which
Expand Down
6 changes: 6 additions & 0 deletions include/glow/Graph/VerifierHelper.h
Expand Up @@ -106,6 +106,12 @@ struct CompareOperatorLessEqual : public CompareWithName<Ty> {
bool operator()(const Ty &a, const Ty &b) const override { return a <= b; }
llvm::StringRef getCompareName() const override { return "LessEqual"; }
};

/// Operator <.
template <typename Ty> struct CompareOperatorLess : public CompareWithName<Ty> {
bool operator()(const Ty &a, const Ty &b) const override { return a < b; }
llvm::StringRef getCompareName() const override { return "Less"; }
};
/// @}

/// Main API of the verifier.
Expand Down
4 changes: 4 additions & 0 deletions include/glow/Importer/ONNXModelLoader.h
Expand Up @@ -248,6 +248,10 @@ class ONNXModelLoader
Error loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);

/// Load Flip Glow operator.
Error loadFlip(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);

protected:
/// Load the network operators from the GraphProto.
/// \returns Error if network cannot be loaded.
Expand Down
6 changes: 6 additions & 0 deletions lib/Backends/CPU/CPUBackend.cpp
Expand Up @@ -106,6 +106,12 @@ bool CPUBackend::isOpSupported(const NodeInfo &NI) const {
{ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
ElemKind::BoolTy});

case Kinded::Kind::FlipNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int16QTy,
ElemKind::Int32QTy, ElemKind::Int32ITy, ElemKind::Int64ITy,
ElemKind::BoolTy});

case Kinded::Kind::SparseLengthsSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy}, {SparseLengthsSumNode::IndicesIdx,
Expand Down
63 changes: 63 additions & 0 deletions lib/Backends/CPU/libjit/libjit.cpp
Expand Up @@ -563,6 +563,39 @@ static void libjit_transpose_generic(const T *inW, T *outW, const dim_t *idim,
}
}

template <typename T>
static void libjit_flip_generic(const T *inW, T *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {

// Product of outer dimensions excluding the flip dimension.
dim_t outerLen = 1;
for (dim_t idx = 0; idx < axis; idx++) {
outerLen *= dims[idx];
}

// Flip dimension.
dim_t len = dims[axis];

// Product of inner dimensions excluding the flip dimension.
dim_t innerLen = 1;
for (dim_t idx = axis + 1; idx < numDims; idx++) {
innerLen *= dims[idx];
}

// Flip axis such that input data is read linearly.
const T *inpPtr = inW;
T *outPtr = outW + (len - 1) * innerLen;
for (dim_t outerIdx = 0; outerIdx < outerLen; outerIdx++) {
for (dim_t idx = 0; idx < len; idx++) {
for (dim_t innerIdx = 0; innerIdx < innerLen; innerIdx++) {
*outPtr++ = *inpPtr++;
}
outPtr -= 2 * innerLen;
}
outPtr += 2 * len * innerLen;
}
}

template <typename T>
static void libjit_max_pool_generic(const T *inW, T *outW, const dim_t *inWdims,
const dim_t *outWdims, dim_t *kernelSizes,
Expand Down Expand Up @@ -1944,6 +1977,36 @@ void libjit_transpose_b(const bool *inW, bool *outW, const dim_t *idim,
libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
}

void libjit_flip_i8(const int8_t *inW, int8_t *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_i16(const int16_t *inW, int16_t *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_i32(const int32_t *inW, int32_t *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_u(const int64_t *inW, int64_t *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_f(const float *inW, float *outW, const dim_t *dims, dim_t axis,
dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_b(const bool *inW, bool *outW, const dim_t *dims, dim_t axis,
dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_insert_tensor_f(float *tensor, float *slice, dim_t *offset,
dim_t *tensorDim, dim_t *sliceDim,
dim_t numDimsTensor, dim_t numDimsSlice,
Expand Down
1 change: 1 addition & 0 deletions lib/Backends/Interpreter/Interpreter.cpp
Expand Up @@ -523,6 +523,7 @@ bool Interpreter::isOpSupported(const NodeInfo &NI) const {
case Kinded::Kind::TransposeNodeKind:
case Kinded::Kind::ReshapeNodeKind:
case Kinded::Kind::SaveNodeKind:
case Kinded::Kind::FlipNodeKind:
// These work regardless of the underlying type.
return true;

Expand Down
80 changes: 80 additions & 0 deletions lib/Backends/Interpreter/InterpreterNodes.cpp
Expand Up @@ -28,6 +28,36 @@

using namespace glow;

#define dispatchImpl(functionName, elemTy, ...) \
switch (elemTy) { \
case ElemKind::FloatTy: \
functionName<float>(__VA_ARGS__); \
break; \
case ElemKind::Float16Ty: \
functionName<float16_t>(__VA_ARGS__); \
break; \
case ElemKind::Int8QTy: \
functionName<int8_t>(__VA_ARGS__); \
break; \
case ElemKind::Int16QTy: \
functionName<int16_t>(__VA_ARGS__); \
break; \
case ElemKind::Int32QTy: \
functionName<int32_t>(__VA_ARGS__); \
break; \
case ElemKind::Int32ITy: \
functionName<int32_t>(__VA_ARGS__); \
break; \
case ElemKind::Int64ITy: \
functionName<int64_t>(__VA_ARGS__); \
break; \
case ElemKind::BoolTy: \
functionName<bool>(__VA_ARGS__); \
break; \
default: \
llvm_unreachable("Type is not supported"); \
}

#define dispatchFloatingPointImpl(functionName, elemTy, ...) \
switch (elemTy) { \
case ElemKind::FloatTy: \
Expand Down Expand Up @@ -4008,3 +4038,53 @@ void BoundInterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) {
#undef CONVERT
llvm_unreachable("Type not supported");
}

template <typename ElemTy>
void BoundInterpreterFunction::fwdFlipInstImpl(const FlipInst *I) {

static_assert(max_tensor_dimensions == 6,
"Loops below assume max_tensor_dimensions = 6.");

auto *src = I->getSrc();
auto *dest = I->getDest();

// Get unowned handles of src and dest with dims expanded to maximum.
ShapeVector eDims = expandDimsToMax(src->dims());
auto eSrc = getTensor(src)->getUnowned(eDims);
auto eDest = getTensor(dest)->getUnowned(eDims);
auto srcH = eSrc.getHandle<ElemTy>();
auto destH = eDest.getHandle<ElemTy>();

#define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5) \
for (dim_t idx0 = 0; idx0 < eDims[0]; idx0++) \
for (dim_t idx1 = 0; idx1 < eDims[1]; idx1++) \
for (dim_t idx2 = 0; idx2 < eDims[2]; idx2++) \
for (dim_t idx3 = 0; idx3 < eDims[3]; idx3++) \
for (dim_t idx4 = 0; idx4 < eDims[4]; idx4++) \
for (dim_t idx5 = 0; idx5 < eDims[5]; idx5++) { \
destH.at({_D0, _D1, _D2, _D3, _D4, _D5}) = \
srcH.at({idx0, idx1, idx2, idx3, idx4, idx5}); \
} \
return;

switch (I->getAxis()) {
case 0:
LOOP_AXIS_CASE(eDims[0] - 1 - idx0, idx1, idx2, idx3, idx4, idx5);
case 1:
LOOP_AXIS_CASE(idx0, eDims[1] - 1 - idx1, idx2, idx3, idx4, idx5);
case 2:
LOOP_AXIS_CASE(idx0, idx1, eDims[2] - 1 - idx2, idx3, idx4, idx5);
case 3:
LOOP_AXIS_CASE(idx0, idx1, idx2, eDims[3] - 1 - idx3, idx4, idx5);
case 4:
LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, eDims[4] - 1 - idx4, idx5);
case 5:
LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, idx4, eDims[5] - 1 - idx5);
default:
llvm_unreachable("Axis should be less than max_tensor_dimensions.");
}
}

void BoundInterpreterFunction::fwdFlipInst(const FlipInst *I) {
dispatchImpl(fwdFlipInstImpl, I->getSrc()->getElementType(), I);
}
8 changes: 8 additions & 0 deletions lib/Exporter/ONNXModelWriter.cpp
Expand Up @@ -687,6 +687,14 @@ Error ONNXModelWriter::writeTranspose(const TransposeNode *node,
return writeAllWithNode("Transpose", node, proto);
}

Error ONNXModelWriter::writeFlip(const FlipNode *node, GraphType &graph) {
auto *proto = graph.add_node();
// Add dictionary entries.
addValueAttribute(proto, "axis", node->getAxis());

return writeAllWithNode("Flip", node, proto);
}

Error ONNXModelWriter::writeConvolution(const ConvolutionNode *node,
GraphType &graph) {
// Loading convolution creates a sandwich with Transpose nodes for Input,
Expand Down
6 changes: 6 additions & 0 deletions lib/Graph/Graph.cpp
Expand Up @@ -1070,6 +1070,12 @@ TransposeNode *Function::createTranspose(llvm::StringRef name, NodeValue input,
return addNode(new TransposeNode(name, NT, input, shuffle.vec(), currLayout));
}

FlipNode *Function::createFlip(llvm::StringRef name, NodeValue input,
unsigned_t axis) {
auto OT = getParent()->uniqueType(*input.getType());
return addNode(new FlipNode(name, OT, input, axis));
}

Node *Function::createBroadcast(llvm::StringRef name, NodeValue input,
UnsignedArrayRef newShape, unsigned_t axis) {
const auto &origDims = input.dims();
Expand Down
10 changes: 10 additions & 0 deletions lib/Graph/Nodes.cpp
Expand Up @@ -905,6 +905,16 @@ bool TransposeNode::verify() const {
return isValid;
}

bool FlipNode::verify() const {
auto dest = getResult();
auto src = getInput();
dim_t axis = getAxis();
bool isValid = checkSameType(src, dest, this);
isValid &= expectCompareTrue("Invalid axis", axis, (dim_t)src.dims().size(),
this, CompareOperatorLess<dim_t>());
return isValid;
}

bool ChannelShuffleNode::verify() const {
bool isValid = expectCompareTrue("Channel shuffle into a different size.",
getResult().getType()->size(),
Expand Down
1 change: 1 addition & 0 deletions lib/Graph/TensorLayout.cpp
Expand Up @@ -648,6 +648,7 @@ static bool acceptsAnyInputLayout(const glow::Node *node) {
case Kinded::Kind::ReshapeNodeKind:
case Kinded::Kind::MeanVarNormalizationNodeKind:
case Kinded::Kind::MatMulNodeKind:
case Kinded::Kind::FlipNodeKind:
case Kinded::Kind::SGDNodeKind: {
return true;
}
Expand Down
19 changes: 19 additions & 0 deletions lib/Importer/ONNXModelLoader.cpp
Expand Up @@ -2089,6 +2089,22 @@ Error ONNXModelLoader::loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op,
return Error::success();
}

Error ONNXModelLoader::loadFlip(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict) {
NodeValue input;
ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));

unsigned_t axis = 0;
if (dict.count("axis")) {
ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict.at("axis")));
}

Node *N = G_.createFlip("flip", input, axis);

RETURN_IF_ERR(addNodeAsOutput(op, N));
return Error::success();
}

Error ONNXModelLoader::loadRowwiseQuantizedFullyConnected(
const ONNX_NAMESPACE::NodeProto &op, const ArgumentDictionaryTy &dict) {
// TODO
Expand Down Expand Up @@ -2240,6 +2256,9 @@ Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) {
if (typeName == "AdaptiveAvgPool") {
return loadAdaptiveAvgPool(op, dict);
}
if (typeName == "Flip") {
return loadFlip(op, dict);
}

RETURN_ERR("Failed to load operator " + typeName + " .",
ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR);
Expand Down
14 changes: 14 additions & 0 deletions lib/LLVMIRCodeGen/LLVMIRGen.cpp
Expand Up @@ -2280,6 +2280,20 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder,
break;
}

case Kinded::Kind::FlipInstKind: {
auto *FI = cast<FlipInst>(I);
auto *dest = FI->getDest();
auto *src = FI->getSrc();
auto *destPtr = emitValueAddress(builder, dest);
auto *srcPtr = emitValueAddress(builder, src);
auto *dims = emitValueDims(builder, src);
auto *axis = emitConstDimT(builder, FI->getAxis());
auto *dimsSize = emitConstDimT(builder, src->getType()->dims().size());
auto *F = getFunction("flip", src->getElementType());
createCall(builder, F, {srcPtr, destPtr, dims, axis, dimsSize});
break;
}

// Alloc and Dealloc instructions are handled by the memory allocator.
case Kinded::Kind::AllocActivationInstKind:
case Kinded::Kind::DeallocActivationInstKind:
Expand Down