Skip to content

Commit

Permalink
Add TensorFlowLite importer. (#4584)
Browse files Browse the repository at this point in the history
Summary:
**Summary**
- Add importer for TensorFlowLite models. The importer currently supports loading **57** operators out of the **125** operators defined by the TensorFlowLite flatbuffer format.
- Add extra operators for Interpreter and CPU:
  - Sqrt, Rsqrt, Reciprocical, Sin, Cos, CmpNEQ, CmpGT, CmpGTE, And, Or, Xor, Not, etc
- Refactor ArgMax to work with any number of dimensions and add ArgMin implementation.
- Python script to generate TensorFlowLite models for unit testing.

**Documentation**
Mention the TFLite support in AOT.md.

**Test Plan**
Unit tests for the TFLite loader and for the newly added operators.
Pull Request resolved: #4584

Reviewed By: yinghai

Differential Revision: D22015060

Pulled By: jackm321

fbshipit-source-id: acf90145caaa51054a1a0058fa1aa815b68957c1
  • Loading branch information
mciprian13 authored and facebook-github-bot committed Jun 18, 2020
1 parent 6798f6c commit 5855fdf
Show file tree
Hide file tree
Showing 265 changed files with 25,599 additions and 378 deletions.
9 changes: 5 additions & 4 deletions docs/AOT.md
Expand Up @@ -51,8 +51,8 @@ You can find the sample images used in the examples above in the following direc

## Compile a bundle for a floating-point model

The **model-compiler** front-end tool is the main Glow tool used to compile ONNX and
Caffe2 models into bundles. The tool is generic in the sense that it can compile
The **model-compiler** front-end tool is the main Glow tool used to compile ONNX, Caffe2 and
TensorFlowLite models into bundles. The tool is generic in the sense that it can compile
models with any number of inputs or outputs, without being limited to a particular
application.

Expand All @@ -72,12 +72,13 @@ model-compiler -backend=CPU -model=<model-path> -emit-bundle=<bundle-dir>
- The option `emit-bundle` specifies the output **directory** where all the bundle
artifacts will be generated. If the directory does not exist, it will be created.

There is a small difference when using this tool with ONNX versus Caffe2 models:
- For **ONNX models** the tool can infer automatically the inputs of the model
There is a small difference when using this tool with ONNX/TFLite versus Caffe2 models:
- For **ONNX or TensorFlowLite models** the tool can infer automatically the inputs of the model
since the description of the input tensors is part of the model. Therefore the tool
will be used in the form shown above:
```
model-compiler -backend=CPU -model=<onnx-model-path> -emit-bundle=<bundle-dir>
model-compiler -backend=CPU -model=<tflite-model-path> -emit-bundle=<bundle-dir>
```
- For **Caffe2 models** the user must also explicitly provide the description
of the input tensors which are not part of model. The option `model-input` will be used
Expand Down
8 changes: 8 additions & 0 deletions include/glow/Backends/Interpreter/InterpreterFunction.h
Expand Up @@ -231,6 +231,10 @@ class BoundInterpreterFunction : public IRInstructionProcessingHandler {
typename CmpTy = ElemTy>
void fwdElementCmpEQInstImpl(const ElementCmpEQInst *I);

template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
typename CmpTy = ElemTy>
void fwdElementCmpNEQInstImpl(const ElementCmpNEQInst *I);

template <typename ElemTy>
void fwdBatchOneHotImpl(const glow::BatchOneHotInst *I);

Expand Down Expand Up @@ -263,6 +267,10 @@ class BoundInterpreterFunction : public IRInstructionProcessingHandler {
template <typename ElemTy>
void fwdElementMinInstArithmeticImpl(const ElementMinInst *I);

template <typename ElemTy, typename InstKind>
void fwdUnaryArithmeticImpl(const InstKind *I,
std::function<float(float)> func);

template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
typename CmpTy = ElemTy>
void fwdElementCmpLTEInstImpl(const ElementCmpLTEInst *I);
Expand Down
17 changes: 17 additions & 0 deletions include/glow/Base/Tensor.h
Expand Up @@ -49,6 +49,14 @@ void genericTranspose(const Tensor *src, Tensor *dest,
/// returned dims. For example, input {2,1,4} would result in {2,1,4,1,1,1}.
ShapeVector expandDimsToMax(llvm::ArrayRef<dim_t> currDims);

/// Helper function that \returns a ShapeVector obtained from \p dims by
/// reducing (setting to 1) the dimensions given by \p axes. If the flag
/// \p keepDims is also used then the reduced dimensions are kept, otherwise
/// are pruned. For example, given the dimensions [2,3,4] and axes [0,2] the
/// returned shape will be [1,3,1] for keepDims true and [3] for keepDims false.
ShapeVector reduceDims(llvm::ArrayRef<dim_t> dims,
llvm::ArrayRef<unsigned_t> axes, bool keepDims);

namespace runtime {
class DeviceManager;
}
Expand Down Expand Up @@ -685,6 +693,15 @@ class Tensor final {
std::copy(&t->getData()[0], &t->getData()[bufferSize], getData());
}

/// Update the raw data of the tensor from a raw buffer \p data.
void copyRawFrom(const char *data) {
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
assert(data && "Null data pointer!");
assert(getData() != data && "Copying to self");
size_t bufferSize = type_.getSizeInBytes();
std::memcpy(getData(), data, bufferSize);
}

/// Update the content of the tensor with a slice from tensor \p t. A slice
/// is one index from the first dimension of the tensor.
void copySlice(const Tensor *t, size_t slice) {
Expand Down
79 changes: 74 additions & 5 deletions include/glow/Graph/Graph.h
Expand Up @@ -402,6 +402,16 @@ class Function final : public IRContainer {
/// @name High-level, operation-level IRBuilder.
///@{

/// Creates a PadNode with the given \p name and output type \p outTy which
/// pads the given \p input with the explicit pads \p pads according to the
/// padding mode \p mode and with the given value \p value. The padding mode
/// \p mode is one of enumeration values from \ref PaddingMode. For an input
/// with N dimensions (rank N) the \p pads must be a vector with 2*N values
/// with the following format:
/// pads = [pad_before(D1), pad_before(D2), ..., pad_before(DN),
/// pad_after (D1), pad_after (D2), ..., pad_after (DN)].
/// The mode PaddingMode::CONSTANT pads the input using the constant value
/// \p value and currently is the only mode supported.
PadNode *createPad(llvm::StringRef name, NodeValue input, TypeRef outTy,
unsigned_t mode, llvm::ArrayRef<int> pads, float value);

Expand Down Expand Up @@ -709,7 +719,8 @@ class Function final : public IRContainer {
LogitNode *createLogit(llvm::StringRef name, NodeValue input, float eps);

SoftMaxNode *createSoftMax(llvm::StringRef name, NodeValue input,
NodeValue selected, TypeRef outTy = nullptr);
NodeValue selected, TypeRef outTy = nullptr,
float beta = 1.0);

CrossEntropyLossNode *createCrossEntropyLoss(llvm::StringRef name,
NodeValue input,
Expand Down Expand Up @@ -787,9 +798,18 @@ class Function final : public IRContainer {
/// Computes the indices of the max elements of the input tensor along the
/// provided \p axis. The resulted tensor has the same rank as the input if \p
/// keepDims equal 1. If \p keepdims equals 0, the resulted tensor has the
/// reduced dimension pruned. The type of the output tensor is int64.
/// reduced dimension pruned. The type of the output tensor is \p elemTy.
ArgMaxNode *createArgMax(llvm::StringRef name, NodeValue input,
unsigned_t axis, bool keepDims);
unsigned_t axis, bool keepDims,
ElemKind elemTy = ElemKind::Int64ITy);

/// Computes the indices of the min elements of the input tensor along the
/// provided \p axis. The resulted tensor has the same rank as the input if \p
/// keepDims equal 1. If \p keepdims equals 0, the resulted tensor has the
/// reduced dimension pruned. The type of the output tensor is \p elemTy.
ArgMinNode *createArgMin(llvm::StringRef name, NodeValue input,
unsigned_t axis, bool keepDims,
ElemKind elemTy = ElemKind::Int64ITy);

/// Removes single-dimensional entries from the shape of a tensor. The
/// parameter \p axes is a list of positive integers, indicating the
Expand Down Expand Up @@ -854,6 +874,25 @@ class Function final : public IRContainer {
ModuloNode *createModulo(llvm::StringRef name, NodeValue input,
int64_t divisor, bool signFollowDivisor = false);

/// Create a logical NOT node with name \p name and input \p input.
NotNode *createNot(llvm::StringRef name, NodeValue input);

#define UNARY_ARITHMETIC_FUN_DECL(NODE_NAME_) \
NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue input); \
NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, TypeRef Ty, \
NodeValue input);
UNARY_ARITHMETIC_FUN_DECL(Abs)
UNARY_ARITHMETIC_FUN_DECL(Neg)
UNARY_ARITHMETIC_FUN_DECL(Floor)
UNARY_ARITHMETIC_FUN_DECL(Ceil)
UNARY_ARITHMETIC_FUN_DECL(Round)
UNARY_ARITHMETIC_FUN_DECL(Sqrt)
UNARY_ARITHMETIC_FUN_DECL(Rsqrt)
UNARY_ARITHMETIC_FUN_DECL(Reciprocal)
UNARY_ARITHMETIC_FUN_DECL(Sin)
UNARY_ARITHMETIC_FUN_DECL(Cos)
#undef UNARY_ARITHMETIC_FUN_DECL

#define ARITHMETIC_FUN_DECL(NODE_NAME_) \
NODE_NAME_##Node *create##NODE_NAME_(llvm::StringRef name, NodeValue LHS, \
NodeValue RHS); \
Expand All @@ -865,9 +904,13 @@ class Function final : public IRContainer {
ARITHMETIC_FUN_DECL(Div);
ARITHMETIC_FUN_DECL(Max);
ARITHMETIC_FUN_DECL(Min);
ARITHMETIC_FUN_DECL(CmpLTE);
ARITHMETIC_FUN_DECL(CmpLT);
ARITHMETIC_FUN_DECL(CmpEQ);
ARITHMETIC_FUN_DECL(CmpNEQ);
ARITHMETIC_FUN_DECL(CmpLT);
ARITHMETIC_FUN_DECL(CmpLTE);
ARITHMETIC_FUN_DECL(And);
ARITHMETIC_FUN_DECL(Or);
ARITHMETIC_FUN_DECL(Xor);
ARITHMETIC_FUN_DECL(Pow);
#undef ARITHMETIC_FUN_DECL

Expand Down Expand Up @@ -935,6 +978,32 @@ class Function final : public IRContainer {
#undef DECLARE_CMP_BROADCAST_NODE
#undef BROADCAST_FUNC_COMMON_CODE

/// Create an element-wise GREATER THAN comparison between \p LHS and \p RHS
/// by creating a CmpLTNode with given \p name and swapped inputs.
CmpLTNode *createCmpGT(llvm::StringRef name, NodeValue LHS, NodeValue RHS);

/// Create an element-wise GREATER THAN or EQUAL comparison between \p LHS and
/// \p RHS by creating a CmpLTENode with given \p name and swapped inputs.
CmpLTENode *createCmpGTE(llvm::StringRef name, NodeValue LHS, NodeValue RHS);

/// Create a MulNode with given \p name which multiplies \p input with itself
/// to produce an equivalent Square node.
MulNode *createSquare(llvm::StringRef name, NodeValue input);

/// Create a MulNode with given \p name and output type \p outTy which
/// multiplies \p input with itself to produce an equivalent Square node.
MulNode *createSquare(llvm::StringRef name, TypeRef outTy, NodeValue input);

/// Create an equivalent LeakyRELU node with given \p name, \p input and slope
/// \p alpha by using a SplatNode and a PRELU node.
PReluNode *createLeakyRELU(llvm::StringRef name, NodeValue input,
float alpha);

/// Create an equivalent LeakyRELU node with given \p name, \p outTy, \p input
/// and slope \p alpha by using a SplatNode and a PRELU node.
PReluNode *createLeakyRELU(llvm::StringRef name, TypeRef outTy,
NodeValue input, float alpha);

/// Create a node that produces an boolean output of the same shape as
/// \p input in which each element indicates whether or not the corresponding
/// element in \p input is NaN or not.
Expand Down
21 changes: 15 additions & 6 deletions include/glow/Importer/CommonOperatorLoader.h
Expand Up @@ -34,6 +34,17 @@
#include <vector>

namespace glow {

/// Some model formats (Caffe2, PyTorch, TensorFlowLite) allow defining weights
/// and activations in UINT8 format. Since Glow supports only INT8 weights and
/// activations we do a transformation from UINT8 to INT8 quantized data by
/// subtracting the value 128 from both the quantized values and the offset:
/// val(int8) = val(uint8) - 128
/// scale(int8) = scale(uint8) (scale value is preserved)
/// offset(int8) = scale(uint8) - 128
/// The constant definition below defines the value used for subtraction.
constexpr int32_t UINT8_TO_INT8_SHIFT = 128;

/// Result of loading a weight, potentially with additional offsets and
/// scales tensors containing quantization parameters only if the loaded weight
/// was found to have multiple quantization parameters.
Expand Down Expand Up @@ -120,9 +131,6 @@ class CommonOperatorLoader : public ProtobufLoader {
return Expected<LoadWeightResult>(std::move(result));
}

// This is a caffe2 offset shift.
constexpr int32_t OFFSETSHIFT = 128;

// Load quantized tensor with either a single or multiple qparams.
float scale = 1.0;
int32_t offset = 0;
Expand All @@ -142,15 +150,16 @@ class CommonOperatorLoader : public ProtobufLoader {

if (in.dataType == ONNXIFI_DATATYPE_UINT8) {
// Must copy the weights here because we will need to modify them by
// adjusting for OFFSETSHIFT.
result.type = Type(ElemKind::Int8QTy, dims, scale, offset - OFFSETSHIFT);
// adjusting for UINT8_TO_INT8_SHIFT.
result.type =
Type(ElemKind::Int8QTy, dims, scale, offset - UINT8_TO_INT8_SHIFT);
if (!in.isOffline) {
result.t->reset(result.type);

auto TH = result.t->getHandle<int8_t>();
uint8_t *data = (uint8_t *)in.buffer;
for (size_t i = 0; i < TH.size(); ++i) {
TH.raw(i) = (int8_t)(data[i] - OFFSETSHIFT);
TH.raw(i) = (int8_t)(data[i] - UINT8_TO_INT8_SHIFT);
}
}
} else if (in.dataType == ONNXIFI_DATATYPE_INT32) {
Expand Down
18 changes: 15 additions & 3 deletions include/glow/Importer/ProtobufLoader.h
Expand Up @@ -138,15 +138,27 @@ template <typename T> std::string loadOperatorName(const T &op) {
/// example an axis value of -1 for a tensor with 3 dimensions (rank 3) is
/// converted to 2. A good definition of the axis value requires to be in the
/// range [rank, rank-1].
Expected<int> getPositiveAxis(int axis, int rank);
template <typename T> Expected<T> getPositiveAxis(int axis, int rank) {
RETURN_ERR_IF_NOT(
(-rank <= axis) && (axis < rank),
strFormat("Axis value %d is invalid! Should be in the range [%d, %d]!",
axis, -rank, rank - 1));
int axisPos = (axis < 0) ? axis + rank : axis;
return static_cast<T>(axisPos);
}

/// \returns the positive value of \p axis given the rank of the value \p val.
template <typename T> Expected<T> getPositiveAxis(int axis, NodeValue val) {
return getPositiveAxis<T>(axis, val.dims().size());
}

/// Reads a single axis parameter which is wrapped if negative using \p rank
/// based on the logic of \ref getPositiveAxis.
template <typename ElemTy, typename T>
static Expected<ElemTy> loadAxis(const T *arg, int rank) {
int axis;
ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(arg));
ASSIGN_VALUE_OR_RETURN_ERR(axis, getPositiveAxis(axis, rank));
ASSIGN_VALUE_OR_RETURN_ERR(axis, getPositiveAxis<int>(axis, rank));
return static_cast<ElemTy>(axis);
}

Expand All @@ -159,7 +171,7 @@ static Expected<std::vector<ElemTy>> loadAxes(const T *arg, int rank) {
std::vector<ElemTy> axesPos;
for (int axis : axes) {
int axisPos;
ASSIGN_VALUE_OR_RETURN_ERR(axisPos, getPositiveAxis(axis, rank));
ASSIGN_VALUE_OR_RETURN_ERR(axisPos, getPositiveAxis<int>(axis, rank));
axesPos.push_back(static_cast<ElemTy>(axisPos));
}
return axesPos;
Expand Down

0 comments on commit 5855fdf

Please sign in to comment.