Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Nivras committed Dec 23, 2021
1 parent cb2b292 commit dc5d038
Show file tree
Hide file tree
Showing 6 changed files with 591 additions and 20 deletions.
12 changes: 6 additions & 6 deletions src/storage/exec/IndexAggregateNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ nebula::cpp2::ErrorCode IndexAggregateNode::init(InitContext& ctx) {
for (const auto& statInfo : statInfos_) {
ctx.statColumns.insert(statInfo.first);
}
auto ret = children_[0]->init(ctx);
if (ret != nebula::cpp2::ErrorCode::SUCCEEDED) {
return ret;
}
initStatValue();
retColMap_.clear();
retColMap_ = ctx.retColMap;
return children_[0]->init(ctx);
return nebula::cpp2::ErrorCode::SUCCEEDED;
}

void IndexAggregateNode::initStatValue() {
Expand All @@ -45,10 +49,8 @@ void IndexAggregateNode::initStatValue() {

void IndexAggregateNode::addStatValue(const Value& value, ColumnStat& stat) {
switch (stat.statType_) {
case cpp2::StatType::SUM:
case cpp2::StatType::AVG: {
case cpp2::StatType::SUM: {
stat.sum_ = stat.sum_ + value;
stat.count_ = stat.count_ + 1;
break;
}
case cpp2::StatType::COUNT: {
Expand Down Expand Up @@ -86,8 +88,6 @@ Row IndexAggregateNode::calculateStats() {
result.values.emplace_back(stat.sum_);
} else if (stat.statType_ == cpp2::StatType::COUNT) {
result.values.emplace_back(stat.count_);
} else if (stat.statType_ == cpp2::StatType::AVG) {
result.values.emplace_back(stat.sum_ / stat.count_);
} else if (stat.statType_ == cpp2::StatType::MAX) {
result.values.emplace_back(stat.max_);
} else if (stat.statType_ == cpp2::StatType::MIN) {
Expand Down
2 changes: 1 addition & 1 deletion src/storage/exec/IndexProjectionNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ nebula::cpp2::ErrorCode IndexProjectionNode::init(InitContext& ctx) {
ctx.requiredColumns.insert(col);
}
for (auto& col : ctx.statColumns) {
if (ctx.requiredColumns.find(col) != ctx.requiredColumns.end()) {
if (ctx.requiredColumns.find(col) == ctx.requiredColumns.end()) {
ctx.requiredColumns.insert(col);
requiredColumns_.push_back(col);
}
Expand Down
53 changes: 43 additions & 10 deletions src/storage/index/LookupProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ void LookupProcessor::runInSingleThread(const std::vector<PartitionID>& parts,
datasetList.emplace_back(std::move(dataset));
codeList.emplace_back(code);
}
if (statsDataSet_.colNames.size() > 0) {
if (statTypes_.size() > 0) {
auto indexAgg = dynamic_cast<IndexAggregateNode*>(plan.get());
statsDataSet_.emplace_back(std::move(indexAgg->calculateStats()));
}
Expand All @@ -228,7 +228,7 @@ void LookupProcessor::runInSingleThread(const std::vector<PartitionID>& parts,
void LookupProcessor::runInMultipleThread(const std::vector<PartitionID>& parts,
std::unique_ptr<IndexNode> plan) {
std::vector<std::unique_ptr<IndexNode>> planCopy = reproducePlan(plan.get(), parts.size());
using ReturnType = std::tuple<PartitionID, ::nebula::cpp2::ErrorCode, std::deque<Row>>;
using ReturnType = std::tuple<PartitionID, ::nebula::cpp2::ErrorCode, std::deque<Row>, Row>;
std::vector<folly::Future<ReturnType>> futures;
for (size_t i = 0; i < parts.size(); i++) {
futures.emplace_back(folly::via(
Expand All @@ -251,28 +251,34 @@ void LookupProcessor::runInMultipleThread(const std::vector<PartitionID>& parts,
if (UNLIKELY(profileDetailFlag_)) {
profilePlan(plan.get());
}
return {part, code, dataset};
Row statResult;
if (code == nebula::cpp2::ErrorCode::SUCCEEDED && statTypes_.size() > 0) {
auto indexAgg = dynamic_cast<IndexAggregateNode*>(plan.get());
statResult = indexAgg->calculateStats();
}
return {part, code, dataset, statResult};
}));
}
folly::collectAll(futures).via(executor_).thenTry([this, &plan](auto&& t) {
folly::collectAll(futures).via(executor_).thenTry([this](auto&& t) {
CHECK(!t.hasException());
const auto& tries = t.value();
std::vector<Row> statResults;
for (size_t j = 0; j < tries.size(); j++) {
CHECK(!tries[j].hasException());
auto& [partId, code, dataset] = tries[j].value();
auto& [partId, code, dataset, statResult] = tries[j].value();
if (code == ::nebula::cpp2::ErrorCode::SUCCEEDED) {
for (auto& row : dataset) {
resultDataSet_.emplace_back(std::move(row));
}
} else {
handleErrorCode(code, context_->spaceId(), partId);
}
}
if (statsDataSet_.colNames.size() > 0) {
auto indexAgg = dynamic_cast<IndexAggregateNode*>(plan.get());
statsDataSet_.emplace_back(std::move(indexAgg->calculateStats()));
statResults.emplace_back(std::move(statResult));
}
DLOG(INFO) << "finish";
// IndexAggregateNode has been copyed and each part get it's own aggregate info,
// we need to merge it
this->mergeStatsResult(statResults);
this->onProcessFinished();
this->onFinished();
});
Expand All @@ -286,6 +292,7 @@ LookupProcessor::handleStatProps(const std::vector<cpp2::StatProp>& statProps) {

for (size_t statIdx = 0; statIdx < statProps.size(); statIdx++) {
const auto& statProp = statProps[statIdx];
statTypes_.emplace_back(statProp.get_stat());
auto exp = Expression::decode(pool, *statProp.prop_ref());
if (exp == nullptr) {
return nebula::cpp2::ErrorCode::E_INVALID_STAT_TYPE;
Expand All @@ -301,7 +308,8 @@ LookupProcessor::handleStatProps(const std::vector<cpp2::StatProp>& statProps) {
if (edgeName != context_->edgeName_) {
return nebula::cpp2::ErrorCode::E_EDGE_NOT_FOUND;
}
if (exp->kind() == Expression::Kind::kEdgeProperty) {
if (exp->kind() == Expression::Kind::kEdgeProperty && propName != kSrc &&
propName != kDst) {
auto edgeSchema =
env_->schemaMan_->getEdgeSchema(context_->spaceId(), context_->edgeType_);
auto field = edgeSchema->field(propName);
Expand Down Expand Up @@ -346,6 +354,31 @@ LookupProcessor::handleStatProps(const std::vector<cpp2::StatProp>& statProps) {
return statInfos;
}

void LookupProcessor::mergeStatsResult(const std::vector<Row>& statsResult) {
if (statsResult.size() == 0 || statTypes_.size() == 0) {
return;
}

Row result;
for (size_t statIdx = 0; statIdx < statTypes_.size(); statIdx++) {
Value value = statsResult[0].values[statIdx];
for (size_t resIdx = 1; resIdx < statsResult.size(); resIdx++) {
const auto& currType = statTypes_[statIdx];
if (currType == cpp2::StatType::SUM || currType == cpp2::StatType::COUNT) {
value = value + statsResult[resIdx].values[statIdx];
} else if (currType == cpp2::StatType::MAX) {
value = value > statsResult[resIdx].values[statIdx] ? value
: statsResult[resIdx].values[statIdx];
} else if (currType == cpp2::StatType::MIN) {
value = value < statsResult[resIdx].values[statIdx] ? value
: statsResult[resIdx].values[statIdx];
}
}
result.values.emplace_back(std::move(value));
}
statsDataSet_.emplace_back(std::move(result));
}

std::vector<std::unique_ptr<IndexNode>> LookupProcessor::reproducePlan(IndexNode* root,
size_t count) {
std::vector<std::unique_ptr<IndexNode>> ret(count);
Expand Down
3 changes: 2 additions & 1 deletion src/storage/index/LookupProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ class LookupProcessor : public BaseProcessor<cpp2::LookupIndexResp> {
std::vector<std::unique_ptr<IndexNode>> reproducePlan(IndexNode* root, size_t count);
ErrorOr<nebula::cpp2::ErrorCode, std::vector<std::pair<std::string, cpp2::StatType>>>
handleStatProps(const std::vector<cpp2::StatProp>& statProps);
void mergeStatsResult(const std::vector<Row>& statsResult);
folly::Executor* executor_{nullptr};
std::unique_ptr<PlanContext> planContext_;
std::unique_ptr<RuntimeContext> context_;
nebula::DataSet resultDataSet_;
nebula::DataSet statsDataSet_;
std::vector<nebula::DataSet> partResults_;
std::vector<std::string> statColumnName_;
std::vector<cpp2::StatType> statTypes_;
};
} // namespace storage
} // namespace nebula
Loading

0 comments on commit dc5d038

Please sign in to comment.