Skip to content

Commit

Permalink
Add fb::embedding_bag_byte_rowwise_offsets (#3792)
Browse files Browse the repository at this point in the history
Summary:
* Add loader for `fb::embedding_bag_byte_rowwise_offsets`
* Add `EmbeddingBagByteRowwiseOffsetsNode` similar to `FusedRowwiseQuantizedSparseLengthsWeightedSum` but with offsets instead of lengths

Documentation:
doxygen
Pull Request resolved: #3792

Test Plan:
Added operator test
Not sure yet how to add a test for the loader

Differential Revision: D18594773

Pulled By: jackm321

fbshipit-source-id: 505a3fd8843eb8ebc469644af8082fbf2f9ec7f9
  • Loading branch information
jackm321 authored and facebook-github-bot committed Nov 21, 2019
1 parent d80ef34 commit 49b1a15
Show file tree
Hide file tree
Showing 19 changed files with 447 additions and 43 deletions.
14 changes: 14 additions & 0 deletions include/glow/Graph/Graph.h
Expand Up @@ -914,6 +914,20 @@ class Function final : public Named {
NodeValue weights, NodeValue indices, NodeValue weights, NodeValue indices,
NodeValue offsets); NodeValue offsets);


/// Create an EmbeddingBagByteRowwiseOffsetsNode node.
EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets(
llvm::StringRef name, NodeValue data, NodeValue weights,
NodeValue indices, NodeValue offsets, bool useFP16Accumulation = false);

/// Same as \ref createEmbeddingBagByteRowwiseOffsets(), but
/// expects float input \p data, which is rowwise-quantized and fused
/// internally. \p fusedElemKind represents the element kind to use for the
/// final fused rowwise-quantized data.
EmbeddingBagByteRowwiseOffsetsNode *createEmbeddingBagByteRowwiseOffsets(
llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
NodeValue offsets, ElemKind fusedElemKind = ElemKind::UInt8FusedQTy,
bool useFP16Accumulation = false);

/// Same as \ref createSparseLengthsWeightedSum(), but with \p outTy /// Same as \ref createSparseLengthsWeightedSum(), but with \p outTy
/// specified. /// specified.
SparseLengthsWeightedSumNode * SparseLengthsWeightedSumNode *
Expand Down
12 changes: 12 additions & 0 deletions lib/Backends/CPU/CPUBackend.cpp
Expand Up @@ -310,6 +310,18 @@ bool CPUBackend::isOpSupported(const NodeInfo &NI) const {
{ElemKind::FloatTy}, {LengthsSumNode::LengthsIdx}) && {ElemKind::FloatTy}, {LengthsSumNode::LengthsIdx}) &&
(NI.getInElemTy(LengthsSumNode::LengthsIdx) == ElemKind::Int32ITy); (NI.getInElemTy(LengthsSumNode::LengthsIdx) == ElemKind::Int32ITy);


case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind:
return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx) ==
ElemKind::UInt8FusedQTy) &&
(NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
ElemKind::FloatTy) &&
(NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) ==
ElemKind::Int32ITy) &&
(NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
ElemKind::FloatTy);

case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind:
return (NI.getInElemTy( return (NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) == FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
Expand Down
25 changes: 25 additions & 0 deletions lib/Backends/CPU/libjit/libjit.cpp
Expand Up @@ -1458,6 +1458,31 @@ void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f(
} }
} }


void libjit_embedding_bag_byte_rowwise_offsets_f(
float *dest, int8_t *data, float *weights, size_t *indices,
int32_t *offsets, size_t segments, size_t numIndices, size_t inLineSize,
size_t outLineSize) {
memset(dest, 0, segments * outLineSize * sizeof(float));
for (size_t i = 0; i < segments; i++) {
size_t start = offsets[i];
size_t end = i == segments - 1 ? numIndices : offsets[i + 1];
for (int32_t j = start; j < end; j++) {
const float weight = weights[j];
const size_t line = indices[j];
const int8_t *currRowScaleOffsetPtr =
data + ((line + 1) * inLineSize) - 2 * sizeof(float);
float scale, offset;
memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
for (size_t k = 0; k < outLineSize; k++) {
const float fData =
(scale * (uint8_t)(data[line * inLineSize + k])) + offset;
dest[i * outLineSize + k] += weight * fData;
}
}
}
}

void libjit_sparse_to_dense_f(float *dest, const size_t *indices, void libjit_sparse_to_dense_f(float *dest, const size_t *indices,
const float *values, size_t numIndices, const float *values, size_t numIndices,
size_t destSize, size_t valueSize) { size_t destSize, size_t valueSize) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Backends/CPU/tests/CPUOperatorTest.cpp
Expand Up @@ -96,6 +96,9 @@ std::set<std::string> glow::backendTestBlacklist = {
"back/0", "back/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_" "FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_"
"back2/0", "back2/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16/0",
"EmbeddingBagByteRowwiseOffsets_ConvertedFloat16/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat/0", "FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16/0", "FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16/0",
"FusedRowwiseQuantizedSparseLengthsSum_Fused4Bit_Float16_AccumFloat16/0", "FusedRowwiseQuantizedSparseLengthsSum_Fused4Bit_Float16_AccumFloat16/0",
Expand Down
4 changes: 4 additions & 0 deletions lib/Backends/Habana/tests/HabanaOperatorTest.cpp
Expand Up @@ -164,6 +164,10 @@ std::set<std::string> glow::backendTestBlacklist = {
"back2/0", "back2/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat/0", "FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16/0", "FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16/0",
"EmbeddingBagByteRowwiseOffsets_ConvertedFloat16/0",
"EmbeddingBagByteRowwiseOffsets_Float/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16/0",
"FusedRWQSLSAllZeroLengths_Float/0", "FusedRWQSLSAllZeroLengths_Float/0",
"FusedRWQSLSAllZeroLengths_Float16/0", "FusedRWQSLSAllZeroLengths_Float16/0",
"GatherDataFloat16IdxInt32/0", "GatherDataFloat16IdxInt32/0",
Expand Down
25 changes: 25 additions & 0 deletions lib/Backends/Interpreter/Interpreter.cpp
Expand Up @@ -321,6 +321,31 @@ bool Interpreter::isOpSupported(const NodeInfo &NI) const {
RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) == RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) ==
ElemKind::Int32ITy); ElemKind::Int32ITy);


case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind: {
if (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) !=
ElemKind::Int64ITy ||
NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) !=
ElemKind::Int32ITy) {
return false;
}

switch (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx)) {
case ElemKind::UInt4FusedFP16QTy:
case ElemKind::UInt8FusedFP16QTy:
return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
ElemKind::Float16Ty) &&
(NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
ElemKind::Float16Ty);
case ElemKind::UInt8FusedQTy:
return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
ElemKind::FloatTy) &&
(NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
ElemKind::FloatTy);
default:
return false;
}
}

case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: { case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: {
if (NI.getInElemTy( if (NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) != FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) !=
Expand Down
4 changes: 4 additions & 0 deletions lib/Backends/Interpreter/InterpreterFunction.h
Expand Up @@ -316,6 +316,10 @@ class BoundInterpreterFunction {
template <typename T, typename AccumT> template <typename T, typename AccumT>
void fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl( void fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl(
const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I); const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I);

template <typename T, typename AccumT>
void fwdEmbeddingBagByteRowwiseOffsetsImpl(
const EmbeddingBagByteRowwiseOffsetsInst *I);
///@} ///@}
}; };


Expand Down
76 changes: 76 additions & 0 deletions lib/Backends/Interpreter/InterpreterNodes.cpp
Expand Up @@ -3523,6 +3523,82 @@ void BoundInterpreterFunction::
} }
} }


template <typename T, typename AccumT>
void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsImpl(
const EmbeddingBagByteRowwiseOffsetsInst *I) {
auto *out = getTensor(I->getDest());
auto *data = getTensor(I->getData());
auto *weights = getTensor(I->getWeights());
auto *indices = getTensor(I->getIndices());
auto *offsets = getTensor(I->getOffsets());

out->zero();

auto IH = indices->getHandle<int64_t>();
auto OFFH = offsets->getHandle<int32_t>();

size_t segments = offsets->dims()[0];
size_t numIndices = indices->dims()[0];

const bool using4BitQuantization =
data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy;

const size_t outLineSize = out->size() / out->dims()[0];

auto DH = data->getHandle<uint8_t>();
auto WH = weights->getHandle<T>();
auto OH = out->getHandle<T>();

for (size_t i = 0; i < segments; i++) {
std::vector<AccumT> accum(outLineSize, 0.0f);
size_t start = OFFH.raw(i);
size_t end = i == segments - 1 ? numIndices : OFFH.raw(i + 1);
for (size_t j = start; j < end; j++) {
const float weight = static_cast<float>(WH.raw(j));
const size_t rowIdx = IH.raw(j);
T scale, offset;
std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<T>(rowIdx);
for (size_t k = 0; k < outLineSize; k++) {
float d = 0.0f;
if (!using4BitQuantization) {
d = quantization::dequantizeWithFloatOffset(
DH.at({rowIdx, k}), static_cast<float>(scale),
static_cast<float>(offset));
} else {
const bool isMSB = (k % 2 == 1);
d = quantization::dequantize4BitWithFloatOffset(
DH.at({rowIdx, k / 2}), static_cast<float>(scale),
static_cast<float>(offset), isMSB);
}
accum[k] += d * weight;
}
}
// Accumulation in FP32 complete, now copy back to output with cast to T.
size_t offsetOut = i * outLineSize;
for (size_t k = 0; k < outLineSize; k++) {
OH.raw(offsetOut++) = static_cast<T>(accum[k]);
}
}
}

void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsInst(
const EmbeddingBagByteRowwiseOffsetsInst *I) {
switch (I->getDest()->getElementType()) {
case ElemKind::FloatTy:
fwdEmbeddingBagByteRowwiseOffsetsImpl<float, float>(I);
break;
case ElemKind::Float16Ty:
if (I->getUseFP16Accumulation()) {
fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float16_t>(I);
} else {
fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float>(I);
}
break;
default:
llvm_unreachable("Type is not supported");
}
}

void BoundInterpreterFunction::fwdLengthsToRangesInst( void BoundInterpreterFunction::fwdLengthsToRangesInst(
const LengthsToRangesInst *I) { const LengthsToRangesInst *I) {
auto ranges = getTensor(I->getDest())->getHandle<int32_t>(); auto ranges = getTensor(I->getDest())->getHandle<int32_t>();
Expand Down
7 changes: 6 additions & 1 deletion lib/Backends/NNPI/tests/NNPIOperatorTest.cpp
Expand Up @@ -94,6 +94,10 @@ std::set<std::string> glow::backendTestBlacklist = {
"EmbeddingBag_1D_Float16/0", "EmbeddingBag_1D_Float16/0",
"EmbeddingBag_2D_Float/0", "EmbeddingBag_2D_Float/0",
"EmbeddingBag_2D_Float16/0", "EmbeddingBag_2D_Float16/0",
"EmbeddingBagByteRowwiseOffsets_ConvertedFloat16_back_",
"EmbeddingBagByteRowwiseOffsets_Float/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16/0",
"SparseToDense/0", "SparseToDense/0",
"SparseToDenseMask1/0", "SparseToDenseMask1/0",
"SparseToDenseMask2/0", "SparseToDenseMask2/0",
Expand Down Expand Up @@ -144,7 +148,8 @@ struct EmulatorOnlyTests {
"Exp/0", "Exp/0",
"FloatArgMaxKeepDim/0", "FloatArgMaxKeepDim/0",
"FloatArgMaxNoKeepDim/0", "FloatArgMaxNoKeepDim/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_" "FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_"
"back_",
"to_back2/0", "to_back2/0",
"GroupDilatedConvolution/0", "GroupDilatedConvolution/0",
"less_int32Cases/0", "less_int32Cases/0",
Expand Down
4 changes: 4 additions & 0 deletions lib/Backends/OpenCL/tests/OpenCLOperatorTest.cpp
Expand Up @@ -178,6 +178,10 @@ std::set<std::string> glow::backendTestBlacklist = {
"back/0", "back/0",
"FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_" "FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_"
"back2/0", "back2/0",
"EmbeddingBagByteRowwiseOffsets_Float/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat/0",
"EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16/0",
"EmbeddingBagByteRowwiseOffsets_ConvertedFloat16/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float/0", "FusedRowwiseQuantizedSparseLengthsSum_Float/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat/0", "FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat/0",
"FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16/0", "FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16/0",
Expand Down
1 change: 1 addition & 0 deletions lib/Exporter/ONNXModelWriter.cpp
Expand Up @@ -1397,6 +1397,7 @@ DEF_ALL_WRITER_NODE(Regression)
DEF_ALL_WRITER_NODE(RowwiseQuantizedFullyConnected) DEF_ALL_WRITER_NODE(RowwiseQuantizedFullyConnected)
DEF_ALL_WRITER_NODE(RowwiseQuantizedSparseLengthsWeightedSum) DEF_ALL_WRITER_NODE(RowwiseQuantizedSparseLengthsWeightedSum)
DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsSum) DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsSum)
DEF_ALL_WRITER_NODE(EmbeddingBagByteRowwiseOffsets)
DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsWeightedSum) DEF_ALL_WRITER_NODE(FusedRowwiseQuantizedSparseLengthsWeightedSum)


Error ONNXModelWriter::writeClip(const ClipNode *node, GraphType &graph) { Error ONNXModelWriter::writeClip(const ClipNode *node, GraphType &graph) {
Expand Down
59 changes: 40 additions & 19 deletions lib/Graph/Graph.cpp
Expand Up @@ -1652,19 +1652,6 @@ Function::createSparseLengthsWeightedSum(llvm::StringRef name, TypeRef outTy,
indices, lengths)); indices, lengths));
} }


EmbeddingBagNode *Function::createEmbeddingBag(llvm::StringRef name,
NodeValue data,
NodeValue weights,
NodeValue indices,
NodeValue offsets) {
auto inDims = data.dims();
ShapeVector outDims(inDims.begin(), inDims.end());
outDims[0] = offsets.dims()[0];
auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
return addNode(
new EmbeddingBagNode(name, outTy, data, weights, indices, offsets));
}

RowwiseQuantizedSparseLengthsWeightedSumNode * RowwiseQuantizedSparseLengthsWeightedSumNode *
Function::createRowwiseQuantizedSparseLengthsWeightedSum( Function::createRowwiseQuantizedSparseLengthsWeightedSum(
llvm::StringRef name, Constant *data, Constant *scales, Constant *offsets, llvm::StringRef name, Constant *data, Constant *scales, Constant *offsets,
Expand Down Expand Up @@ -1755,15 +1742,15 @@ Function::createRowwiseQuantizedSparseLengthsSum(
} }


/// Helper used to get specific output type required for /// Helper used to get specific output type required for
/// createRowwiseQuantizedSparseLengthsSum and /// createRowwiseQuantizedSparseLengthsSum,
/// createRowwiseQuantizedSparseLengthsWeightedSum. /// createRowwiseQuantizedSparseLengthsWeightedSum, and
/// Function \p F is used to get the specific type, using inputs \p data and /// EmbeddingBagByteRowwiseOffsets. Function \p F is used to get the specific
/// \p lengthsDims to compute output dimensions. /// type, using inputs \p data and \p segmentsDim to compute output dimensions.
static TypeRef static TypeRef
getOutputTypeOfFusedRowwiseQuantizedSLS(Function *F, NodeValue data, getOutputTypeOfFusedRowwiseQuantizedSLS(Function *F, NodeValue data,
llvm::ArrayRef<size_t> lengthsDims) { llvm::ArrayRef<size_t> segmentsDim) {
ShapeVector outDims(data.dims().begin(), data.dims().end()); ShapeVector outDims(data.dims().begin(), data.dims().end());
outDims[0] = lengthsDims[0]; outDims[0] = segmentsDim[0];
// The output column count is the same as the input column count, but // The output column count is the same as the input column count, but
// without the extra bytes for the fused scale/offset, as the output is not // without the extra bytes for the fused scale/offset, as the output is not
// fused. // fused.
Expand Down Expand Up @@ -1876,6 +1863,40 @@ Function::createFusedRowwiseQuantizedSparseLengthsSum(
name, rwqData, indices, lengths, useFP16Accumulation); name, rwqData, indices, lengths, useFP16Accumulation);
} }


EmbeddingBagNode *Function::createEmbeddingBag(llvm::StringRef name,
NodeValue data,
NodeValue weights,
NodeValue indices,
NodeValue offsets) {
auto inDims = data.dims();
ShapeVector outDims(inDims.begin(), inDims.end());
outDims[0] = offsets.dims()[0];
auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
return addNode(
new EmbeddingBagNode(name, outTy, data, weights, indices, offsets));
}

EmbeddingBagByteRowwiseOffsetsNode *
Function::createEmbeddingBagByteRowwiseOffsets(
llvm::StringRef name, Tensor &data, NodeValue weights, NodeValue indices,
NodeValue offsets, ElemKind fusedElemKind, bool useFP16Accumulation) {
Constant *rwqData =
quantizeDataForFusedRowwiseQuantizedSparseLengthsWeightedSum(
this, data, fusedElemKind);
return createEmbeddingBagByteRowwiseOffsets(name, rwqData, weights, indices,
offsets, useFP16Accumulation);
}

EmbeddingBagByteRowwiseOffsetsNode *
Function::createEmbeddingBagByteRowwiseOffsets(
llvm::StringRef name, NodeValue data, NodeValue weights, NodeValue indices,
NodeValue offsets, bool useFP16Accumulation) {
auto outTy =
getOutputTypeOfFusedRowwiseQuantizedSLS(this, data, offsets.dims());
return addNode(new EmbeddingBagByteRowwiseOffsetsNode(
name, outTy, data, weights, indices, offsets, useFP16Accumulation));
}

LengthsToRangesNode *Function::createLengthsToRanges(llvm::StringRef name, LengthsToRangesNode *Function::createLengthsToRanges(llvm::StringRef name,
NodeValue lengths) { NodeValue lengths) {
ShapeVector outDims({lengths.dims()[0], 2}); ShapeVector outDims({lengths.dims()[0], 2});
Expand Down
6 changes: 6 additions & 0 deletions lib/Graph/Nodes.cpp
Expand Up @@ -1325,6 +1325,12 @@ static bool verifyFusedRowwiseQuantizedSparseLengthsSum(
return isValid; return isValid;
} }


bool EmbeddingBagByteRowwiseOffsetsNode::verify() const {
return verifyFusedRowwiseQuantizedSparseLengthsSum(
getResult(), getData(), getIndices(), getOffsets(), getWeights(),
getUseFP16Accumulation());
}

bool FusedRowwiseQuantizedSparseLengthsWeightedSumNode::verify() const { bool FusedRowwiseQuantizedSparseLengthsWeightedSumNode::verify() const {
return verifyFusedRowwiseQuantizedSparseLengthsSum( return verifyFusedRowwiseQuantizedSparseLengthsSum(
getResult(), getData(), getIndices(), getLengths(), getWeights(), getResult(), getData(), getIndices(), getLengths(), getWeights(),
Expand Down
24 changes: 24 additions & 0 deletions lib/LLVMIRCodeGen/LLVMIRGen.cpp
Expand Up @@ -2574,6 +2574,30 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder,
break; break;
} }


case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsInstKind: {
auto *N = cast<EmbeddingBagByteRowwiseOffsetsInst>(I);
auto *dest = N->getDest();
auto *data = N->getData();
auto *weights = N->getWeights();
auto *indices = N->getIndices();
auto *offsets = N->getOffsets();
auto *destPtr = emitValueAddress(builder, dest);
auto *dataPtr = emitValueAddress(builder, data);
auto *weightsPtr = emitValueAddress(builder, weights);
auto *indicesPtr = emitValueAddress(builder, indices);
auto *offsetsPtr = emitValueAddress(builder, offsets);
auto *segments = emitConstSizeT(builder, offsets->dims()[0]);
auto *numIndices = emitConstSizeT(builder, indices->dims()[0]);
auto *inLineSize = emitConstSizeT(builder, data->size() / data->dims()[0]);
auto *outLineSize = emitConstSizeT(builder, dest->size() / dest->dims()[0]);
auto *F = getFunction("embedding_bag_byte_rowwise_offsets",
dest->getElementType());
createCall(builder, F,
{destPtr, dataPtr, weightsPtr, indicesPtr, offsetsPtr, segments,
numIndices, inLineSize, outLineSize});
break;
}

case Kinded::Kind::SparseToDenseInstKind: { case Kinded::Kind::SparseToDenseInstKind: {
auto *STDI = llvm::cast<SparseToDenseInst>(I); auto *STDI = llvm::cast<SparseToDenseInst>(I);
auto *indices = STDI->getIndices(); auto *indices = STDI->getIndices();
Expand Down

0 comments on commit 49b1a15

Please sign in to comment.