diff --git a/src/common/algorithm/Sampler.h b/src/common/algorithm/Sampler.h new file mode 100644 index 00000000000..ebcde5b433a --- /dev/null +++ b/src/common/algorithm/Sampler.h @@ -0,0 +1,179 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#ifndef COMMON_ALGORITHM_SAMPLER_H_ +#define COMMON_ALGORITHM_SAMPLER_H_ + +#include +#include +#include +#include +#include +#include + +namespace nebula { +namespace algorithm { +template +T UniformRandom() { + static_assert(std::is_floating_point::value, "Only support float point type"); +#if defined(__clang__) + static std::default_random_engine e(std::time(nullptr)); + static std::uniform_real_distribution u(0., 1.); +#elif defined(__GNUC__) || defined(__GNUG__) + static thread_local std::default_random_engine e(std::time(nullptr)); + static thread_local std::uniform_real_distribution u(0., 1.); +#endif + return u(e); +} + +template +void Normalization(std::vector& distribution) { + static_assert(std::is_floating_point::value, "Only support float point type"); + T norm_sum = 0.0f; + for (auto& dist : distribution) { + norm_sum += dist; + } + if (norm_sum <= FLT_EPSILON && !distribution.empty()) { + for (size_t i = 0; i < distribution.size(); ++i) { + distribution[i] = 1.0f / static_cast(distribution.size()); + } + return; + } + for (size_t i = 0; i < distribution.size(); ++i) { + distribution[i] /= norm_sum; + } +} + +// https://en.wikipedia.org/wiki/Alias_method +template +class AliasSampler { + public: + static_assert(std::is_floating_point::value, "Only support float point type"); + using AliasType = uint32_t; + bool Init(std::vector& distribution); + inline bool Init(const std::vector& distribution); + AliasType Sample() const; + inline size_t Size() const; + + private: + std::vector prob_; + std::vector alias_; +}; + +template +bool AliasSampler::Init(std::vector& distribution) { + // normalization sum of distribution to 1 + Normalization(distribution); + + prob_.resize(distribution.size()); + alias_.resize(distribution.size()); + std::vector smaller, larger; + smaller.reserve(distribution.size()); + larger.reserve(distribution.size()); + + for (size_t i = 0; i < distribution.size(); ++i) { + prob_[i] = distribution[i] * distribution.size(); + if (prob_[i] < 1.0) { + smaller.push_back(i); + } else { + larger.push_back(i); + } + } + // Construct the probability and alias tables + AliasType small, large; + while (!smaller.empty() && !larger.empty()) { + small = smaller.back(); + smaller.pop_back(); + large = larger.back(); + larger.pop_back(); + alias_[small] = large; + prob_[large] = prob_[large] + prob_[small] - 1.0; + if (prob_[large] < 1.0) { + smaller.push_back(large); + } else { + larger.push_back(large); + } + } + while (!smaller.empty()) { + small = smaller.back(); + smaller.pop_back(); + prob_[small] = 1.0; + } + while (!larger.empty()) { + large = larger.back(); + larger.pop_back(); + prob_[large] = 1.0; + } + return true; +} + +template +bool AliasSampler::Init(const std::vector& distribution) { + std::vector dist = distribution; + return Init(dist); +} + +template +typename AliasSampler::AliasType AliasSampler::Sample() const { + AliasType roll = floor(prob_.size() * UniformRandom()); + bool coin = UniformRandom() < prob_[roll]; + return coin ? roll : alias_[roll]; +} + +template +size_t AliasSampler::Size() const { + return prob_.size(); +} + +/** + * binary sample in accumulation weights + */ +template +size_t BinarySampleAcc(const std::vector& accumulate_weights) { + if (accumulate_weights.empty()) { + return 0; + } + T rnd = UniformRandom() * accumulate_weights.back(); + size_t low = 0, high = accumulate_weights.size() - 1, mid = 0; + while (low <= high) { + mid = ((high - low) >> 1) + low; + if (rnd < accumulate_weights[mid]) { + if (mid == 0) { + return mid; + } + high = mid - 1; + if (high >= 0 && rnd >= accumulate_weights[high]) { + // rnd in [mid-1, mid) + return mid; + } + } else { + low = mid + 1; + if (low < accumulate_weights.size() && rnd < accumulate_weights[low]) { + // rnd in [mid, mid+1) + return low; + } + } + } + return mid; +} + +/** + * binary sample in weights + */ +template +size_t BinarySample(const std::vector& weights) { + std::vector accumulate_weights(weights.size(), 0.0f); + T cur_weight = 0.0f; + for (size_t i = 0; i < weights.size(); ++i) { + cur_weight += weights[i]; + accumulate_weights[i] = cur_weight; + } + Normalization(accumulate_weights); + return BinarySampleAcc(accumulate_weights); +} + +} // namespace algorithm +} // namespace nebula +#endif diff --git a/src/graph/context/ast/CypherAstContext.h b/src/graph/context/ast/CypherAstContext.h index a71f682453d..dcd9be2579c 100644 --- a/src/graph/context/ast/CypherAstContext.h +++ b/src/graph/context/ast/CypherAstContext.h @@ -11,8 +11,8 @@ #include "common/expression/Expression.h" #include "common/expression/PathBuildExpression.h" #include "graph/context/ast/AstContext.h" +#include "graph/planner/plan/Query.h" #include "parser/MatchSentence.h" - namespace nebula { namespace graph { enum class CypherClauseKind : uint8_t { @@ -22,6 +22,7 @@ enum class CypherClauseKind : uint8_t { kWhere, kReturn, kOrderBy, + kSampling, kPagination, kYield, kShortestPath, @@ -142,6 +143,12 @@ struct OrderByClauseContext final : CypherClauseContextBase { std::vector> indexedOrderFactors; }; +struct SamplingClauseContext final : CypherClauseContextBase { + SamplingClauseContext() : CypherClauseContextBase(CypherClauseKind::kSampling) {} + + std::vector indexedSamplingFactors; +}; + struct PaginationContext final : CypherClauseContextBase { PaginationContext() : CypherClauseContextBase(CypherClauseKind::kPagination) {} @@ -176,6 +183,7 @@ struct YieldClauseContext final : CypherClauseContextBase { struct ReturnClauseContext final : CypherClauseContextBase { ReturnClauseContext() : CypherClauseContextBase(CypherClauseKind::kReturn) {} + std::unique_ptr sampling; std::unique_ptr order; std::unique_ptr pagination; std::unique_ptr yield; @@ -184,6 +192,7 @@ struct ReturnClauseContext final : CypherClauseContextBase { struct WithClauseContext final : CypherClauseContextBase { WithClauseContext() : CypherClauseContextBase(CypherClauseKind::kWith) {} + std::unique_ptr sampling; std::unique_ptr order; std::unique_ptr pagination; std::unique_ptr where; diff --git a/src/graph/executor/CMakeLists.txt b/src/graph/executor/CMakeLists.txt index 0b2f00a5936..8c98f8fe034 100644 --- a/src/graph/executor/CMakeLists.txt +++ b/src/graph/executor/CMakeLists.txt @@ -26,6 +26,7 @@ nebula_add_library( query/UnwindExecutor.cpp query/SortExecutor.cpp query/TopNExecutor.cpp + query/SamplingExecutor.cpp query/IndexScanExecutor.cpp query/SetExecutor.cpp query/UnionExecutor.cpp diff --git a/src/graph/executor/Executor.cpp b/src/graph/executor/Executor.cpp index 4656559c06d..eeeb6be8c43 100644 --- a/src/graph/executor/Executor.cpp +++ b/src/graph/executor/Executor.cpp @@ -86,6 +86,7 @@ #include "graph/executor/query/ProjectExecutor.h" #include "graph/executor/query/RollUpApplyExecutor.h" #include "graph/executor/query/SampleExecutor.h" +#include "graph/executor/query/SamplingExecutor.h" #include "graph/executor/query/ScanEdgesExecutor.h" #include "graph/executor/query/ScanVerticesExecutor.h" #include "graph/executor/query/SortExecutor.h" @@ -180,6 +181,9 @@ Executor *Executor::makeExecutor(QueryContext *qctx, const PlanNode *node) { case PlanNode::Kind::kTopN: { return pool->makeAndAdd(node, qctx); } + case PlanNode::Kind::kSampling: { + return pool->makeAndAdd(node, qctx); + } case PlanNode::Kind::kFilter: { return pool->makeAndAdd(node, qctx); } diff --git a/src/graph/executor/query/SamplingExecutor.cpp b/src/graph/executor/query/SamplingExecutor.cpp new file mode 100644 index 00000000000..0d658196ee2 --- /dev/null +++ b/src/graph/executor/query/SamplingExecutor.cpp @@ -0,0 +1,113 @@ +// Copyright (c) 2020 vesoft inc. All rights reserved. +// +// This source code is licensed under Apache 2.0 License. + +#include "graph/executor/query/SamplingExecutor.h" + +#include "common/algorithm/Sampler.h" +#include "graph/planner/plan/Query.h" + +namespace nebula { +namespace graph { + +using WeightType = float; + +folly::Future SamplingExecutor::execute() { + SCOPED_TIMER(&execTime_); + auto *sampling = asNode(node()); + Result result = ectx_->getResult(sampling->inputVar()); + auto *iter = result.iterRef(); + if (UNLIKELY(iter == nullptr)) { + return Status::Error("Internal error: nullptr iterator in sampling executor"); + } + if (UNLIKELY(!result.iter()->isSequentialIter())) { + std::stringstream ss; + ss << "Internal error: Sampling executor does not supported " << iter->kind(); + return Status::Error(ss.str()); + } + auto &factors = sampling->factors(); + auto size = iter->size(); + if (size <= 0) { + iter->clear(); + return finish(ResultBuilder().value(result.valuePtr()).iter(std::move(result).iter()).build()); + } + auto colNames = result.value().getDataSet().colNames; + DataSet dataset(std::move(colNames)); + for (auto factor : factors) { + if (factor.count <= 0) { + iter->clear(); + return finish( + ResultBuilder().value(result.valuePtr()).iter(std::move(result).iter()).build()); + } + if (factor.samplingType == SamplingFactor::SamplingType::BINARY) { + executeBinarySample(iter, factor.colIdx, factor.count, dataset); + } else { + executeAliasSample(iter, factor.colIdx, factor.count, dataset); + } + } + return finish( + ResultBuilder().value(Value(std::move(dataset))).iter(Iterator::Kind::kSequential).build()); +} + +template +void SamplingExecutor::executeBinarySample(Iterator *iter, + size_t index, + size_t count, + DataSet &list) { + auto uIter = static_cast(iter); + std::vector accumulateWeights; + auto it = uIter->begin(); + WeightType v; + while (it != uIter->end()) { + v = 1.0; + if ((*it)[index].type() == Value::Type::FLOAT) { + v = static_cast((*it)[index].getFloat()); + } else if ((*it)[index].type() == Value::Type::INT) { + v = static_cast((*it)[index].getInt()); + } + if (!accumulateWeights.empty()) { + v += accumulateWeights.back(); + } + accumulateWeights.emplace_back(std::move(v)); + ++it; + } + nebula::algorithm::Normalization(accumulateWeights); + auto beg = uIter->begin(); + for (size_t i = 0; i < count; ++i) { + auto idx = nebula::algorithm::BinarySampleAcc(accumulateWeights); + list.emplace_back(*(beg + idx)); + } + uIter->clear(); +} + +template +void SamplingExecutor::executeAliasSample(Iterator *iter, + size_t index, + size_t count, + DataSet &list) { + auto uIter = static_cast(iter); + std::vector weights; + auto it = uIter->begin(); + WeightType v; + while (it != uIter->end()) { + v = 1.0; + if ((*it)[index].type() == Value::Type::FLOAT) { + v = static_cast((*it)[index].getFloat()); + } else if ((*it)[index].type() == Value::Type::INT) { + v = static_cast((*it)[index].getInt()); + } + weights.emplace_back(std::move(v)); + ++it; + } + nebula::algorithm::AliasSampler sampler_; + sampler_.Init(weights); + auto beg = uIter->begin(); + for (size_t i = 0; i < count; ++i) { + auto idx = sampler_.Sample(); + list.emplace_back(*(beg + idx)); + } + uIter->clear(); +} + +} // namespace graph +} // namespace nebula diff --git a/src/graph/executor/query/SamplingExecutor.h b/src/graph/executor/query/SamplingExecutor.h new file mode 100644 index 00000000000..0c5dfca0c34 --- /dev/null +++ b/src/graph/executor/query/SamplingExecutor.h @@ -0,0 +1,28 @@ +// Copyright (c) 2020 vesoft inc. All rights reserved. +// +// This source code is licensed under Apache 2.0 License. + +#ifndef GRAPH_EXECUTOR_QUERY_SAMPLINGEXECUTOR_H_ +#define GRAPH_EXECUTOR_QUERY_SAMPLINGEXECUTOR_H_ + +#include "graph/executor/Executor.h" +namespace nebula { +namespace graph { + +class SamplingExecutor final : public Executor { + public: + SamplingExecutor(const PlanNode *node, QueryContext *qctx) + : Executor("SamplingExecutor", node, qctx) {} + + folly::Future execute() override; + + private: + template + void executeBinarySample(Iterator *iter, size_t index, size_t count, DataSet &list); + template + void executeAliasSample(Iterator *iter, size_t index, size_t count, DataSet &list); +}; + +} // namespace graph +} // namespace nebula +#endif // GRAPH_EXECUTOR_QUERY_SAMPLINGEXECUTOR_H_ diff --git a/src/graph/planner/CMakeLists.txt b/src/graph/planner/CMakeLists.txt index 8aa88ad1a93..b7d1883a95d 100644 --- a/src/graph/planner/CMakeLists.txt +++ b/src/graph/planner/CMakeLists.txt @@ -14,6 +14,7 @@ nebula_add_library( match/UnwindClausePlanner.cpp match/ReturnClausePlanner.cpp match/OrderByClausePlanner.cpp + match/SamplingClausePlanner.cpp match/YieldClausePlanner.cpp match/PaginationPlanner.cpp match/WhereClausePlanner.cpp diff --git a/src/graph/planner/match/SamplingClausePlanner.cpp b/src/graph/planner/match/SamplingClausePlanner.cpp new file mode 100644 index 00000000000..bbc900d4a4f --- /dev/null +++ b/src/graph/planner/match/SamplingClausePlanner.cpp @@ -0,0 +1,31 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#include "graph/planner/match/SamplingClausePlanner.h" + +#include "graph/planner/plan/Query.h" + +namespace nebula { +namespace graph { +StatusOr SamplingClausePlanner::transform(CypherClauseContextBase* clauseCtx) { + if (clauseCtx->kind != CypherClauseKind::kSampling) { + return Status::Error("Not a valid context for SamplingClausePlanner."); + } + auto* samplingCtx = static_cast(clauseCtx); + + SubPlan samplingPlan; + NG_RETURN_IF_ERROR(buildSampling(samplingCtx, samplingPlan)); + return samplingPlan; +} + +Status SamplingClausePlanner::buildSampling(SamplingClauseContext* octx, SubPlan& subplan) { + auto* currentRoot = subplan.root; + auto* sampling = Sampling::make(octx->qctx, currentRoot, octx->indexedSamplingFactors); + subplan.root = sampling; + subplan.tail = sampling; + return Status::OK(); +} +} // namespace graph +} // namespace nebula diff --git a/src/graph/planner/match/SamplingClausePlanner.h b/src/graph/planner/match/SamplingClausePlanner.h new file mode 100644 index 00000000000..fb4eeb9d16d --- /dev/null +++ b/src/graph/planner/match/SamplingClausePlanner.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#ifndef GRAPH_PLANNER_MATCH_ORDERBYCLAUSEPLANNER_H_ +#define GRAPH_PLANNER_MATCH_ORDERBYCLAUSEPLANNER_H_ + +#include "graph/planner/match/CypherClausePlanner.h" + +namespace nebula { +namespace graph { +// The SamplingClausePlanner generates plan for order by clause; +class SamplingClausePlanner final : public CypherClausePlanner { + public: + SamplingClausePlanner() = default; + + StatusOr transform(CypherClauseContextBase* clauseCtx) override; + + Status buildSampling(SamplingClauseContext* octx, SubPlan& subplan); +}; +} // namespace graph +} // namespace nebula +#endif // GRAPH_PLANNER_MATCH_ORDERBYCLAUSEPLANNER_H_ diff --git a/src/graph/planner/plan/PlanNode.cpp b/src/graph/planner/plan/PlanNode.cpp index d9b9ebd8a7e..e10ebaf0247 100644 --- a/src/graph/planner/plan/PlanNode.cpp +++ b/src/graph/planner/plan/PlanNode.cpp @@ -88,6 +88,8 @@ const char* PlanNode::toString(PlanNode::Kind kind) { return "Limit"; case Kind::kSample: return "Sample"; + case Kind::kSampling: + return "Sampling"; case Kind::kAggregate: return "Aggregate"; case Kind::kSelect: diff --git a/src/graph/planner/plan/PlanNode.h b/src/graph/planner/plan/PlanNode.h index eb1991f1aa5..87f26800e06 100644 --- a/src/graph/planner/plan/PlanNode.h +++ b/src/graph/planner/plan/PlanNode.h @@ -58,6 +58,7 @@ class PlanNode { kTopN, kLimit, kSample, + kSampling, kAggregate, kDedup, kAssign, diff --git a/src/graph/planner/plan/Query.cpp b/src/graph/planner/plan/Query.cpp index 6dd9c037f48..96ab77826e0 100644 --- a/src/graph/planner/plan/Query.cpp +++ b/src/graph/planner/plan/Query.cpp @@ -501,6 +501,28 @@ void Sort::cloneMembers(const Sort& p) { factors_ = std::move(factors); } +std::unique_ptr Sampling::explain() const { + auto desc = SingleInputNode::explain(); + addDescription("factors", folly::toJson(util::toJson(factorsString())), desc.get()); + return desc; +} + +PlanNode* Sampling::clone() const { + auto* newSampling = Sampling::make(qctx_, nullptr); + newSampling->cloneMembers(*this); + return newSampling; +} + +void Sampling::cloneMembers(const Sampling& p) { + SingleInputNode::cloneMembers(p); + + std::vector factors; + for (const auto& factor : p.factors()) { + factors.emplace_back(factor); + } + factors_ = std::move(factors); +} + // Get constant count value int64_t Limit::count(QueryContext* qctx) const { if (count_ == nullptr) { diff --git a/src/graph/planner/plan/Query.h b/src/graph/planner/plan/Query.h index 784a22d18b7..e4b1e547a87 100644 --- a/src/graph/planner/plan/Query.h +++ b/src/graph/planner/plan/Query.h @@ -1143,6 +1143,58 @@ class Sort final : public SingleInputNode { std::vector> factors_; }; +struct SamplingParams { + size_t colIdx; + size_t count; + SamplingFactor::SamplingType samplingType; + + SamplingParams() = default; + SamplingParams(size_t col_idx, size_t c, SamplingFactor::SamplingType st) + : colIdx(col_idx), count(c), samplingType(st) {} +}; + +// Sampling the given record set. +class Sampling final : public SingleInputNode { + public: + static Sampling* make(QueryContext* qctx, + PlanNode* input, + std::vector factors = {}) { + return qctx->objPool()->makeAndAdd(qctx, input, std::move(factors)); + } + + const std::vector& factors() const { + return factors_; + } + + PlanNode* clone() const override; + std::unique_ptr explain() const override; + + private: + friend ObjectPool; + Sampling(QueryContext* qctx, PlanNode* input, std::vector factors) + : SingleInputNode(qctx, Kind::kSampling, input) { + factors_ = std::move(factors); + } + + std::vector> factorsString() const { + auto cols = colNames(); + std::vector> result; + for (auto& factor : factors_) { + std::string colName = cols[factor.colIdx]; + std::string order = + factor.samplingType == SamplingFactor::SamplingType::BINARY ? "BINARY" : "ALIAS"; + std::vector temp = {colName, std::to_string(factor.count), order}; + result.emplace_back(temp); + } + return result; + } + + void cloneMembers(const Sampling&); + + private: + std::vector factors_; +}; + // Output the records with the given limitation. class Limit final : public SingleInputNode { public: diff --git a/src/graph/service/PermissionCheck.cpp b/src/graph/service/PermissionCheck.cpp index db5b9fd834a..ad005385175 100644 --- a/src/graph/service/PermissionCheck.cpp +++ b/src/graph/service/PermissionCheck.cpp @@ -138,6 +138,7 @@ namespace graph { case Sentence::Kind::kLookup: case Sentence::Kind::kYield: case Sentence::Kind::kOrderBy: + case Sentence::Kind::kSampling: case Sentence::Kind::kFetchVertices: case Sentence::Kind::kFetchEdges: case Sentence::Kind::kFindPath: diff --git a/src/graph/validator/CMakeLists.txt b/src/graph/validator/CMakeLists.txt index 9c8c7f66fbf..f1f22b0aeec 100644 --- a/src/graph/validator/CMakeLists.txt +++ b/src/graph/validator/CMakeLists.txt @@ -24,6 +24,7 @@ nebula_add_library( YieldValidator.cpp ExplainValidator.cpp GroupByValidator.cpp + SamplingValidator.cpp FindPathValidator.cpp LookupValidator.cpp MatchValidator.cpp diff --git a/src/graph/validator/MatchValidator.cpp b/src/graph/validator/MatchValidator.cpp index 7071cd85219..a06ee03a74f 100644 --- a/src/graph/validator/MatchValidator.cpp +++ b/src/graph/validator/MatchValidator.cpp @@ -525,6 +525,13 @@ Status MatchValidator::validateReturn(MatchReturn *ret, NG_RETURN_IF_ERROR(validatePagination(ret->skip(), ret->limit(), *paginationCtx)); retClauseCtx.pagination = std::move(paginationCtx); + if (ret->samplingFactors() != nullptr) { + auto samplingCtx = getContext(); + NG_RETURN_IF_ERROR( + validateSampling(ret->samplingFactors(), retClauseCtx.yield->yieldColumns, *samplingCtx)); + retClauseCtx.sampling = std::move(samplingCtx); + } + if (ret->orderFactors() != nullptr) { auto orderByCtx = getContext(); NG_RETURN_IF_ERROR( @@ -897,6 +904,47 @@ Status MatchValidator::validateOrderBy(const OrderFactors *factors, return Status::OK(); } +// Check validity of order by options. +// Disable duplicate columns, +// check expression of column (only constant expression and label expression) +Status MatchValidator::validateSampling(const SamplingFactors *factors, + const YieldColumns *yieldColumns, + SamplingClauseContext &samplingCtx) const { + if (factors != nullptr) { + std::vector inputColList; + inputColList.reserve(yieldColumns->columns().size()); + for (auto *col : yieldColumns->columns()) { + inputColList.emplace_back(col->name()); + } + std::unordered_map inputColIndices; + for (auto i = 0u; i < inputColList.size(); i++) { + if (!inputColIndices.emplace(inputColList[i], i).second) { + return Status::SemanticError("Duplicated columns not allowed: %s", inputColList[i].c_str()); + } + } + + for (auto &factor : factors->factors()) { + if (factor->count() < 0) { + return Status::SemanticError("Sampling count"); + } + auto factorExpr = factor->expr(); + if (ExpressionUtils::isEvaluableExpr(factorExpr, qctx_)) continue; + if (factorExpr->kind() != Expression::Kind::kLabel) { + return Status::SemanticError("Only column name can be used as sort item"); + } + auto &name = static_cast(factor->expr())->name(); + auto iter = inputColIndices.find(name); + if (iter == inputColIndices.end()) { + return Status::SemanticError("Column `%s' not found", name.c_str()); + } + samplingCtx.indexedSamplingFactors.emplace_back( + SamplingParams(iter->second, factor->count(), factor->samplingType())); + } + } + + return Status::OK(); +} + // Validate group by and fill group by context. Status MatchValidator::validateGroup(YieldClauseContext &yieldCtx) { auto cols = yieldCtx.yieldColumns->columns(); diff --git a/src/graph/validator/MatchValidator.h b/src/graph/validator/MatchValidator.h index c78de9d3e7a..f20eee8b410 100644 --- a/src/graph/validator/MatchValidator.h +++ b/src/graph/validator/MatchValidator.h @@ -54,6 +54,10 @@ class MatchValidator final : public Validator { const YieldColumns *yieldColumns, OrderByClauseContext &orderByCtx) const; + Status validateSampling(const SamplingFactors *factors, + const YieldColumns *yieldColumns, + SamplingClauseContext &samplingCtx) const; + Status validateGroup(YieldClauseContext &yieldCtx); Status validateYield(YieldClauseContext &yieldCtx); diff --git a/src/graph/validator/SamplingValidator.cpp b/src/graph/validator/SamplingValidator.cpp new file mode 100644 index 00000000000..2f7ebd567a1 --- /dev/null +++ b/src/graph/validator/SamplingValidator.cpp @@ -0,0 +1,83 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#include "graph/validator/SamplingValidator.h" + +#include "parser/TraverseSentences.h" + +namespace nebula { +namespace graph { +Status SamplingValidator::validateImpl() { + auto sentence = static_cast(sentence_); + auto &factors = sentence->factors(); + // Check expression type, collect properties, fill index of order by column in + // input columns. + for (auto &factor : factors) { + if (factor->count() < 0) { + return Status::SyntaxError("sampling `%ld' is illegal", factor->count()); + } + if (factor->expr()->kind() == Expression::Kind::kInputProperty) { + auto expr = static_cast(factor->expr()); + NG_RETURN_IF_ERROR(deduceExprType(expr)); + NG_RETURN_IF_ERROR(deduceProps(expr, exprProps_)); + const auto &cols = inputCols(); + auto &name = expr->prop(); + auto eq = [&](const ColDef &col) { return col.name == name; }; + auto iter = std::find_if(cols.cbegin(), cols.cend(), eq); + size_t colIdx = std::distance(cols.cbegin(), iter); + colSamplingTypes_.emplace_back( + SamplingParams(colIdx, factor->count(), factor->samplingType())); + } else if (factor->expr()->kind() == Expression::Kind::kVarProperty) { + auto expr = static_cast(factor->expr()); + NG_RETURN_IF_ERROR(deduceExprType(expr)); + NG_RETURN_IF_ERROR(deduceProps(expr, exprProps_)); + const auto &cols = vctx_->getVar(expr->sym()); + auto &name = expr->prop(); + auto eq = [&](const ColDef &col) { return col.name == name; }; + auto iter = std::find_if(cols.cbegin(), cols.cend(), eq); + size_t colIdx = std::distance(cols.cbegin(), iter); + colSamplingTypes_.emplace_back( + SamplingParams(colIdx, factor->count(), factor->samplingType())); + } else { + return Status::SemanticError("Order by with invalid expression `%s'", + factor->expr()->toString().c_str()); + } + } + + // only one Input/Variable is ok. + if (!exprProps_.inputProps().empty() && !exprProps_.varProps().empty()) { + return Status::SemanticError("Not support both input and variable."); + } else if (!exprProps_.inputProps().empty()) { + outputs_ = inputCols(); + } else if (!exprProps_.varProps().empty()) { + if (!userDefinedVarNameList_.empty()) { + if (userDefinedVarNameList_.size() != 1) { + return Status::SemanticError("Multiple user defined vars are not supported yet."); + } + userDefinedVarName_ = *userDefinedVarNameList_.begin(); + outputs_ = vctx_->getVar(userDefinedVarName_); + } + } + + return Status::OK(); +} + +Status SamplingValidator::toPlan() { + auto *plan = qctx_->plan(); + auto *samplingNode = Sampling::make(qctx_, plan->root(), std::move(colSamplingTypes_)); + std::vector colNames; + for (auto &col : outputs_) { + colNames.emplace_back(col.name); + } + samplingNode->setColNames(std::move(colNames)); + if (!userDefinedVarName_.empty()) { + samplingNode->setInputVar(userDefinedVarName_); + } + root_ = samplingNode; + tail_ = root_; + return Status::OK(); +} +} // namespace graph +} // namespace nebula diff --git a/src/graph/validator/SamplingValidator.h b/src/graph/validator/SamplingValidator.h new file mode 100644 index 00000000000..a34598fddcc --- /dev/null +++ b/src/graph/validator/SamplingValidator.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#ifndef GRAPH_VALIDATOR_SAMPLINGVALIDATOR_H_ +#define GRAPH_VALIDATOR_SAMPLINGVALIDATOR_H_ + +#include "graph/planner/plan/Query.h" +#include "graph/validator/Validator.h" + +namespace nebula { +namespace graph { +class SamplingValidator final : public Validator { + public: + SamplingValidator(Sentence* sentence, QueryContext* context) : Validator(sentence, context) { + setNoSpaceRequired(); + } + + private: + Status validateImpl() override; + + Status toPlan() override; + + private: + std::vector colSamplingTypes_; + std::string userDefinedVarName_; +}; +} // namespace graph +} // namespace nebula +#endif // GRAPH_VALIDATOR_SAMPLINGVALIDATOR_H_ diff --git a/src/graph/validator/Validator.cpp b/src/graph/validator/Validator.cpp index d76d350c851..75abcf744b3 100644 --- a/src/graph/validator/Validator.cpp +++ b/src/graph/validator/Validator.cpp @@ -31,6 +31,7 @@ #include "graph/validator/OrderByValidator.h" #include "graph/validator/PipeValidator.h" #include "graph/validator/ReportError.h" +#include "graph/validator/SamplingValidator.h" #include "graph/validator/SequentialValidator.h" #include "graph/validator/SetValidator.h" #include "graph/validator/UnwindValidator.h" @@ -72,6 +73,8 @@ std::unique_ptr Validator::makeValidator(Sentence* sentence, QueryCon return std::make_unique(sentence, context); case Sentence::Kind::kOrderBy: return std::make_unique(sentence, context); + case Sentence::Kind::kSampling: + return std::make_unique(sentence, context); case Sentence::Kind::kYield: return std::make_unique(sentence, context); case Sentence::Kind::kGroupBy: diff --git a/src/parser/MatchSentence.cpp b/src/parser/MatchSentence.cpp index 9ff07e88d0d..e06b2d00cfe 100644 --- a/src/parser/MatchSentence.cpp +++ b/src/parser/MatchSentence.cpp @@ -48,6 +48,12 @@ std::string WithClause::toString() const { buf += returnItems_->toString(); + if (samplingFactors_ != nullptr) { + buf += " "; + buf += "SAMPLING "; + buf += samplingFactors_->toString(); + } + if (orderFactors_ != nullptr) { buf += " "; buf += "ORDER BY "; @@ -101,6 +107,12 @@ std::string MatchReturn::toString() const { buf += returnItems_->toString(); + if (samplingFactors_ != nullptr) { + buf += " "; + buf += "SAMPLING "; + buf += samplingFactors_->toString(); + } + if (orderFactors_ != nullptr) { buf += " "; buf += "ORDER BY "; diff --git a/src/parser/MatchSentence.h b/src/parser/MatchSentence.h index 239c52c991e..a3563a0ddbc 100644 --- a/src/parser/MatchSentence.h +++ b/src/parser/MatchSentence.h @@ -61,12 +61,14 @@ class MatchReturnItems final { class MatchReturn final { public: MatchReturn(MatchReturnItems* returnItems = nullptr, + SamplingFactors* samplingFactors = nullptr, OrderFactors* orderFactors = nullptr, Expression* skip = nullptr, Expression* limit = nullptr, bool distinct = false) { returnItems_.reset(returnItems); orderFactors_.reset(orderFactors); + samplingFactors_.reset(samplingFactors); skip_ = skip; limit_ = limit; isDistinct_ = distinct; @@ -100,12 +102,21 @@ class MatchReturn final { return orderFactors_.get(); } + SamplingFactors* samplingFactors() { + return samplingFactors_.get(); + } + + const SamplingFactors* samplingFactors() const { + return samplingFactors_.get(); + } + std::string toString() const; private: std::unique_ptr returnItems_; bool isDistinct_{false}; std::unique_ptr orderFactors_; + std::unique_ptr samplingFactors_; Expression* skip_{nullptr}; Expression* limit_{nullptr}; }; @@ -210,6 +221,7 @@ class UnwindClause final : public ReadingClause { class WithClause final : public ReadingClause { public: explicit WithClause(MatchReturnItems* returnItems, + SamplingFactors* samplingFactors = nullptr, OrderFactors* orderFactors = nullptr, Expression* skip = nullptr, Expression* limit = nullptr, @@ -218,6 +230,7 @@ class WithClause final : public ReadingClause { : ReadingClause(Kind::kWith) { returnItems_.reset(returnItems); orderFactors_.reset(orderFactors); + samplingFactors_.reset(samplingFactors); skip_ = skip; limit_ = limit; where_.reset(where); @@ -240,6 +253,14 @@ class WithClause final : public ReadingClause { return orderFactors_.get(); } + SamplingFactors* samplingFactors() { + return samplingFactors_.get(); + } + + const SamplingFactors* samplingFactors() const { + return samplingFactors_.get(); + } + Expression* skip() { return skip_; } @@ -273,6 +294,7 @@ class WithClause final : public ReadingClause { private: std::unique_ptr returnItems_; std::unique_ptr orderFactors_; + std::unique_ptr samplingFactors_; Expression* skip_{nullptr}; Expression* limit_{nullptr}; std::unique_ptr where_; diff --git a/src/parser/Sentence.h b/src/parser/Sentence.h index ce36cabb72f..fcf718c3e8e 100644 --- a/src/parser/Sentence.h +++ b/src/parser/Sentence.h @@ -98,6 +98,7 @@ class Sentence { kRevoke, kChangePassword, kOrderBy, + kSampling, kShowConfigs, kSetConfig, kGetConfig, diff --git a/src/parser/TraverseSentences.cpp b/src/parser/TraverseSentences.cpp index fcd12a862ea..638458365ce 100644 --- a/src/parser/TraverseSentences.cpp +++ b/src/parser/TraverseSentences.cpp @@ -145,6 +145,33 @@ std::string OrderBySentence::toString() const { return folly::stringPrintf("ORDER BY %s", orderFactors_->toString().c_str()); } +std::string SamplingFactor::toString() const { + switch (sampling_type_) { + case BINARY: + return folly::stringPrintf("%s %ld BINARY,", expr_->toString().c_str(), count_); + case ALIAS: + return folly::stringPrintf("%s %ld ALIAS", expr_->toString().c_str(), count_); + default: + LOG(FATAL) << "Unknown Sampling Type: " << sampling_type_; + } +} + +std::string SamplingFactors::toString() const { + std::string buf; + buf.reserve(256); + for (auto &factor : factors_) { + buf += factor->toString(); + } + if (!buf.empty()) { + buf.resize(buf.size() - 1); + } + return buf; +} + +std::string SamplingSentence::toString() const { + return folly::stringPrintf("SAMPLING %s", samplingFactors_->toString().c_str()); +} + std::string FetchVerticesSentence::toString() const { std::string buf; buf.reserve(256); diff --git a/src/parser/TraverseSentences.h b/src/parser/TraverseSentences.h index b4ef8bb11f0..62eda23f2ad 100644 --- a/src/parser/TraverseSentences.h +++ b/src/parser/TraverseSentences.h @@ -301,6 +301,81 @@ class OrderBySentence final : public Sentence { std::unique_ptr orderFactors_; }; +class SamplingFactor final { + public: + enum SamplingType : uint8_t { BINARY, ALIAS }; + + SamplingFactor(Expression* expr, int64_t count, SamplingType sp) { + expr_ = expr; + count_ = count; + sampling_type_ = sp; + } + + Expression* expr() { + return expr_; + } + + void setExpr(Expression* expr) { + expr_ = expr; + } + + int64_t count() { + return count_; + } + + SamplingType samplingType() { + return sampling_type_; + } + + std::string toString() const; + + private: + Expression* expr_{nullptr}; + int64_t count_; + SamplingType sampling_type_; +}; + +class SamplingFactors final { + public: + void addFactor(SamplingFactor* factor) { + factors_.emplace_back(factor); + } + + auto& factors() { + return factors_; + } + + const auto& factors() const { + return factors_; + } + + std::string toString() const; + + private: + std::vector> factors_; +}; + +class SamplingSentence final : public Sentence { + public: + explicit SamplingSentence(SamplingFactors* factors) { + samplingFactors_.reset(factors); + kind_ = Kind::kSampling; + } + + auto& factors() { + return samplingFactors_->factors(); + } + + const auto& factors() const { + return samplingFactors_->factors(); + } + + std::string toString() const override; + + private: + std::unique_ptr samplingFactors_; +}; + class FetchVerticesSentence final : public Sentence { public: FetchVerticesSentence(NameLabelList* tags, VertexIDList* vidList, YieldClause* clause) { diff --git a/src/parser/parser.yy b/src/parser/parser.yy index b9708c060a4..798e64da930 100644 --- a/src/parser/parser.yy +++ b/src/parser/parser.yy @@ -118,6 +118,8 @@ using namespace nebula; nebula::IndexParamItem *index_param_item; nebula::OrderFactor *order_factor; nebula::OrderFactors *order_factors; + nebula::SamplingFactor *sampling_factor; + nebula::SamplingFactors *sampling_factors; nebula::meta::cpp2::ConfigModule config_module; nebula::meta::cpp2::ListHostType list_host_type; nebula::ConfigRowItem *config_row_item; @@ -185,7 +187,7 @@ using namespace nebula; %token KW_GET KW_DECLARE KW_GRAPH KW_META KW_STORAGE KW_AGENT %token KW_TTL KW_TTL_DURATION KW_TTL_COL KW_DATA KW_STOP %token KW_FETCH KW_PROP KW_UPDATE KW_UPSERT KW_WHEN -%token KW_ORDER KW_ASC KW_LIMIT KW_SAMPLE KW_OFFSET KW_ASCENDING KW_DESCENDING +%token KW_ORDER KW_ASC KW_LIMIT KW_SAMPLE KW_OFFSET KW_ASCENDING KW_DESCENDING KW_SAMPLING KW_BINARY KW_ALIAS %token KW_DISTINCT KW_ALL KW_OF %token KW_BALANCE KW_LEADER KW_RESET KW_PLAN %token KW_SHORTEST KW_PATH KW_NOLOOP KW_SHORTESTPATH KW_ALLSHORTESTPATHS @@ -296,6 +298,8 @@ using namespace nebula; %type index_param_item %type order_factor %type order_factors +%type sampling_factor +%type sampling_factors %type config_module_enum %type list_host_type %type show_config_item get_config_item set_config_item @@ -338,6 +342,7 @@ using namespace nebula; %type reading_clauses reading_with_clause reading_with_clauses %type match_step_range %type match_order_by +%type match_sampling %type text_search_argument %type base_text_search_argument %type fuzzy_text_search_argument @@ -392,7 +397,7 @@ using namespace nebula; %type traverse_sentence unwind_sentence %type go_sentence match_sentence lookup_sentence find_path_sentence get_subgraph_sentence -%type group_by_sentence order_by_sentence limit_sentence +%type group_by_sentence order_by_sentence limit_sentence sampling_sentence %type fetch_sentence fetch_vertices_sentence fetch_edges_sentence %type set_sentence piped_sentence assignment_sentence match_sentences %type yield_sentence use_sentence @@ -579,6 +584,7 @@ unreserved_keyword | KW_DIVIDE { $$ = new std::string("divide"); } | KW_RENAME { $$ = new std::string("rename"); } | KW_CLEAR { $$ = new std::string("clear"); } + | KW_SAMPLING { $$ = new std::string("sampling"); } ; expression @@ -1669,11 +1675,11 @@ unwind_sentence ; with_clause - : KW_WITH match_return_items match_order_by match_skip match_limit where_clause { - $$ = new WithClause($2, $3, $4, $5, $6, false/*distinct*/); + : KW_WITH match_return_items match_sampling match_order_by match_skip match_limit where_clause { + $$ = new WithClause($2, $3, $4, $5, $6, $7, false/*distinct*/); } - | KW_WITH KW_DISTINCT match_return_items match_order_by match_skip match_limit where_clause { - $$ = new WithClause($3, $4, $5, $6, $7, true); + | KW_WITH KW_DISTINCT match_return_items match_sampling match_order_by match_skip match_limit where_clause { + $$ = new WithClause($3, $4, $5, $6, $7, $8, true); } ; @@ -1951,11 +1957,11 @@ match_edge_type_list ; match_return - : KW_RETURN match_return_items match_order_by match_skip match_limit { - $$ = new MatchReturn($2, $3, $4, $5); + : KW_RETURN match_return_items match_sampling match_order_by match_skip match_limit { + $$ = new MatchReturn($2, $3, $4, $5, $6); } - | KW_RETURN KW_DISTINCT match_return_items match_order_by match_skip match_limit { - $$ = new MatchReturn($3, $4, $5, $6, true); + | KW_RETURN KW_DISTINCT match_return_items match_sampling match_order_by match_skip match_limit { + $$ = new MatchReturn($3, $4, $5, $6, $7, true); } ; @@ -1979,6 +1985,15 @@ match_order_by } ; +match_sampling + : %empty { + $$ = nullptr; + } + | KW_SAMPLING sampling_factors { + $$ = $2; + } + ; + match_skip : %empty { $$ = nullptr; @@ -2239,6 +2254,36 @@ order_by_sentence } ; +sampling_factor + : expression legal_integer { + $$ = new SamplingFactor($1, $2, SamplingFactor::BINARY); + } + | expression legal_integer KW_BINARY { + $$ = new SamplingFactor($1, $2, SamplingFactor::BINARY); + } + | expression legal_integer KW_ALIAS { + $$ = new SamplingFactor($1, $2, SamplingFactor::ALIAS); + } + ; + +sampling_factors + : sampling_factor { + auto factors = new SamplingFactors(); + factors->addFactor($1); + $$ = factors; + } + | sampling_factors COMMA sampling_factor { + $1->addFactor($3); + $$ = $1; + } + ; + +sampling_sentence + : KW_SAMPLING sampling_factors { + $$ = new SamplingSentence($2); + } + ; + fetch_vertices_sentence : KW_FETCH KW_PROP KW_ON name_label_list vid_list yield_clause { $$ = new FetchVerticesSentence($4, $5, $6); @@ -2943,6 +2988,7 @@ traverse_sentence | go_sentence { $$ = $1; } | lookup_sentence { $$ = $1; } | group_by_sentence { $$ = $1; } + | sampling_sentence { $$ = $1; } | order_by_sentence { $$ = $1; } | fetch_sentence { $$ = $1; } | find_path_sentence { $$ = $1; } diff --git a/src/parser/scanner.lex b/src/parser/scanner.lex index 283302135ab..64bef233839 100644 --- a/src/parser/scanner.lex +++ b/src/parser/scanner.lex @@ -167,6 +167,7 @@ LABEL_FULL_WIDTH {CN_EN_FULL_WIDTH}{CN_EN_NUM_FULL_WIDTH}* "GET" { return TokenType::KW_GET; } "OF" { return TokenType::KW_OF; } "ORDER" { return TokenType::KW_ORDER; } +"SAMPLING" { return TokenType::KW_SAMPLING; } "INGEST" { return TokenType::KW_INGEST; } "COMPACT" { return TokenType::KW_COMPACT; } "FLUSH" { return TokenType::KW_FLUSH; } @@ -174,6 +175,8 @@ LABEL_FULL_WIDTH {CN_EN_FULL_WIDTH}{CN_EN_NUM_FULL_WIDTH}* "ASC" { return TokenType::KW_ASC; } "ASCENDING" { return TokenType::KW_ASCENDING; } "DESCENDING" { return TokenType::KW_DESCENDING; } +"BINARY" { return TokenType::KW_BINARY; } +"ALIAS" { return TokenType::KW_ALIAS; } "DISTINCT" { return TokenType::KW_DISTINCT; } "FETCH" { return TokenType::KW_FETCH; } "PROP" { return TokenType::KW_PROP; } diff --git a/src/parser/test/ParserTest.cpp b/src/parser/test/ParserTest.cpp index 9bc976cc1a0..a12d5957a70 100644 --- a/src/parser/test/ParserTest.cpp +++ b/src/parser/test/ParserTest.cpp @@ -1761,6 +1761,14 @@ TEST_F(ParserTest, UnreservedKeywords) { auto result = parse(query); ASSERT_TRUE(result.ok()) << result.status(); } + { + std::string query = + "GO FROM \"123\" OVER like YIELD $$.tag1.EMAIL, like.users," + "like._src, like._dst, like.type, $^.tag2.SPACE " + "| SAMPLING $-.SPACE 5 binary"; + auto result = parse(query); + ASSERT_TRUE(result.ok()) << result.status(); + } { std::string query = "GO FROM UUID() OVER like YIELD $$.tag1.EMAIL, like.users," @@ -1851,6 +1859,11 @@ TEST_F(ParserTest, Agg) { auto result = parse(query); ASSERT_TRUE(result.ok()) << result.status(); } + { + std::string query = "SAMPLING $-.id 5 binary"; + auto result = parse(query); + ASSERT_TRUE(result.ok()) << result.status(); + } { std::string query = "GO FROM \"1\" OVER friend " diff --git a/src/parser/test/ScannerTest.cpp b/src/parser/test/ScannerTest.cpp index 6e6c4a32e51..84a4b15ad58 100644 --- a/src/parser/test/ScannerTest.cpp +++ b/src/parser/test/ScannerTest.cpp @@ -350,6 +350,12 @@ TEST(Scanner, Basic) { CHECK_SEMANTIC_TYPE("ORDER", TokenType::KW_ORDER), CHECK_SEMANTIC_TYPE("Order", TokenType::KW_ORDER), CHECK_SEMANTIC_TYPE("order", TokenType::KW_ORDER), + CHECK_SEMANTIC_TYPE("sampling", TokenType::KW_SAMPLING), + CHECK_SEMANTIC_TYPE("Sampling", TokenType::KW_SAMPLING), + CHECK_SEMANTIC_TYPE("SAMPLING", TokenType::KW_SAMPLING), + CHECK_SEMANTIC_TYPE("binary", TokenType::KW_BINARY), + CHECK_SEMANTIC_TYPE("Binary", TokenType::KW_BINARY), + CHECK_SEMANTIC_TYPE("BINARY", TokenType::KW_BINARY), CHECK_SEMANTIC_TYPE("ASC", TokenType::KW_ASC), CHECK_SEMANTIC_TYPE("Asc", TokenType::KW_ASC), CHECK_SEMANTIC_TYPE("asc", TokenType::KW_ASC),