Skip to content
Permalink
Browse files

Add Flip node. (#3889)

Summary:
**Summary**
- Add Flip node inspired from Numpy.
- The node reverts the order of the elements along a given axis.
- The node currently supports single axis: multiple nodes can be chained for multiple axes.
- The node can be used for example for RGB <-> BGR conversion inside the graph.

**Documentation**
None

**Test Plan**
Unit tests for operator, ONNX importer and exporter.
Pull Request resolved: #3889

Differential Revision: D19326679

Pulled By: jfix71

fbshipit-source-id: 103f0f84418e7197a937590c99347366897edd48
  • Loading branch information
mciprian13 authored and facebook-github-bot committed Jan 11, 2020
1 parent 35d2ce6 commit 2661e05bd0f1ee0b0d3322ef244ba2c11edc8b16
@@ -324,6 +324,9 @@ class BoundInterpreterFunction {
template <typename T, typename AccumT>
void fwdEmbeddingBagByteRowwiseOffsetsImpl(
const EmbeddingBagByteRowwiseOffsetsInst *I);

template <typename ElemTy> void fwdFlipInstImpl(const FlipInst *I);

///@}
};

@@ -646,6 +646,10 @@ class Function final : public Named {
llvm::ArrayRef<unsigned_t> shuffle,
const std::string &layout = ANY_LAYOUT);

/// Create a node with the name \p name which flips (reorders) the elements
/// of the input \p input along the given axis \p axis.
FlipNode *createFlip(llvm::StringRef name, NodeValue input, unsigned_t axis);

/// Create a series of nodes that implement a Broadcast operation. The \p
/// input Tensor is broadcasted based on \p newShape and along the \p axis,
/// which defines the offset from the leading dimension under which
@@ -106,6 +106,12 @@ struct CompareOperatorLessEqual : public CompareWithName<Ty> {
bool operator()(const Ty &a, const Ty &b) const override { return a <= b; }
llvm::StringRef getCompareName() const override { return "LessEqual"; }
};

/// Operator <.
template <typename Ty> struct CompareOperatorLess : public CompareWithName<Ty> {
bool operator()(const Ty &a, const Ty &b) const override { return a < b; }
llvm::StringRef getCompareName() const override { return "Less"; }
};
/// @}

/// Main API of the verifier.
@@ -252,6 +252,10 @@ class ONNXModelLoader
Error loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);

/// Load Flip Glow operator.
Error loadFlip(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);

protected:
/// Load the network operators from the GraphProto.
/// \returns Error if network cannot be loaded.
@@ -106,6 +106,12 @@ bool CPUBackend::isOpSupported(const NodeInfo &NI) const {
{ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
ElemKind::BoolTy});

case Kinded::Kind::FlipNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int16QTy,
ElemKind::Int32QTy, ElemKind::Int32ITy, ElemKind::Int64ITy,
ElemKind::BoolTy});

case Kinded::Kind::SparseLengthsSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy}, {SparseLengthsSumNode::IndicesIdx,
@@ -563,6 +563,39 @@ static void libjit_transpose_generic(const T *inW, T *outW, const dim_t *idim,
}
}

template <typename T>
static void libjit_flip_generic(const T *inW, T *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {

// Product of outer dimensions excluding the flip dimension.
dim_t outerLen = 1;
for (dim_t idx = 0; idx < axis; idx++) {
outerLen *= dims[idx];
}

// Flip dimension.
dim_t len = dims[axis];

// Product of inner dimensions excluding the flip dimension.
dim_t innerLen = 1;
for (dim_t idx = axis + 1; idx < numDims; idx++) {
innerLen *= dims[idx];
}

// Flip axis such that input data is read linearly.
const T *inpPtr = inW;
T *outPtr = outW + (len - 1) * innerLen;
for (dim_t outerIdx = 0; outerIdx < outerLen; outerIdx++) {
for (dim_t idx = 0; idx < len; idx++) {
for (dim_t innerIdx = 0; innerIdx < innerLen; innerIdx++) {
*outPtr++ = *inpPtr++;
}
outPtr -= 2 * innerLen;
}
outPtr += 2 * len * innerLen;
}
}

template <typename T>
static void libjit_max_pool_generic(const T *inW, T *outW, const dim_t *inWdims,
const dim_t *outWdims, dim_t *kernelSizes,
@@ -1954,6 +1987,36 @@ void libjit_transpose_b(const bool *inW, bool *outW, const dim_t *idim,
libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
}

void libjit_flip_i8(const int8_t *inW, int8_t *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_i16(const int16_t *inW, int16_t *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_i32(const int32_t *inW, int32_t *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_u(const int64_t *inW, int64_t *outW, const dim_t *dims,
dim_t axis, dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_f(const float *inW, float *outW, const dim_t *dims, dim_t axis,
dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_flip_b(const bool *inW, bool *outW, const dim_t *dims, dim_t axis,
dim_t numDims) {
libjit_flip_generic(inW, outW, dims, axis, numDims);
}

void libjit_insert_tensor_f(float *tensor, float *slice, dim_t *offset,
dim_t *tensorDim, dim_t *sliceDim,
dim_t numDimsTensor, dim_t numDimsSlice,
@@ -523,6 +523,7 @@ bool Interpreter::isOpSupported(const NodeInfo &NI) const {
case Kinded::Kind::TransposeNodeKind:
case Kinded::Kind::ReshapeNodeKind:
case Kinded::Kind::SaveNodeKind:
case Kinded::Kind::FlipNodeKind:
// These work regardless of the underlying type.
return true;

@@ -28,6 +28,36 @@

using namespace glow;

#define dispatchImpl(functionName, elemTy, ...) \
switch (elemTy) { \
case ElemKind::FloatTy: \
functionName<float>(__VA_ARGS__); \
break; \
case ElemKind::Float16Ty: \
functionName<float16_t>(__VA_ARGS__); \
break; \
case ElemKind::Int8QTy: \
functionName<int8_t>(__VA_ARGS__); \
break; \
case ElemKind::Int16QTy: \
functionName<int16_t>(__VA_ARGS__); \
break; \
case ElemKind::Int32QTy: \
functionName<int32_t>(__VA_ARGS__); \
break; \
case ElemKind::Int32ITy: \
functionName<int32_t>(__VA_ARGS__); \
break; \
case ElemKind::Int64ITy: \
functionName<int64_t>(__VA_ARGS__); \
break; \
case ElemKind::BoolTy: \
functionName<bool>(__VA_ARGS__); \
break; \
default: \
llvm_unreachable("Type is not supported"); \
}

#define dispatchFloatingPointImpl(functionName, elemTy, ...) \
switch (elemTy) { \
case ElemKind::FloatTy: \
@@ -4038,3 +4068,53 @@ void BoundInterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) {
#undef CONVERT
llvm_unreachable("Type not supported");
}

template <typename ElemTy>
void BoundInterpreterFunction::fwdFlipInstImpl(const FlipInst *I) {

static_assert(max_tensor_dimensions == 6,
"Loops below assume max_tensor_dimensions = 6.");

auto *src = I->getSrc();
auto *dest = I->getDest();

// Get unowned handles of src and dest with dims expanded to maximum.
ShapeVector eDims = expandDimsToMax(src->dims());
auto eSrc = getTensor(src)->getUnowned(eDims);
auto eDest = getTensor(dest)->getUnowned(eDims);
auto srcH = eSrc.getHandle<ElemTy>();
auto destH = eDest.getHandle<ElemTy>();

#define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5) \
for (dim_t idx0 = 0; idx0 < eDims[0]; idx0++) \
for (dim_t idx1 = 0; idx1 < eDims[1]; idx1++) \
for (dim_t idx2 = 0; idx2 < eDims[2]; idx2++) \
for (dim_t idx3 = 0; idx3 < eDims[3]; idx3++) \
for (dim_t idx4 = 0; idx4 < eDims[4]; idx4++) \
for (dim_t idx5 = 0; idx5 < eDims[5]; idx5++) { \
destH.at({_D0, _D1, _D2, _D3, _D4, _D5}) = \
srcH.at({idx0, idx1, idx2, idx3, idx4, idx5}); \
} \
return;

switch (I->getAxis()) {
case 0:
LOOP_AXIS_CASE(eDims[0] - 1 - idx0, idx1, idx2, idx3, idx4, idx5);
case 1:
LOOP_AXIS_CASE(idx0, eDims[1] - 1 - idx1, idx2, idx3, idx4, idx5);
case 2:
LOOP_AXIS_CASE(idx0, idx1, eDims[2] - 1 - idx2, idx3, idx4, idx5);
case 3:
LOOP_AXIS_CASE(idx0, idx1, idx2, eDims[3] - 1 - idx3, idx4, idx5);
case 4:
LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, eDims[4] - 1 - idx4, idx5);
case 5:
LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, idx4, eDims[5] - 1 - idx5);
default:
llvm_unreachable("Axis should be less than max_tensor_dimensions.");
}
}

void BoundInterpreterFunction::fwdFlipInst(const FlipInst *I) {
dispatchImpl(fwdFlipInstImpl, I->getSrc()->getElementType(), I);
}
@@ -725,6 +725,14 @@ Error ONNXModelWriter::writeTranspose(const TransposeNode *node,
return writeAllWithNode("Transpose", node, graph, proto);
}

Error ONNXModelWriter::writeFlip(const FlipNode *node, GraphType &graph) {
auto *proto = graph.add_node();
// Add dictionary entries.
addValueAttribute(proto, "axis", node->getAxis());

return writeAllWithNode("Flip", node, graph, proto);
}

Error ONNXModelWriter::writeConvolution(const ConvolutionNode *node,
GraphType &graph) {
// Loading convolution creates a sandwich with Transpose nodes for Input,
@@ -1070,6 +1070,12 @@ TransposeNode *Function::createTranspose(llvm::StringRef name, NodeValue input,
return addNode(new TransposeNode(name, NT, input, shuffle.vec(), currLayout));
}

FlipNode *Function::createFlip(llvm::StringRef name, NodeValue input,
unsigned_t axis) {
auto OT = getParent()->uniqueType(*input.getType());
return addNode(new FlipNode(name, OT, input, axis));
}

Node *Function::createBroadcast(llvm::StringRef name, NodeValue input,
UnsignedArrayRef newShape, unsigned_t axis) {
const auto &origDims = input.dims();
@@ -905,6 +905,16 @@ bool TransposeNode::verify() const {
return isValid;
}

bool FlipNode::verify() const {
auto dest = getResult();
auto src = getInput();
dim_t axis = getAxis();
bool isValid = checkSameType(src, dest, this);
isValid &= expectCompareTrue("Invalid axis", axis, (dim_t)src.dims().size(),
this, CompareOperatorLess<dim_t>());
return isValid;
}

bool ChannelShuffleNode::verify() const {
bool isValid = expectCompareTrue("Channel shuffle into a different size.",
getResult().getType()->size(),
@@ -648,6 +648,7 @@ static bool acceptsAnyInputLayout(const glow::Node *node) {
case Kinded::Kind::ReshapeNodeKind:
case Kinded::Kind::MeanVarNormalizationNodeKind:
case Kinded::Kind::MatMulNodeKind:
case Kinded::Kind::FlipNodeKind:
case Kinded::Kind::SGDNodeKind: {
return true;
}
@@ -2097,6 +2097,22 @@ Error ONNXModelLoader::loadAdaptiveAvgPool(const ONNX_NAMESPACE::NodeProto &op,
return Error::success();
}

Error ONNXModelLoader::loadFlip(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict) {
NodeValue input;
ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));

unsigned_t axis = 0;
if (dict.count("axis")) {
ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict.at("axis")));
}

Node *N = G_.createFlip("flip", input, axis);

RETURN_IF_ERR(addNodeAsOutput(op, N));
return Error::success();
}

Error ONNXModelLoader::loadRowwiseQuantizedFullyConnected(
const ONNX_NAMESPACE::NodeProto &op, const ArgumentDictionaryTy &dict) {
// TODO
@@ -2248,6 +2264,9 @@ Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) {
if (typeName == "AdaptiveAvgPool") {
return loadAdaptiveAvgPool(op, dict);
}
if (typeName == "Flip") {
return loadFlip(op, dict);
}
if (typeName == "Identity") {
return loadIdentity(op, dict);
}
@@ -2280,6 +2280,20 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder,
break;
}

case Kinded::Kind::FlipInstKind: {
auto *FI = cast<FlipInst>(I);
auto *dest = FI->getDest();
auto *src = FI->getSrc();
auto *destPtr = emitValueAddress(builder, dest);
auto *srcPtr = emitValueAddress(builder, src);
auto *dims = emitValueDims(builder, src);
auto *axis = emitConstDimT(builder, FI->getAxis());
auto *dimsSize = emitConstDimT(builder, src->getType()->dims().size());
auto *F = getFunction("flip", src->getElementType());
createCall(builder, F, {srcPtr, destPtr, dims, axis, dimsSize});
break;
}

// Alloc and Dealloc instructions are handled by the memory allocator.
case Kinded::Kind::AllocActivationInstKind:
case Kinded::Kind::DeallocActivationInstKind:

0 comments on commit 2661e05

Please sign in to comment.
You can’t perform that action at this time.