Skip to content
Permalink
Browse files

Adding SLS node instead of reducing to SLWS (#3039)

Summary:
Currently SLS is implemented by creating a splat node for constant 1.0  weights and reducing to SLWS. Better performance could be achieved by avoiding the unnecessary weights.
work-in-progress on this stack:
- add Habana impl
- add OpenCL impl
- do the same for the fused operator (SparseLengthsSumFused8BitRowwise)

Documentation:
Pull Request resolved: #3039

Reviewed By: bertmaher

Differential Revision: D15751807

Pulled By: bertmaher

fbshipit-source-id: 53bc1a5f07a4517c3ff2853c5a46729cadd1a227
  • Loading branch information...
mortzur authored and facebook-github-bot committed Jun 11, 2019
1 parent 09c73d4 commit 2605951a3fabe63c23e3da159b8b71864a547458
@@ -704,10 +704,10 @@ class Function final : public Named {
/// first Lengths[0] slices are aggregated to Result[0], next Lengths[1]
/// slices are aggregated to Result[1], etc. I.e. sum(Lengths) must be equal
/// to len(Indices).
SparseLengthsWeightedSumNode *createSparseLengthsSum(llvm::StringRef name,
NodeValue data,
NodeValue indices,
NodeValue lengths);
SparseLengthsSumNode *createSparseLengthsSum(llvm::StringRef name,
NodeValue data,
NodeValue indices,
NodeValue lengths);

/// Same as SparseLengthsSum, but i-th slice is multiplied by weights[i].
/// len(weights) must be equal to len(indices).
@@ -81,6 +81,15 @@ bool CPUBackend::isOpSupported(const NodeInfo &NI) const {
{ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
ElemKind::BoolTy});

case Kinded::Kind::SparseLengthsSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy}, {SparseLengthsSumNode::IndicesIdx,
SparseLengthsSumNode::LengthsIdx}) &&
(NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(SparseLengthsSumNode::LengthsIdx) ==
ElemKind::Int32ITy);

case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy},
@@ -312,9 +321,13 @@ bool CPUBackend::isOpSupported(const NodeInfo &NI) const {
}

bool CPUBackend::shouldLower(const Node *N) const {
if (N->getKind() == Kinded::Kind::ConvolutionNodeKind)
switch (N->getKind()) {
case Kinded::Kind::ConvolutionNodeKind:
case Kinded::Kind::SparseLengthsSumNodeKind:
return false;
return true;
default:
return true;
}
}

std::unique_ptr<CompiledFunction> CPUBackend::createCompiledFunction(
@@ -1126,6 +1126,22 @@ void libjit_lengths_to_ranges_i32(int32_t *ranges, const int32_t *lengths,
}
}

void libjit_sparse_lengths_sum_f(float *dest, float *data, size_t *indices,
int32_t *lengths, size_t segments,
size_t lineSize) {
memset(dest, 0, segments * lineSize * sizeof(float));
size_t curIndex = 0;
for (size_t i = 0; i < segments; i++) {
for (int32_t j = 0; j < lengths[i]; j++) {
size_t line = indices[curIndex];
for (size_t k = 0; k < lineSize; k++) {
dest[i * lineSize + k] += data[line * lineSize + k];
}
curIndex++;
}
}
}

void libjit_sparse_lengths_weighted_sum_f(float *dest, float *data,
float *weights, size_t *indices,
int32_t *lengths, size_t segments,
@@ -170,6 +170,16 @@ bool Interpreter::isOpSupported(const NodeInfo &NI) const {
(NI.getOutElemTy(RowwiseQuantizedFullyConnectedNode::ResultIdx) ==
ElemKind::Int8QTy);

case Kinded::Kind::SparseLengthsSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
{SparseLengthsSumNode::IndicesIdx,
SparseLengthsSumNode::LengthsIdx}) &&
(NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(SparseLengthsSumNode::LengthsIdx) ==
ElemKind::Int32ITy);

case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
@@ -382,7 +392,11 @@ bool Interpreter::isOpSupported(const NodeInfo &NI) const {
}

bool Interpreter::shouldLower(const Node *N) const {
if (N->getKind() == Kinded::Kind::ConvolutionNodeKind)
switch (N->getKind()) {
case Kinded::Kind::ConvolutionNodeKind:
case Kinded::Kind::SparseLengthsSumNodeKind:
return false;
return true;
default:
return true;
}
}
@@ -238,6 +238,10 @@ class BoundInterpreterFunction {
template <typename ElemTy>
void fwdGatherRangesInstImpl(const GatherRangesInst *I);

void fwdSparseLengthsSumInstI8Impl(const SparseLengthsSumInst *I);
template <typename ElemTy>
void fwdSparseLengthsSumInstFloatImpl(const SparseLengthsSumInst *I);

void
fwdSparseLengthsWeightedSumInstI8Impl(const SparseLengthsWeightedSumInst *I);
template <typename ElemTy>
@@ -2312,6 +2312,102 @@ void BoundInterpreterFunction::fwdLengthsSumInst(const LengthsSumInst *I) {
I->getData()->getElementType(), I)
}

void BoundInterpreterFunction::fwdSparseLengthsSumInstI8Impl(
const SparseLengthsSumInst *I) {

auto out = getTensor(I->getDest());
auto data = getTensor(I->getData());
auto indices = getTensor(I->getIndices());
auto lengths = getTensor(I->getLengths());

out->zero();

auto IH = indices->getHandle<int64_t>();
auto LH = lengths->getHandle<int32_t>();

size_t segments = lengths->dims()[0];
size_t totalLength = 0;
for (size_t i = 0; i < segments; i++) {
totalLength += LH.raw(i);
}
assert(totalLength <= indices->dims()[0] &&
"sum(Lengths) must be equal to len(Indices)");

size_t lineSize = data->size() / data->dims()[0];

auto DH = data->getHandle<int8_t>();
auto OH = out->getHandle<int8_t>();

auto TQP = [](Tensor *T) {
return TensorQuantizationParams{T->getType().getScale(),
T->getType().getOffset()};
};

size_t curIdx = 0;
for (size_t i = 0; i < segments; i++) {
std::vector<float> accum(lineSize, 0.0f);
for (int32_t j = 0; j < LH.raw(i); j++) {
size_t offsetIn = IH.raw(curIdx) * lineSize;
for (size_t k = 0; k < lineSize; k++) {
accum[k] += quantization::dequantize(DH.raw(offsetIn++), TQP(data));
}
curIdx++;
}
size_t offsetOut = i * lineSize;
for (size_t k = 0; k < lineSize; k++) {
OH.raw(offsetOut++) = quantization::quantize(accum[k], TQP(out));
}
}
}

template <typename ElemTy>
void BoundInterpreterFunction::fwdSparseLengthsSumInstFloatImpl(
const SparseLengthsSumInst *I) {
staticAssertFloatingPointType(ElemTy);

auto out = getTensor(I->getDest());
auto data = getTensor(I->getData());
auto indices = getTensor(I->getIndices());
auto lengths = getTensor(I->getLengths());

out->zero();

auto IH = indices->getHandle<int64_t>();
auto LH = lengths->getHandle<int32_t>();

size_t segments = lengths->dims()[0];
size_t totalLength = 0;
for (size_t i = 0; i < segments; i++) {
totalLength += LH.raw(i);
}
assert(totalLength <= indices->dims()[0] &&
"sum(Lengths) must be equal to len(Indices)");

size_t lineSize = data->size() / data->dims()[0];

auto DH = data->getHandle<ElemTy>();
auto OH = out->getHandle<ElemTy>();

size_t curIdx = 0;
for (size_t i = 0; i < segments; i++) {
for (size_t j = 0, e = LH.raw(i); j < e; j++) {
size_t offsetIn = IH.raw(curIdx++) * lineSize;
size_t offsetOut = i * lineSize;
for (size_t k = 0; k < lineSize; k++)
OH.raw(offsetOut++) += DH.raw(offsetIn++);
}
}
}

void BoundInterpreterFunction::fwdSparseLengthsSumInst(
const SparseLengthsSumInst *I) {
if (I->getDest()->getType()->isQuantizedType()) {
return fwdSparseLengthsSumInstI8Impl(I);
}
dispatchFloatingPointImpl(fwdSparseLengthsSumInstFloatImpl,
I->getData()->getElementType(), I);
}

template <typename ElemTy>
void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl(
const SparseLengthsWeightedSumInst *I) {
@@ -253,7 +253,6 @@ void ExecutionEngine::compile(Function *F, CompilationContext &cctx,
"A function with this name has already been compiled.");

EXIT_ON_ERR(::glow::optimizeFunction(F, *backend_, cctx));

for (const Node &N : F->getNodes()) {
CHECK(backend_->isOpSupported(N))
<< "Backend must support all nodes after high-level optimizations but "
@@ -1374,13 +1374,15 @@ LengthsSumNode *Function::createLengthsSum(llvm::StringRef name, NodeValue data,
return addNode(new LengthsSumNode(name, outTy, data, lengths));
}

SparseLengthsWeightedSumNode *
Function::createSparseLengthsSum(llvm::StringRef name, NodeValue data,
NodeValue indices, NodeValue lengths) {
auto ty =
getParent()->uniqueTypeWithNewShape(data.getType(), {indices.dims()[0]});
auto ones = createSplat(name.str() + ".ones", ty, 1.0);
return createSparseLengthsWeightedSum(name, data, ones, indices, lengths);
SparseLengthsSumNode *Function::createSparseLengthsSum(llvm::StringRef name,
NodeValue data,
NodeValue indices,
NodeValue lengths) {
auto inDims = data.dims();
ShapeVector outDims(inDims.begin(), inDims.end());
outDims[0] = lengths.dims()[0];
auto outTy = getParent()->uniqueTypeWithNewShape(data.getType(), outDims);
return addNode(new SparseLengthsSumNode(name, outTy, data, indices, lengths));
}

SparseLengthsWeightedSumNode *
@@ -376,6 +376,20 @@ static bool verifyRegression(NodeValue src, NodeValue dest,
checkSameType(dest, expected, dest.getNode());
}

static bool verifySparseLengthsSum(NodeValue dest, NodeValue data,
NodeValue indices, NodeValue lengths) {
bool isValid = checkType(dest, data.getElementType(), dest.getNode());
isValid &= checkType(indices, ElemKind::Int64ITy, dest.getNode());
isValid &= checkType(lengths, ElemKind::Int32ITy, dest.getNode());
isValid &=
expectCompareTrue("Indices must be a 1D vector", indices.dims().size(),
size_t(1), dest.getNode());
isValid &=
expectCompareTrue("Lengths must be a 1D vector", lengths.dims().size(),
size_t(1), dest.getNode());
return isValid;
}

static bool verifySparseLengthsWeightedSum(NodeValue dest, NodeValue data,
NodeValue weights, NodeValue indices,
NodeValue lengths) {
@@ -920,6 +934,11 @@ bool BatchedReduceMeanNode::verify() const {
return isValid;
}

bool SparseLengthsSumNode::verify() const {
return verifySparseLengthsSum(getResult(), getData(), getIndices(),
getLengths());
}

bool SparseLengthsWeightedSumNode::verify() const {
return verifySparseLengthsWeightedSum(getResult(), getData(), getWeights(),
getIndices(), getLengths());
@@ -2231,6 +2231,24 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder,
break;
}

case Kinded::Kind::SparseLengthsSumInstKind: {
auto *SI = cast<SparseLengthsSumInst>(I);
auto *dest = SI->getDest();
auto *data = SI->getData();
auto *indices = SI->getIndices();
auto *lengths = SI->getLengths();
auto *destPtr = emitValueAddress(builder, dest);
auto *dataPtr = emitValueAddress(builder, data);
auto *indicesPtr = emitValueAddress(builder, indices);
auto *lengthsPtr = emitValueAddress(builder, lengths);
auto *segments = emitConstSizeT(builder, lengths->dims()[0]);
auto *lineSize = emitConstSizeT(builder, data->size() / data->dims()[0]);
auto *F = getFunction("sparse_lengths_sum", dest->getElementType());
createCall(builder, F,
{destPtr, dataPtr, indicesPtr, lengthsPtr, segments, lineSize});
break;
}

case Kinded::Kind::SparseLengthsWeightedSumInstKind: {
auto *SI = cast<SparseLengthsWeightedSumInst>(I);
auto *dest = SI->getDest();
@@ -794,6 +794,18 @@ static void lowerBatchMatMulNode(Function *F, CompilationContext &cctx,
replaceAllUsesOfWith(cctx.loweredInfoMap, BMMN.getResult(), RN);
}

static void lowerSparseLengthsSumNode(Function *F, CompilationContext &cctx,
const SparseLengthsSumNode &SLSN) {
auto ty = F->getParent()->uniqueTypeWithNewShape(
SLSN.getData().getType(), {SLSN.getIndices().dims()[0]});
auto *ones = F->createSplat(SLSN.getName().str() + ".ones", ty, 1.0);
auto *SLWSN = F->createSparseLengthsWeightedSum(
SLSN.getName().str(), SLSN.getData(), ones, SLSN.getIndices(),
SLSN.getLengths());

replaceAllUsesOfWith(cctx.loweredInfoMap, SLSN.getResult(), SLWSN);
}

/// Lowers \p node given Function \p. \p cctx contains a mapping of loweredMap
/// that will log the lowering info of what was replaced by what via output
/// names.
@@ -849,6 +861,8 @@ static void lowerNode(Function *F, Node *node, CompilationContext &cctx) {
lowerReplaceNaNNode(F, cctx, *RN);
} else if (auto *BMMN = dyn_cast<BatchMatMulNode>(node)) {
lowerBatchMatMulNode(F, cctx, *BMMN);
} else if (auto *SLSN = dyn_cast<SparseLengthsSumNode>(node)) {
lowerSparseLengthsSumNode(F, cctx, *SLSN);
}
}

@@ -1237,7 +1237,7 @@ TEST_F(Habana, SparseLengthsSum) {
synTensor sls1Inputs[] = {dataT, i1T, i11T, sbT};
ns_SparseLengthsSum::Params sls1Params;
sls1Params.mode = SEPARATE_SC_ZP;
synCreateGenericNode(sls1Inputs, &save1T, 4, 1, (void *)&sls1Params,
synCreateGenericNode(sls1Inputs, &save1T, 3, 1, (void *)&sls1Params,
"sparse_lengths_sum_f32", "sls1", nullptr, nullptr);

// Compile graph.
@@ -1190,15 +1190,13 @@ TEST(onnx, importSparseLengthsSum) {
// Verify structure: PH, PH -> SparseLengthsSum -> Save -> PH.
// PH -> Splat /
ASSERT_EQ(mod.getPlaceholders().size(), 4);
ASSERT_EQ(F->getNodes().size(), 3);
ASSERT_EQ(F->getNodes().size(), 2);
auto *save = getSaveNodeFromDest(output);
auto *LS =
llvm::dyn_cast<SparseLengthsWeightedSumNode>(save->getInput().getNode());
auto *LS = llvm::dyn_cast<SparseLengthsSumNode>(save->getInput().getNode());
ASSERT_TRUE(LS);
ASSERT_TRUE(llvm::isa<Placeholder>(LS->getData()));
ASSERT_TRUE(llvm::isa<Placeholder>(LS->getIndices()));
ASSERT_TRUE(llvm::isa<Placeholder>(LS->getLengths()));
ASSERT_TRUE(llvm::isa<SplatNode>(LS->getWeights()));
}

/// Test loading LengthsSum from an ONNX model.
@@ -5082,6 +5082,7 @@ TEST_P(OperatorTest, SparseLengthsSum) {
};

auto *R = F_->createSparseLengthsSum("SLS", data, indices, lengths);

auto *S = F_->createSave("save", R);
bindings_.allocate(S->getPlaceholder());

@@ -225,6 +225,18 @@ int main(int argc, char **argv) {
.autoVerify(VerifyKind::SameElementType,
{"Lengths", "ElemKind::Int32ITy"});

BB.newInstr("SparseLengthsSum")
.addOperand("Dest", OperandKind::Out)
.addOperand("Data", OperandKind::In)
.addOperand("Indices", OperandKind::In)
.addOperand("Lengths", OperandKind::In)
.autoIRGen()
.autoVerify(VerifyKind::SameElementType, {"Dest", "Data"})
.autoVerify(VerifyKind::SameElementType,
{"Indices", "ElemKind::Int64ITy"})
.autoVerify(VerifyKind::SameElementType,
{"Lengths", "ElemKind::Int32ITy"});

BB.newInstr("SparseLengthsWeightedSum")
.addOperand("Dest", OperandKind::Out)
.addOperand("Data", OperandKind::In)

0 comments on commit 2605951

Please sign in to comment.
You can’t perform that action at this time.