Skip to content

Commit

Permalink
Fixed up aggregator polysis!!! Now passes all tests!!
Browse files Browse the repository at this point in the history
  • Loading branch information
azreika committed Nov 2, 2020
1 parent efc5586 commit b40da9f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 7 deletions.
18 changes: 17 additions & 1 deletion src/ast/Aggregator.h
Expand Up @@ -26,6 +26,7 @@
#include "souffle/utility/MiscUtil.h"
#include "souffle/utility/StreamUtil.h"
#include <memory>
#include <optional>
#include <ostream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -82,8 +83,12 @@ class Aggregator : public Argument {
}

Aggregator* clone() const override {
return new Aggregator(
auto* copy = new Aggregator(
baseOperator, souffle::clone(targetExpression), souffle::clone(body), getSrcLoc());
if (finalTranslatorType.has_value()) {
copy->setFinalType(finalTranslatorType.value());
}
return copy;
}

void apply(const NodeMapper& map) override {
Expand All @@ -95,6 +100,14 @@ class Aggregator : public Argument {
}
}

void setFinalType(AggregateOp newType) {
finalTranslatorType = newType;
}

std::optional<AggregateOp> getFinalType() const {
return finalTranslatorType;
}

protected:
void print(std::ostream& os) const override {
os << baseOperator;
Expand All @@ -119,6 +132,9 @@ class Aggregator : public Argument {

/** Body literal of sub-query */
VecOwn<Literal> body;

// TODO (azreika): remove after refactoring translator
std::optional<AggregateOp> finalTranslatorType;
};

} // namespace souffle::ast
14 changes: 10 additions & 4 deletions src/ast/analysis/Type.cpp
Expand Up @@ -947,6 +947,7 @@ bool TypeAnalysis::hasInvalidPolymorphicNumericConstantType(const NumericConstan
}

AggregateOp TypeAnalysis::getPolymorphicOperator(const Aggregator* aggr) const {
assert(contains(aggregatorType, aggr) && "aggregator does not have a set type");
return aggregatorType.at(aggr);
}

Expand Down Expand Up @@ -1042,9 +1043,6 @@ void TypeAnalysis::run(const TranslationUnit& translationUnit) {
auto isUnsigned = [&](const Argument* argument) {
return isOfKind(getTypes(argument), TypeAttribute::Unsigned);
};
auto isSigned = [&](const Argument* argument) {
return isOfKind(getTypes(argument), TypeAttribute::Signed);
};
auto setAggregatorType = [&](const Aggregator& aggr, TypeAttribute attr) {
auto overloadedType = convertOverloadedAggregator(aggr.getBaseOperator(), attr);
if (contains(aggregatorType, &aggr) && aggregatorType.at(&aggr) == overloadedType) return;
Expand All @@ -1058,9 +1056,17 @@ void TypeAnalysis::run(const TranslationUnit& translationUnit) {
setAggregatorType(aggregator, TypeAttribute::Float);
} else if (isUnsigned(targetExpression)) {
setAggregatorType(aggregator, TypeAttribute::Unsigned);
} else if (isSigned(targetExpression)) {
} else {
setAggregatorType(aggregator, TypeAttribute::Signed);
}
} else {
if (contains(aggregatorType, &aggregator)) {
assert(aggregatorType.at(&aggregator) == aggregator.getBaseOperator() &&
"unexpected aggr type");
return;
}
changed = true;
aggregatorType[&aggregator] = aggregator.getBaseOperator();
}
});
}
Expand Down
3 changes: 3 additions & 0 deletions src/ast2ram/AstToRamTranslator.cpp
Expand Up @@ -1032,6 +1032,9 @@ void AstToRamTranslator::translateProgram(const ast::TranslationUnit& translatio
visitDepthFirst(*program, [&](const ast::NumericConstant& nc) {
const_cast<ast::NumericConstant&>(nc).setFinalType(polyAnalysis->getInferredType(&nc));
});
visitDepthFirst(*program, [&](const ast::Aggregator& aggr) {
const_cast<ast::Aggregator&>(aggr).setFinalType(polyAnalysis->getOverloadedOperator(&aggr));
});

// determine the sips to use
std::string sipsChosen = "all-bound";
Expand Down
3 changes: 1 addition & 2 deletions src/ast2ram/ClauseTranslator.cpp
Expand Up @@ -187,8 +187,7 @@ Own<ram::Statement> ClauseTranslator::translateClause(
auto expr = translator.translateValue(agg->getTargetExpression(), *valueIndex);

// add Ram-Aggregation layer
const auto* polyAnalysis = translator.getPolymorphicObjectsAnalysis();
op = mk<ram::Aggregate>(std::move(op), polyAnalysis->getOverloadedOperator(agg),
op = mk<ram::Aggregate>(std::move(op), agg->getFinalType().value(),
translator.translateRelation(atom), expr ? std::move(expr) : mk<ram::UndefValue>(),
aggCond ? std::move(aggCond) : mk<ram::True>(), level);
} else if (const auto* func = dynamic_cast<const ast::IntrinsicFunctor*>(cur)) {
Expand Down

0 comments on commit b40da9f

Please sign in to comment.