Skip to content

Commit

Permalink
Push filter down HashInnerJoin rule (#4956)
Browse files Browse the repository at this point in the history
* Push filter down hash inner join

* rewrite argument input var

* Fix bug

* Enhancement

* Disable the cleanup in optimizer

* More DCHECK

* More checks in optimizer

* Use pattern leaves to validate the sub-plan

* Rename BiInnerJoin to hash join

* cleanup

* Add validator for OptGroup

* Fuck the bug

* cancel DCHECK

* Fix tck tests
  • Loading branch information
yixinglu committed Dec 6, 2022
1 parent fef6560 commit 7d27f32
Show file tree
Hide file tree
Showing 29 changed files with 908 additions and 184 deletions.
26 changes: 1 addition & 25 deletions src/graph/context/Symbols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,6 @@ bool SymbolTable::deleteWrittenBy(const std::string& varName, PlanNode* node) {
if (var == vars_.end()) {
return false;
}
for (auto& alias : var->second->colNames) {
auto found = aliasGeneratedBy_.find(alias);
if (found != aliasGeneratedBy_.end()) {
if (found->second == varName) {
aliasGeneratedBy_.erase(alias);
}
}
}
var->second->writtenBy.erase(node);
return true;
}
Expand All @@ -106,6 +98,7 @@ bool SymbolTable::updateWrittenBy(const std::string& oldVar,
}

Variable* SymbolTable::getVar(const std::string& varName) {
DCHECK(!varName.empty()) << "the variable name is empty";
auto var = vars_.find(varName);
if (var == vars_.end()) {
return nullptr;
Expand All @@ -114,22 +107,5 @@ Variable* SymbolTable::getVar(const std::string& varName) {
}
}

void SymbolTable::setAliasGeneratedBy(const std::vector<std::string>& aliases,
const std::string& varName) {
for (auto& alias : aliases) {
if (aliasGeneratedBy_.count(alias) == 0) {
aliasGeneratedBy_.emplace(alias, varName);
}
}
}

StatusOr<std::string> SymbolTable::getAliasGeneratedBy(const std::string& alias) {
auto found = aliasGeneratedBy_.find(alias);
if (found == aliasGeneratedBy_.end()) {
return Status::Error("Not found a variable that generates the alias: %s", alias.c_str());
} else {
return found->second;
}
}
} // namespace graph
} // namespace nebula
6 changes: 0 additions & 6 deletions src/graph/context/Symbols.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ class SymbolTable final {

Variable* getVar(const std::string& varName);

void setAliasGeneratedBy(const std::vector<std::string>& aliases, const std::string& varName);

StatusOr<std::string> getAliasGeneratedBy(const std::string& alias);

std::string toString() const;

private:
Expand All @@ -85,8 +81,6 @@ class SymbolTable final {
ExecutionContext* ectx_{nullptr};
// var name -> variable
std::unordered_map<std::string, Variable*> vars_;
// alias -> first variable that generate the alias
std::unordered_map<std::string, std::string> aliasGeneratedBy_;
};

} // namespace graph
Expand Down
1 change: 1 addition & 0 deletions src/graph/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ nebula_add_library(
rule/PushFilterDownAggregateRule.cpp
rule/PushFilterDownProjectRule.cpp
rule/PushFilterDownLeftJoinRule.cpp
rule/PushFilterDownHashInnerJoinRule.cpp
rule/PushFilterDownInnerJoinRule.cpp
rule/PushFilterDownNodeRule.cpp
rule/PushFilterDownScanVerticesRule.cpp
Expand Down
174 changes: 157 additions & 17 deletions src/graph/optimizer/OptGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,81 @@ OptGroup::OptGroup(OptContext *ctx) noexcept : ctx_(ctx) {
DCHECK(ctx != nullptr);
}

Status OptGroup::validateSubPlan(const OptGroupNode *gn,
const OptRule *rule,
const std::vector<OptGroup *> &patternLeaves) const {
auto &deps = DCHECK_NOTNULL(gn)->dependencies();

auto checkDepGroup = [this, gn, rule, &patternLeaves](const OptGroup *depGroup) -> Status {
auto iter = std::find(patternLeaves.begin(), patternLeaves.end(), depGroup);
if (iter != patternLeaves.end()) {
return Status::OK();
}
if (!depGroup) {
return Status::Error("Could not find the dependent group in pattern leaves");
}
if (depGroup->groupNodes_.size() != 1U || depGroup->groupNodesReferenced_.size() != 1U) {
return Status::Error(
"Invalid sub-plan generated when applying the rule: %s, "
"planNode: %s, numGroupNodes: %lu, numGroupNodesRef: %lu",
rule->toString().c_str(),
PlanNode::toString(gn->node()->kind()),
depGroup->groupNodes_.size(),
depGroup->groupNodesReferenced_.size());
}
return validateSubPlan(depGroup->groupNodes_.front(), rule, patternLeaves);
};

switch (deps.size()) {
case 0: {
auto iter = std::find(patternLeaves.begin(), patternLeaves.end(), nullptr);
if (iter == patternLeaves.end()) {
return Status::Error("Invalid sub-plan generated by the rule: %s, planNode: %s",
rule->toString().c_str(),
PlanNode::toString(gn->node()->kind()));
}
break;
}
case 1: {
NG_RETURN_IF_ERROR(checkDepGroup(deps[0]));
break;
}
case 2: {
NG_RETURN_IF_ERROR(checkDepGroup(deps[0]));
NG_RETURN_IF_ERROR(checkDepGroup(deps[1]));
break;
}
default: {
return Status::Error("Invalid dependencies of opt group node: %lu", deps.size());
}
}
return Status::OK();
}

Status OptGroup::validate(const OptRule *rule) const {
if (groupNodes_.empty() && !groupNodesReferenced_.empty()) {
return Status::Error(
"The OptGroup has no any OptGroupNode but used by other OptGroupNode "
"when applying the rule: %s, numGroupNodesRef: %lu",
rule->toString().c_str(),
groupNodesReferenced_.size());
}
for (auto *gn : groupNodes_) {
NG_RETURN_IF_ERROR(gn->validate(rule));
if (gn->node()->outputVar() != outputVar_) {
return Status::Error(
"The output columns of the OptGroupNode is different from OptGroup "
"when applying the rule: %s, %s vs. %s",
rule->toString().c_str(),
gn->node()->outputVar().c_str(),
outputVar_.c_str());
}
}
return Status::OK();
}

void OptGroup::addGroupNode(OptGroupNode *groupNode) {
DCHECK(groupNode != nullptr);
DCHECK(groupNode->group() == this);
DCHECK_EQ(this, DCHECK_NOTNULL(groupNode)->group());
if (outputVar_.empty()) {
outputVar_ = groupNode->node()->outputVar();
} else {
Expand All @@ -54,13 +126,9 @@ void OptGroup::addGroupNode(OptGroupNode *groupNode) {
}

OptGroupNode *OptGroup::makeGroupNode(PlanNode *node) {
if (outputVar_.empty()) {
outputVar_ = node->outputVar();
} else {
DCHECK_EQ(outputVar_, node->outputVar());
}
groupNodes_.emplace_back(OptGroupNode::create(ctx_, node, this));
return groupNodes_.back();
auto *gn = OptGroupNode::create(ctx_, node, this);
addGroupNode(gn);
return gn;
}

Status OptGroup::explore(const OptRule *rule) {
Expand All @@ -69,9 +137,12 @@ Status OptGroup::explore(const OptRule *rule) {
}
setExplored(rule);

// TODO(yee): the opt group maybe in the loop body branch
// DCHECK(isRootGroup_ || !groupNodesReferenced_.empty())
// << "Current group should be referenced by other group nodes before optimization";

for (auto iter = groupNodes_.begin(); iter != groupNodes_.end();) {
auto groupNode = *iter;
DCHECK(groupNode != nullptr);
auto *groupNode = DCHECK_NOTNULL(*iter);
if (groupNode->isExplored(rule)) {
++iter;
continue;
Expand All @@ -80,23 +151,33 @@ Status OptGroup::explore(const OptRule *rule) {
NG_RETURN_IF_ERROR(groupNode->explore(rule));

// Find more equivalents
std::vector<OptGroup *> boundary;
std::vector<OptGroup *> leaves;
auto status = rule->match(ctx_, groupNode);
if (!status.ok()) {
++iter;
continue;
}
ctx_->setChanged(true);
auto matched = std::move(status).value();
matched.collectBoundary(boundary);
matched.collectPatternLeaves(leaves);
auto resStatus = rule->transform(ctx_, matched);
NG_RETURN_IF_ERROR(resStatus);
auto result = std::move(resStatus).value();

for (auto *gn : result.newGroupNodes) {
auto it = std::find(groupNodes_.begin(), groupNodes_.end(), gn);
if (it != groupNodes_.end()) {
return Status::Error("Should not push the OptGroupNode %s into the group in the rule: %s",
gn->node()->toString().c_str(),
rule->toString().c_str());
}
}

// In some cases, we can apply optimization rules even if the control flow and data flow are
// inconsistent. For now, let the optimization rules themselves guarantee correctness.
if (result.eraseAll) {
for (auto gnode : groupNodes_) {
gnode->node()->releaseSymbols();
gnode->release();
}
groupNodes_.clear();
for (auto ngn : result.newGroupNodes) {
Expand All @@ -114,13 +195,16 @@ Status OptGroup::explore(const OptRule *rule) {
}

if (result.eraseCurr) {
(*iter)->node()->releaseSymbols();
(*iter)->release();
iter = groupNodes_.erase(iter);
} else {
++iter;
}
}

DCHECK(!groupNodes_.empty())
<< "Should have at least one group node after optimizing the current group";

return Status::OK();
}

Expand All @@ -137,6 +221,7 @@ Status OptGroup::exploreUntilMaxRound(const OptRule *rule) {
}

std::pair<double, const OptGroupNode *> OptGroup::findMinCostGroupNode() const {
DCHECK(!groupNodes_.empty()) << "There is no any group nodes in opt group";
double minCost = std::numeric_limits<double>::max();
const OptGroupNode *minGroupNode = nullptr;
for (auto &groupNode : groupNodes_) {
Expand All @@ -155,8 +240,19 @@ double OptGroup::getCost() const {

const PlanNode *OptGroup::getPlan() const {
const OptGroupNode *minGroupNode = findMinCostGroupNode().second;
DCHECK(minGroupNode != nullptr);
return minGroupNode->getPlan();
return DCHECK_NOTNULL(minGroupNode)->getPlan();
}

void OptGroup::deleteRefGroupNode(const OptGroupNode *node) {
groupNodesReferenced_.erase(node);
if (groupNodesReferenced_.empty()) {
// Cleanup all opt group nodes in current opt group if it's NOT referenced by any other opt
// group nodes
for (auto *n : groupNodes_) {
n->release();
}
groupNodes_.clear();
}
}

OptGroupNode *OptGroupNode::create(OptContext *ctx, PlanNode *node, const OptGroup *group) {
Expand Down Expand Up @@ -224,5 +320,49 @@ const PlanNode *OptGroupNode::getPlan() const {
return node_;
}

void OptGroupNode::release() {
node_->releaseSymbols();
for (auto *dep : dependencies_) {
dep->deleteRefGroupNode(this);
}
}

Status OptGroupNode::validate(const OptRule *rule) const {
if (!node_) {
return Status::Error("The OptGroupNode does not have plan node when applying the rule: %s",
rule->toString().c_str());
}
if (!group_) {
return Status::Error(
"The OptGroupNode does not have the right OptGroup when applying the rule: %s",
rule->toString().c_str());
}
if (!bodies_.empty()) {
if (node_->kind() != PlanNode::Kind::kLoop) {
return Status::Error(
"The plan node is not Loop in OptGroupNode when applying the rule: %s, planNode: %s",
rule->toString().c_str(),
PlanNode::toString(node_->kind()));
}
for (auto *g : bodies_) {
NG_RETURN_IF_ERROR(g->validate(rule));
}
}
if (dependencies_.empty()) {
if (node_->kind() != PlanNode::Kind::kStart && node_->kind() != PlanNode::Kind::kArgument) {
return Status::Error(
"The leaf plan node is not Start or Argument in OptGroupNode when applying the rule: %s, "
"planNode: %s",
rule->toString().c_str(),
PlanNode::toString(node_->kind()));
}
} else {
for (auto *g : dependencies_) {
NG_RETURN_IF_ERROR(g->validate(rule));
}
}
return Status::OK();
}

} // namespace opt
} // namespace nebula
28 changes: 28 additions & 0 deletions src/graph/optimizer/OptGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,38 @@ class OptGroup final {
return outputVar_;
}

void addRefGroupNode(const OptGroupNode *node) {
groupNodesReferenced_.insert(node);
}

void deleteRefGroupNode(const OptGroupNode *node);

void setRootGroup() {
isRootGroup_ = true;
}

Status validate(const OptRule *rule) const;

private:
friend ObjectPool;
explicit OptGroup(OptContext *ctx) noexcept;

static constexpr int16_t kMaxExplorationRound = 128;

std::pair<double, const OptGroupNode *> findMinCostGroupNode() const;
Status validateSubPlan(const OptGroupNode *gn,
const OptRule *rule,
const std::vector<OptGroup *> &patternLeaves) const;

OptContext *ctx_{nullptr};
std::list<OptGroupNode *> groupNodes_;
std::vector<const OptRule *> exploredRules_;
// The output variable should be same across the whole group.
std::string outputVar_;

bool isRootGroup_{false};
// Save the OptGroupNode which references this OptGroup
std::unordered_set<const OptGroupNode *> groupNodesReferenced_;
};

class OptGroupNode final {
Expand All @@ -69,6 +88,7 @@ class OptGroupNode final {

void dependsOn(OptGroup *dep) {
dependencies_.emplace_back(dep);
dep->addRefGroupNode(this);
}

const std::vector<OptGroup *> &dependencies() const {
Expand All @@ -77,6 +97,9 @@ class OptGroupNode final {

void setDeps(std::vector<OptGroup *> deps) {
dependencies_ = deps;
for (auto *dep : deps) {
dep->addRefGroupNode(this);
}
}

void addBody(OptGroup *body) {
Expand Down Expand Up @@ -109,6 +132,11 @@ class OptGroupNode final {
double getCost() const;
const graph::PlanNode *getPlan() const;

// Release the opt group node from its opt group
void release();

Status validate(const OptRule *rule) const;

private:
friend ObjectPool;
OptGroupNode(graph::PlanNode *node, const OptGroup *group) noexcept;
Expand Down
Loading

0 comments on commit 7d27f32

Please sign in to comment.