Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push filter down HashInnerJoin rule #4956

Merged
merged 15 commits into from
Dec 6, 2022
26 changes: 1 addition & 25 deletions src/graph/context/Symbols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,6 @@ bool SymbolTable::deleteWrittenBy(const std::string& varName, PlanNode* node) {
if (var == vars_.end()) {
return false;
}
for (auto& alias : var->second->colNames) {
auto found = aliasGeneratedBy_.find(alias);
if (found != aliasGeneratedBy_.end()) {
if (found->second == varName) {
aliasGeneratedBy_.erase(alias);
}
}
}
var->second->writtenBy.erase(node);
return true;
}
Expand All @@ -106,6 +98,7 @@ bool SymbolTable::updateWrittenBy(const std::string& oldVar,
}

Variable* SymbolTable::getVar(const std::string& varName) {
DCHECK(!varName.empty()) << "the variable name is empty";
auto var = vars_.find(varName);
if (var == vars_.end()) {
return nullptr;
Expand All @@ -114,22 +107,5 @@ Variable* SymbolTable::getVar(const std::string& varName) {
}
}

void SymbolTable::setAliasGeneratedBy(const std::vector<std::string>& aliases,
const std::string& varName) {
for (auto& alias : aliases) {
if (aliasGeneratedBy_.count(alias) == 0) {
aliasGeneratedBy_.emplace(alias, varName);
}
}
}

StatusOr<std::string> SymbolTable::getAliasGeneratedBy(const std::string& alias) {
auto found = aliasGeneratedBy_.find(alias);
if (found == aliasGeneratedBy_.end()) {
return Status::Error("Not found a variable that generates the alias: %s", alias.c_str());
} else {
return found->second;
}
}
} // namespace graph
} // namespace nebula
6 changes: 0 additions & 6 deletions src/graph/context/Symbols.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ class SymbolTable final {

Variable* getVar(const std::string& varName);

void setAliasGeneratedBy(const std::vector<std::string>& aliases, const std::string& varName);

StatusOr<std::string> getAliasGeneratedBy(const std::string& alias);

std::string toString() const;

private:
Expand All @@ -85,8 +81,6 @@ class SymbolTable final {
ExecutionContext* ectx_{nullptr};
// var name -> variable
std::unordered_map<std::string, Variable*> vars_;
// alias -> first variable that generate the alias
std::unordered_map<std::string, std::string> aliasGeneratedBy_;
};

} // namespace graph
Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ nebula_add_library(
rule/PushFilterDownAggregateRule.cpp
rule/PushFilterDownProjectRule.cpp
rule/PushFilterDownLeftJoinRule.cpp
rule/PushFilterDownHashInnerJoinRule.cpp
rule/PushFilterDownInnerJoinRule.cpp
rule/PushFilterDownNodeRule.cpp
rule/PushFilterDownScanVerticesRule.cpp
Expand Down
174 changes: 157 additions & 17 deletions src/graph/optimizer/OptGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,81 @@ OptGroup::OptGroup(OptContext *ctx) noexcept : ctx_(ctx) {
DCHECK(ctx != nullptr);
}

Status OptGroup::validateSubPlan(const OptGroupNode *gn,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find the call to this function. Did I miss something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Actually I don't use it in this PR except to debug some crash. I want to use it when refactoring go planner, now i need to consider loop/select plan nodes, it's a little tedious!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ACK.

const OptRule *rule,
const std::vector<OptGroup *> &patternLeaves) const {
auto &deps = DCHECK_NOTNULL(gn)->dependencies();

auto checkDepGroup = [this, gn, rule, &patternLeaves](const OptGroup *depGroup) -> Status {
auto iter = std::find(patternLeaves.begin(), patternLeaves.end(), depGroup);
if (iter != patternLeaves.end()) {
return Status::OK();
}
if (!depGroup) {
return Status::Error("Could not find the dependent group in pattern leaves");
}
if (depGroup->groupNodes_.size() != 1U || depGroup->groupNodesReferenced_.size() != 1U) {
return Status::Error(
"Invalid sub-plan generated when applying the rule: %s, "
"planNode: %s, numGroupNodes: %lu, numGroupNodesRef: %lu",
rule->toString().c_str(),
PlanNode::toString(gn->node()->kind()),
depGroup->groupNodes_.size(),
depGroup->groupNodesReferenced_.size());
}
Comment on lines +54 to +65
Copy link
Contributor

@czpmango czpmango Dec 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better not to pass optimizer errors to the user.

It might change this way:

DCHECK_NOTNULL(depGroup) << "Could not find the dependent group in pattern leaves";

or

// The `depGroup` should not be null bcz ...
DCHECK_NOTNULL(depGroup);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have converted the error message in QueryInstance. this return only for not crash in release version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

return validateSubPlan(depGroup->groupNodes_.front(), rule, patternLeaves);
};

switch (deps.size()) {
case 0: {
auto iter = std::find(patternLeaves.begin(), patternLeaves.end(), nullptr);
if (iter == patternLeaves.end()) {
return Status::Error("Invalid sub-plan generated by the rule: %s, planNode: %s",
rule->toString().c_str(),
PlanNode::toString(gn->node()->kind()));
}
break;
}
case 1: {
NG_RETURN_IF_ERROR(checkDepGroup(deps[0]));
break;
}
case 2: {
NG_RETURN_IF_ERROR(checkDepGroup(deps[0]));
NG_RETURN_IF_ERROR(checkDepGroup(deps[1]));
break;
}
default: {
return Status::Error("Invalid dependencies of opt group node: %lu", deps.size());
}
}
return Status::OK();
}

Status OptGroup::validate(const OptRule *rule) const {
if (groupNodes_.empty() && !groupNodesReferenced_.empty()) {
return Status::Error(
"The OptGroup has no any OptGroupNode but used by other OptGroupNode "
"when applying the rule: %s, numGroupNodesRef: %lu",
rule->toString().c_str(),
groupNodesReferenced_.size());
}
for (auto *gn : groupNodes_) {
NG_RETURN_IF_ERROR(gn->validate(rule));
if (gn->node()->outputVar() != outputVar_) {
return Status::Error(
"The output columns of the OptGroupNode is different from OptGroup "
"when applying the rule: %s, %s vs. %s",
rule->toString().c_str(),
gn->node()->outputVar().c_str(),
outputVar_.c_str());
}
}
return Status::OK();
}

void OptGroup::addGroupNode(OptGroupNode *groupNode) {
DCHECK(groupNode != nullptr);
DCHECK(groupNode->group() == this);
DCHECK_EQ(this, DCHECK_NOTNULL(groupNode)->group());
if (outputVar_.empty()) {
outputVar_ = groupNode->node()->outputVar();
} else {
Expand All @@ -54,13 +126,9 @@ void OptGroup::addGroupNode(OptGroupNode *groupNode) {
}

OptGroupNode *OptGroup::makeGroupNode(PlanNode *node) {
if (outputVar_.empty()) {
outputVar_ = node->outputVar();
} else {
DCHECK_EQ(outputVar_, node->outputVar());
}
groupNodes_.emplace_back(OptGroupNode::create(ctx_, node, this));
return groupNodes_.back();
auto *gn = OptGroupNode::create(ctx_, node, this);
addGroupNode(gn);
return gn;
}

Status OptGroup::explore(const OptRule *rule) {
Expand All @@ -69,9 +137,12 @@ Status OptGroup::explore(const OptRule *rule) {
}
setExplored(rule);

// TODO(yee): the opt group maybe in the loop body branch
// DCHECK(isRootGroup_ || !groupNodesReferenced_.empty())
// << "Current group should be referenced by other group nodes before optimization";

for (auto iter = groupNodes_.begin(); iter != groupNodes_.end();) {
auto groupNode = *iter;
DCHECK(groupNode != nullptr);
auto *groupNode = DCHECK_NOTNULL(*iter);
if (groupNode->isExplored(rule)) {
++iter;
continue;
Expand All @@ -80,23 +151,33 @@ Status OptGroup::explore(const OptRule *rule) {
NG_RETURN_IF_ERROR(groupNode->explore(rule));

// Find more equivalents
std::vector<OptGroup *> boundary;
std::vector<OptGroup *> leaves;
auto status = rule->match(ctx_, groupNode);
if (!status.ok()) {
++iter;
continue;
}
ctx_->setChanged(true);
auto matched = std::move(status).value();
matched.collectBoundary(boundary);
matched.collectPatternLeaves(leaves);
auto resStatus = rule->transform(ctx_, matched);
NG_RETURN_IF_ERROR(resStatus);
auto result = std::move(resStatus).value();

for (auto *gn : result.newGroupNodes) {
auto it = std::find(groupNodes_.begin(), groupNodes_.end(), gn);
if (it != groupNodes_.end()) {
return Status::Error("Should not push the OptGroupNode %s into the group in the rule: %s",
gn->node()->toString().c_str(),
rule->toString().c_str());
}
}

// In some cases, we can apply optimization rules even if the control flow and data flow are
// inconsistent. For now, let the optimization rules themselves guarantee correctness.
if (result.eraseAll) {
for (auto gnode : groupNodes_) {
gnode->node()->releaseSymbols();
gnode->release();
}
groupNodes_.clear();
for (auto ngn : result.newGroupNodes) {
Expand All @@ -114,13 +195,16 @@ Status OptGroup::explore(const OptRule *rule) {
}

if (result.eraseCurr) {
(*iter)->node()->releaseSymbols();
(*iter)->release();
iter = groupNodes_.erase(iter);
} else {
++iter;
}
}

DCHECK(!groupNodes_.empty())
<< "Should have at least one group node after optimizing the current group";

return Status::OK();
}

Expand All @@ -137,6 +221,7 @@ Status OptGroup::exploreUntilMaxRound(const OptRule *rule) {
}

std::pair<double, const OptGroupNode *> OptGroup::findMinCostGroupNode() const {
DCHECK(!groupNodes_.empty()) << "There is no any group nodes in opt group";
double minCost = std::numeric_limits<double>::max();
const OptGroupNode *minGroupNode = nullptr;
for (auto &groupNode : groupNodes_) {
Expand All @@ -155,8 +240,19 @@ double OptGroup::getCost() const {

const PlanNode *OptGroup::getPlan() const {
const OptGroupNode *minGroupNode = findMinCostGroupNode().second;
DCHECK(minGroupNode != nullptr);
return minGroupNode->getPlan();
return DCHECK_NOTNULL(minGroupNode)->getPlan();
}

void OptGroup::deleteRefGroupNode(const OptGroupNode *node) {
groupNodesReferenced_.erase(node);
if (groupNodesReferenced_.empty()) {
// Cleanup all opt group nodes in current opt group if it's NOT referenced by any other opt
// group nodes
for (auto *n : groupNodes_) {
n->release();
}
groupNodes_.clear();
}
}

OptGroupNode *OptGroupNode::create(OptContext *ctx, PlanNode *node, const OptGroup *group) {
Expand Down Expand Up @@ -224,5 +320,49 @@ const PlanNode *OptGroupNode::getPlan() const {
return node_;
}

void OptGroupNode::release() {
node_->releaseSymbols();
for (auto *dep : dependencies_) {
dep->deleteRefGroupNode(this);
}
}

Status OptGroupNode::validate(const OptRule *rule) const {
if (!node_) {
return Status::Error("The OptGroupNode does not have plan node when applying the rule: %s",
rule->toString().c_str());
}
if (!group_) {
return Status::Error(
"The OptGroupNode does not have the right OptGroup when applying the rule: %s",
rule->toString().c_str());
}
if (!bodies_.empty()) {
if (node_->kind() != PlanNode::Kind::kLoop) {
return Status::Error(
"The plan node is not Loop in OptGroupNode when applying the rule: %s, planNode: %s",
rule->toString().c_str(),
PlanNode::toString(node_->kind()));
}
for (auto *g : bodies_) {
NG_RETURN_IF_ERROR(g->validate(rule));
}
}
if (dependencies_.empty()) {
if (node_->kind() != PlanNode::Kind::kStart && node_->kind() != PlanNode::Kind::kArgument) {
return Status::Error(
"The leaf plan node is not Start or Argument in OptGroupNode when applying the rule: %s, "
"planNode: %s",
rule->toString().c_str(),
PlanNode::toString(node_->kind()));
}
} else {
for (auto *g : dependencies_) {
NG_RETURN_IF_ERROR(g->validate(rule));
}
}
return Status::OK();
}

} // namespace opt
} // namespace nebula
28 changes: 28 additions & 0 deletions src/graph/optimizer/OptGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,38 @@ class OptGroup final {
return outputVar_;
}

void addRefGroupNode(const OptGroupNode *node) {
groupNodesReferenced_.insert(node);
}

void deleteRefGroupNode(const OptGroupNode *node);

void setRootGroup() {
isRootGroup_ = true;
}

Status validate(const OptRule *rule) const;

private:
friend ObjectPool;
explicit OptGroup(OptContext *ctx) noexcept;

static constexpr int16_t kMaxExplorationRound = 128;

std::pair<double, const OptGroupNode *> findMinCostGroupNode() const;
Status validateSubPlan(const OptGroupNode *gn,
const OptRule *rule,
const std::vector<OptGroup *> &patternLeaves) const;

OptContext *ctx_{nullptr};
std::list<OptGroupNode *> groupNodes_;
std::vector<const OptRule *> exploredRules_;
// The output variable should be same across the whole group.
std::string outputVar_;

bool isRootGroup_{false};
// Save the OptGroupNode which references this OptGroup
std::unordered_set<const OptGroupNode *> groupNodesReferenced_;
};

class OptGroupNode final {
Expand All @@ -69,6 +88,7 @@ class OptGroupNode final {

void dependsOn(OptGroup *dep) {
dependencies_.emplace_back(dep);
dep->addRefGroupNode(this);
}

const std::vector<OptGroup *> &dependencies() const {
Expand All @@ -77,6 +97,9 @@ class OptGroupNode final {

void setDeps(std::vector<OptGroup *> deps) {
dependencies_ = deps;
for (auto *dep : deps) {
dep->addRefGroupNode(this);
}
}

void addBody(OptGroup *body) {
Expand Down Expand Up @@ -109,6 +132,11 @@ class OptGroupNode final {
double getCost() const;
const graph::PlanNode *getPlan() const;

// Release the opt group node from its opt group
void release();

Status validate(const OptRule *rule) const;

private:
friend ObjectPool;
OptGroupNode(graph::PlanNode *node, const OptGroup *group) noexcept;
Expand Down
Loading