Skip to content
Permalink
Browse files

Add fb::embedding_bag_byte_rowwise_offsets (#3792)

Summary:
* Add loader for `fb::embedding_bag_byte_rowwise_offsets`
* Add `EmbeddingBagByteRowwiseOffsetsNode` similar to `FusedRowwiseQuantizedSparseLengthsWeightedSum` but with offsets instead of lengths

Documentation:
doxygen
Pull Request resolved: #3792

Test Plan:
Added operator test
Not sure yet how to add a test for the loader

Differential Revision: D18594773

Pulled By: jackm321

fbshipit-source-id: 505a3fd8843eb8ebc469644af8082fbf2f9ec7f9
  • Loading branch information
jackm321 authored and facebook-github-bot committed Nov 21, 2019
1 parent d80ef34 commit 49b1a1587d9070409e048f21d292dc39d1b6df27
@@ -914,6 +914,20 @@ class Function final : public Named {
NodeValue weights, NodeValue indices,
NodeValue offsets);

/// Create an EmbeddingBagByteRowwiseOffsetsNode node.
EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets(
llvm::StringRef name, NodeValue data, NodeValue weights,
NodeValue indices, NodeValue offsets, bool useFP16Accumulation = false);

/// Same as \ref createEmbeddingBagByteRowwiseOffsets(), but
/// expects float input \p data, which is rowwise-quantized and fused
/// internally. \p fusedElemKind represents the element kind to use for the
/// final fused rowwise-quantized data.
EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets(
llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
NodeValue offsets, ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
bool useFP16Accumulation = false);

/// Same as \ref createSparseLengthsWeightedSum(), but with \p outTy
/// specified.
SparseLengthsWeightedSumNode *
@@ -310,6 +310,18 @@ bool CPUBackend::isOpSupported(const NodeInfo &NI) const {
{ElemKind::FloatTy}, {LengthsSumNode::LengthsIdx}) &&
(NI.getInElemTy(LengthsSumNode::LengthsIdx) == ElemKind::Int32ITy);

case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind:
return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx) ==
ElemKind::UInt8FusedQTy) &&
(NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
ElemKind::FloatTy) &&
(NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) ==
ElemKind::Int32ITy) &&
(NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
ElemKind::FloatTy);

case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind:
return (NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
@@ -1458,6 +1458,31 @@ void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f(
}
}

void libjit_embedding_bag_byte_rowwise_offsets_f(
float *dest, int8_t *data, float *weights, size_t *indices,
int32_t *offsets, size_t segments, size_t numIndices, size_t inLineSize,
size_t outLineSize) {
memset(dest, 0, segments * outLineSize * sizeof(float));
for (size_t i = 0; i < segments; i++) {
size_t start = offsets[i];
size_t end = i == segments - 1 ? numIndices : offsets[i + 1];
for (int32_t j = start; j < end; j++) {
const float weight = weights[j];
const size_t line = indices[j];
const int8_t *currRowScaleOffsetPtr =
data + ((line + 1) * inLineSize) - 2 * sizeof(float);
float scale, offset;
memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
for (size_t k = 0; k < outLineSize; k++) {
const float fData =
(scale * (uint8_t)(data[line * inLineSize + k])) + offset;
dest[i * outLineSize + k] += weight * fData;
}
}
}
}

void libjit_sparse_to_dense_f(float *dest, const size_t *indices,
const float *values, size_t numIndices,
size_t destSize, size_t valueSize) {
@@ -96,6 +96,9 @@ std::set<std::string> glow::backendTestBlacklist = {
"back/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_"
"back2/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16/0",
"EmbeddingBagByteRowwiseOffsets_ConvertedFloat16/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16/0",
"FusedRowwiseQuantizedSparseLengthsSum_Fused4Bit_Float16_AccumFloat16/0",
@@ -164,6 +164,10 @@ std::set<std::string> glow::backendTestBlacklist = {
"back2/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16/0",
"EmbeddingBagByteRowwiseOffsets_ConvertedFloat16/0",
"EmbeddingBagByteRowwiseOffsets_Float/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16/0",
"FusedRWQSLSAllZeroLengths_Float/0",
"FusedRWQSLSAllZeroLengths_Float16/0",
"GatherDataFloat16IdxInt32/0",
@@ -321,6 +321,31 @@ bool Interpreter::isOpSupported(const NodeInfo &NI) const {
RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) ==
ElemKind::Int32ITy);

case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind: {
if (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) !=
ElemKind::Int64ITy ||
NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) !=
ElemKind::Int32ITy) {
return false;
}

switch (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx)) {
case ElemKind::UInt4FusedFP16QTy:
case ElemKind::UInt8FusedFP16QTy:
return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
ElemKind::Float16Ty) &&
(NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
ElemKind::Float16Ty);
case ElemKind::UInt8FusedQTy:
return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
ElemKind::FloatTy) &&
(NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
ElemKind::FloatTy);
default:
return false;
}
}

case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: {
if (NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) !=
@@ -316,6 +316,10 @@ class BoundInterpreterFunction {
template <typename T, typename AccumT>
void fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl(
const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I);

template <typename T, typename AccumT>
void fwdEmbeddingBagByteRowwiseOffsetsImpl(
const EmbeddingBagByteRowwiseOffsetsInst *I);
///@}
};

@@ -3523,6 +3523,82 @@ void BoundInterpreterFunction::
}
}

template <typename T, typename AccumT>
void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsImpl(
const EmbeddingBagByteRowwiseOffsetsInst *I) {
auto *out = getTensor(I->getDest());
auto *data = getTensor(I->getData());
auto *weights = getTensor(I->getWeights());
auto *indices = getTensor(I->getIndices());
auto *offsets = getTensor(I->getOffsets());

out->zero();

auto IH = indices->getHandle<int64_t>();
auto OFFH = offsets->getHandle<int32_t>();

size_t segments = offsets->dims()[0];
size_t numIndices = indices->dims()[0];

const bool using4BitQuantization =
data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy;

const size_t outLineSize = out->size() / out->dims()[0];

auto DH = data->getHandle<uint8_t>();
auto WH = weights->getHandle<T>();
auto OH = out->getHandle<T>();

for (size_t i = 0; i < segments; i++) {
std::vector<AccumT> accum(outLineSize, 0.0f);
size_t start = OFFH.raw(i);
size_t end = i == segments - 1 ? numIndices : OFFH.raw(i + 1);
for (size_t j = start; j < end; j++) {
const float weight = static_cast<float>(WH.raw(j));
const size_t rowIdx = IH.raw(j);
T scale, offset;
std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<T>(rowIdx);
for (size_t k = 0; k < outLineSize; k++) {
float d = 0.0f;
if (!using4BitQuantization) {
d = quantization::dequantizeWithFloatOffset(
DH.at({rowIdx, k}), static_cast<float>(scale),
static_cast<float>(offset));
} else {
const bool isMSB = (k % 2 == 1);
d = quantization::dequantize4BitWithFloatOffset(
DH.at({rowIdx, k / 2}), static_cast<float>(scale),
static_cast<float>(offset), isMSB);
}
accum[k] += d * weight;
}
}
// Accumulation in FP32 complete, now copy back to output with cast to T.
size_t offsetOut = i * outLineSize;
for (size_t k = 0; k < outLineSize; k++) {
OH.raw(offsetOut++) = static_cast<T>(accum[k]);
}
}
}

void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsInst(
const EmbeddingBagByteRowwiseOffsetsInst *I) {
switch (I->getDest()->getElementType()) {
case ElemKind::FloatTy:
fwdEmbeddingBagByteRowwiseOffsetsImpl<float, float>(I);
break;
case ElemKind::Float16Ty:
if (I->getUseFP16Accumulation()) {
fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float16_t>(I);
} else {
fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float>(I);
}
break;
default:
llvm_unreachable("Type is not supported");
}
}

void BoundInterpreterFunction::fwdLengthsToRangesInst(
const LengthsToRangesInst *I) {
auto ranges = getTensor(I->getDest())->getHandle<int32_t>();
@@ -94,6 +94,10 @@ std::set<std::string> glow::backendTestBlacklist = {
"EmbeddingBag_1D_Float16/0",
"EmbeddingBag_2D_Float/0",
"EmbeddingBag_2D_Float16/0",
"EmbeddingBagByteRowwiseOffsets_ConvertedFloat16_back_",
"EmbeddingBagByteRowwiseOffsets_Float/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16/0",
"SparseToDense/0",
"SparseToDenseMask1/0",
"SparseToDenseMask2/0",
@@ -144,7 +148,8 @@ struct EmulatorOnlyTests {
"Exp/0",
"FloatArgMaxKeepDim/0",
"FloatArgMaxNoKeepDim/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_"
"FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_"
"back_",
"to_back2/0",
"GroupDilatedConvolution/0",
"less_int32Cases/0",
@@ -178,6 +178,10 @@ std::set<std::string> glow::backendTestBlacklist = {
"back/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_"
"back2/0",
"EmbeddingBagByteRowwiseOffsets_Float/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16/0",
"EmbeddingBagByteRowwiseOffsets_ConvertedFloat16/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16/0",
@@ -1397,6 +1397,7 @@ DEF_ALL_WRITER_NODE(Regression)
DEF_ALL_WRITER_NODE(RowwiseQuantizedFullyConnected)
DEF_ALL_WRITER_NODE(RowwiseQuantizedSparseLengthsWeightedSum)
DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsSum)
DEF_ALL_WRITER_NODE(EmbeddingBagByteRowwiseOffsets)
DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsWeightedSum)

Error ONNXModelWriter::writeClip(const ClipNode *node, GraphType &graph) {
@@ -1652,19 +1652,6 @@ Function::createSparseLengthsWeightedSum(llvm::StringRef name, TypeRef outTy,
indices, lengths));
}

EmbeddingBagNode *Function::createEmbeddingBag(llvm::StringRef name,
NodeValue data,
NodeValue weights,
NodeValue indices,
NodeValue offsets) {
auto inDims = data.dims();
ShapeVector outDims(inDims.begin(), inDims.end());
outDims[0] = offsets.dims()[0];
auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
return addNode(
new EmbeddingBagNode(name, outTy, data, weights, indices, offsets));
}

RowwiseQuantizedSparseLengthsWeightedSumNode *
Function::createRowwiseQuantizedSparseLengthsWeightedSum(
llvm::StringRef name, Constant *data, Constant *scales, Constant *offsets,
@@ -1755,15 +1742,15 @@ Function::createRowwiseQuantizedSparseLengthsSum(
}

/// Helper used to get specific output type required for
/// createRowwiseQuantizedSparseLengthsSum and
/// createRowwiseQuantizedSparseLengthsWeightedSum.
/// Function \p F is used to get the specific type, using inputs \p data and
/// \p lengthsDims to compute output dimensions.
/// createRowwiseQuantizedSparseLengthsSum,
/// createRowwiseQuantizedSparseLengthsWeightedSum, and
/// EmbeddingBagByteRowwiseOffsets. Function \p F is used to get the specific
/// type, using inputs \p data and \p segmentsDim to compute output dimensions.
static TypeRef
getOutputTypeOfFusedRowwiseQuantizedSLS(Function *F, NodeValue data,
llvm::ArrayRef<size_t> lengthsDims) {
llvm::ArrayRef<size_t> segmentsDim) {
ShapeVector outDims(data.dims().begin(), data.dims().end());
outDims[0] = lengthsDims[0];
outDims[0] = segmentsDim[0];
// The output column count is the same as the input column count, but
// without the extra bytes for the fused scale/offset, as the output is not
// fused.
@@ -1876,6 +1863,40 @@ Function::createFusedRowwiseQuantizedSparseLengthsSum(
name, rwqData, indices, lengths, useFP16Accumulation);
}

EmbeddingBagNode *Function::createEmbeddingBag(llvm::StringRef name,
NodeValue data,
NodeValue weights,
NodeValue indices,
NodeValue offsets) {
auto inDims = data.dims();
ShapeVector outDims(inDims.begin(), inDims.end());
outDims[0] = offsets.dims()[0];
auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
return addNode(
new EmbeddingBagNode(name, outTy, data, weights, indices, offsets));
}

EmbeddingBagByteRowwiseOffsetsNode *
Function::createEmbeddingBagByteRowwiseOffsets(
llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
NodeValue offsets, ElemKind fusedElemKind, bool useFP16Accumulation) {
Constant *rwqData =
quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
this, data, fusedElemKind);
return createEmbeddingBagByteRowwiseOffsets(name, rwqData, weights, indices,
offsets, useFP16Accumulation);
}

EmbeddingBagByteRowwiseOffsetsNode *
Function::createEmbeddingBagByteRowwiseOffsets(
llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
NodeValue offsets, bool useFP16Accumulation) {
auto outTy =
getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, offsets.dims());
return addNode(new EmbeddingBagByteRowwiseOffsetsNode(
name, outTy, data, weights, indices, offsets, useFP16Accumulation));
}

LengthsToRangesNode *Function::createLengthsToRanges(llvm::StringRef name,
NodeValue lengths) {
ShapeVector outDims({lengths.dims()[0], 2});
@@ -1325,6 +1325,12 @@ static bool verifyFusedRowwiseQuantizedSparseLengthsSum(
return isValid;
}

bool EmbeddingBagByteRowwiseOffsetsNode::verify() const {
return verifyFusedRowwiseQuantizedSparseLengthsSum(
getResult(), getData(), getIndices(), getOffsets(), getWeights(),
getUseFP16Accumulation());
}

bool FusedRowwiseQuantizedSparseLengthsWeightedSumNode::verify() const {
return verifyFusedRowwiseQuantizedSparseLengthsSum(
getResult(), getData(), getIndices(), getLengths(), getWeights(),
@@ -2574,6 +2574,30 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder,
break;
}

case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsInstKind: {
auto *N = cast<EmbeddingBagByteRowwiseOffsetsInst>(I);
auto *dest = N->getDest();
auto *data = N->getData();
auto *weights = N->getWeights();
auto *indices = N->getIndices();
auto *offsets = N->getOffsets();
auto *destPtr = emitValueAddress(builder, dest);
auto *dataPtr = emitValueAddress(builder, data);
auto *weightsPtr = emitValueAddress(builder, weights);
auto *indicesPtr = emitValueAddress(builder, indices);
auto *offsetsPtr = emitValueAddress(builder, offsets);
auto *segments = emitConstSizeT(builder, offsets->dims()[0]);
auto *numIndices = emitConstSizeT(builder, indices->dims()[0]);
auto *inLineSize = emitConstSizeT(builder, data->size() / data->dims()[0]);
auto *outLineSize = emitConstSizeT(builder, dest->size() / dest->dims()[0]);
auto *F = getFunction("embedding_bag_byte_rowwise_offsets",
dest->getElementType());
createCall(builder, F,
{destPtr, dataPtr, weightsPtr, indicesPtr, offsetsPtr, segments,
numIndices, inLineSize, outLineSize});
break;
}

case Kinded::Kind::SparseToDenseInstKind: {
auto *STDI = llvm::cast<SparseToDenseInst>(I);
auto *indices = STDI->getIndices();

0 comments on commit 49b1a15

Please sign in to comment.
You can’t perform that action at this time.