Skip to content

Commit

Permalink
Formula Validation using ExpressionParser (#1998)
Browse files Browse the repository at this point in the history
* Formula Validation using ExpressionParser

* Remove extra lines

* Review Comments

* Review Comments

* Rebase Fix

* Override canNest method for DailyAverageScorePerPeriod

Co-authored-by: Rishi Agarwal <rishi.agarwal@verizonmedia.com>
  • Loading branch information
rishi-aga and Rishi Agarwal committed Apr 14, 2021
1 parent a62e387 commit 699e5df
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 112 deletions.
Expand Up @@ -136,7 +136,7 @@ public static String getFieldAlias(String tableAlias, String fieldName) {
* @param alias alias
* @return True if is null or empty
*/
private static boolean nullOrEmpty(String alias) {
public static boolean nullOrEmpty(String alias) {
return alias == null || alias.equals("");
}

Expand Down
Expand Up @@ -67,7 +67,7 @@
*
* @return metric formula
*/
String value();
String value() default "";

/**
* Function which constructs a projection for this given metric.
Expand Down
Expand Up @@ -5,23 +5,22 @@
*/
package com.yahoo.elide.datastores.aggregation.metadata;

import com.yahoo.elide.core.type.Type;
import com.yahoo.elide.datastores.aggregation.annotation.DimensionFormula;
import com.yahoo.elide.datastores.aggregation.annotation.MetricFormula;
import com.yahoo.elide.datastores.aggregation.core.JoinPath;
import com.yahoo.elide.datastores.aggregation.metadata.models.Column;
import com.yahoo.elide.datastores.aggregation.query.ColumnProjection;
import com.yahoo.elide.datastores.aggregation.query.MetricProjection;
import com.yahoo.elide.datastores.aggregation.query.Queryable;
import com.yahoo.elide.datastores.aggregation.queryengines.sql.expression.ExpressionParser;
import com.yahoo.elide.datastores.aggregation.queryengines.sql.expression.Reference;

import java.util.LinkedHashSet;
import java.util.List;
import java.util.stream.Collectors;

/**
* FormulaValidator check whether a column defined with {@link MetricFormula} or
* {@link DimensionFormula} has reference loop. If so, throw out exception.
*/
public class FormulaValidator extends ColumnVisitor<Void> {
public class FormulaValidator extends ExpressionParser {
private final LinkedHashSet<String> visited = new LinkedHashSet<>();

private static String getColumnId(Queryable parent, ColumnProjection column) {
Expand All @@ -33,57 +32,16 @@ public FormulaValidator(MetaDataStore metaDataStore) {
}

@Override
protected Void visitFormulaMetric(Queryable parent, MetricProjection metric) {
return visitFormulaColumn(parent, metric);
}

@Override
protected Void visitFormulaDimension(Queryable parent, ColumnProjection dimension) {
return visitFormulaColumn(parent, dimension);
}

@Override
protected Void visitFieldDimension(Queryable parent, ColumnProjection dimension) {
return null;
}
public List<Reference> parse(Queryable source, ColumnProjection column) {
String columnId = getColumnId(source, column);

/**
* For a FORMULA column, we just need to check all source columns in the formula expression.
*
* @param column a column defined with {@link MetricFormula} or {@link DimensionFormula}
* @return null
*/
private Void visitFormulaColumn(Queryable source, ColumnProjection column) {
if (visited.contains(getColumnId(source, column))) {
if (!visited.add(columnId)) {
throw new IllegalArgumentException(referenceLoopMessage(visited, source, column));
}
List<Reference> references = parse(source, column.getExpression());
visited.remove(columnId);

Type<?> tableClass = dictionary.getEntityClass(source.getName(), source.getVersion());

visited.add(getColumnId(source, column));
for (String reference : resolveFormulaReferences(column.getExpression())) {

//Column is from a query instead of a table. Nothing to validate.
if (source != source.getSource()) {
continue;
} else if (reference.contains(".")) {
JoinPath joinToPath = new JoinPath(tableClass, metaDataStore, reference);
Column joinToColumn = getColumn(joinToPath);
if (joinToColumn != null) {
visitColumn(joinToColumn.getTable().toQueryable(), joinToColumn.toProjection());
}
} else {
ColumnProjection referenceColumn = source.getColumnProjection(reference);

// if the reference is to a logical column, check it
if (referenceColumn != null && !reference.equals(column.getName())) {
visitColumn(source, referenceColumn);
}
}
}
visited.remove(getColumnId(source, column));

return null;
return references;
}

/**
Expand Down
Expand Up @@ -113,19 +113,17 @@ protected Column(Table table, String fieldName, EntityDictionary dictionary) {

if (dictionary.attributeOrRelationAnnotationExists(tableClass, fieldName, MetricFormula.class)) {
columnType = FORMULA;
expression = dictionary
.getAttributeOrRelationAnnotation(tableClass, MetricFormula.class, fieldName).value();
MetricFormula metricFormula = dictionary.getAttributeOrRelationAnnotation(tableClass, MetricFormula.class,
fieldName);
this.expression = metricFormula.value();
this.arguments = Arrays.stream(metricFormula.arguments())
.map(argument -> new Argument(getId(), argument))
.collect(Collectors.toCollection(LinkedHashSet::new));
} else if (dictionary.attributeOrRelationAnnotationExists(tableClass, fieldName, DimensionFormula.class)) {
columnType = FORMULA;
expression = dictionary
.getAttributeOrRelationAnnotation(tableClass, DimensionFormula.class, fieldName).value();
DimensionFormula dimensionFormula = dictionary.getAttributeOrRelationAnnotation(tableClass,
DimensionFormula.class, fieldName);
this.expression = dimensionFormula.value();
this.arguments = Arrays.stream(dimensionFormula.arguments())
.map(argument -> new Argument(getId(), argument))
.collect(Collectors.toCollection(LinkedHashSet::new));
Expand Down
Expand Up @@ -32,18 +32,38 @@ public Metric(Table table, String fieldName, EntityDictionary dictionary) {
super(table, fieldName, dictionary);
Type<?> tableClass = dictionary.getEntityClass(table.getName(), table.getVersion());

MetricFormula formula = dictionary.getAttributeOrRelationAnnotation(
tableClass,
MetricFormula.class,
fieldName);

if (formula != null) {
this.metricProjectionMaker = dictionary.getInjector().instantiate(formula.maker());
dictionary.getInjector().inject(this.metricProjectionMaker);

} else {
throw new IllegalStateException("Trying to construct metric field "
+ getId() + " without @MetricFormula.");
MetricFormula formula = dictionary.getAttributeOrRelationAnnotation(tableClass, MetricFormula.class, fieldName);

verfiyFormula(formula);

this.metricProjectionMaker = dictionary.getInjector().instantiate(formula.maker());
dictionary.getInjector().inject(this.metricProjectionMaker);
}

private void verfiyFormula(MetricFormula formula) {
if (formula == null) {
throw new IllegalStateException("Trying to construct metric field " + getId() + " without @MetricFormula.");
}

String defaultValue;
Class<?> defaultMaker;

try {
defaultValue = (String) MetricFormula.class.getDeclaredMethod("value").getDefaultValue();
defaultMaker = (Class<?>) MetricFormula.class.getDeclaredMethod("maker").getDefaultValue();
} catch (NoSuchMethodException | SecurityException e) {
throw new IllegalStateException("Error encountered while constructing metric field: " + getId()
+ ". " + e.getMessage());
}

if (formula.value().equals(defaultValue) && formula.maker().equals(defaultMaker)) {
throw new IllegalStateException("Trying to construct metric field " + getId()
+ " with default values. Provide either value or maker in @MetricFormula.");
}

if (!formula.value().equals(defaultValue) && !formula.maker().equals(defaultMaker)) {
throw new IllegalStateException("Trying to construct metric field " + getId()
+ " with value and maker. Provide either one in @MetricFormula, both are not allowed.");
}
}
}
Expand Up @@ -186,6 +186,14 @@ default Queryable getRoot() {
return current.getSource();
}

/**
* Determines if this queryable is root table.
* @return true if this queryable is root table.
*/
default boolean isRoot() {
return this == this.getRoot();
}

/**
* Returns the depth of the nesting of this Queryable.
* @return 0 for unnested. Positive integer for nested..
Expand Down
Expand Up @@ -6,6 +6,9 @@

package com.yahoo.elide.datastores.aggregation.queryengines.sql.expression;

import static com.yahoo.elide.core.utils.TypeHelper.nullOrEmpty;
import static com.yahoo.elide.datastores.aggregation.queryengines.sql.metadata.SQLReferenceTable.PERIOD;

import com.yahoo.elide.core.Path;
import com.yahoo.elide.core.dictionary.EntityDictionary;
import com.yahoo.elide.core.type.Type;
Expand All @@ -14,10 +17,16 @@
import com.yahoo.elide.datastores.aggregation.metadata.enums.ColumnType;
import com.yahoo.elide.datastores.aggregation.query.ColumnProjection;
import com.yahoo.elide.datastores.aggregation.query.Queryable;

import com.github.jknack.handlebars.EscapingStrategy;
import com.github.jknack.handlebars.Handlebars;
import com.github.jknack.handlebars.Helper;
import com.github.jknack.handlebars.Options;
import com.github.jknack.handlebars.Template;
import com.google.common.base.Preconditions;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -29,15 +38,44 @@
public class ExpressionParser {

private static final Pattern REFERENCE_PARENTHESES = Pattern.compile("\\{\\{(.+?)}}");
private static final String SQL_HELPER_PREFIX = "sql ";

private MetaDataStore metaDataStore;
private EntityDictionary dictionary;
private final Handlebars handlebars = new Handlebars()
.with(EscapingStrategy.NOOP)
.registerHelper("sql", new Helper<Object>() {

@Override
public Object apply(final Object context, final Options options) throws IOException {
String from = options.hash("from");
String column = options.hash("column");
int argsIndex = column.indexOf('[');

// Remove args from column
column = argsIndex == -1 ? column : column.substring(0, argsIndex);
// Prefix column with join table name
column = nullOrEmpty(from) ? column : from + PERIOD + column;

return column;
}
});

public ExpressionParser(MetaDataStore store) {
this.dictionary = store.getMetadataDictionary();
this.metaDataStore = store;
}

/**
* Parses the column or join expression and returns the list of discovered references.
* @param source The source table where the column or join expression lives.
* @param column {@link ColumnProjection}
* @return A list of discovered references.
*/
public List<Reference> parse(Queryable source, ColumnProjection column) {
return parse(source, column.getExpression());
}

/**
* Parses the column or join expression and returns the list of discovered references.
* @param source The source table where the column or join expression lives.
Expand All @@ -50,6 +88,17 @@ public List<Reference> parse(Queryable source, String expression) {
List<Reference> results = new ArrayList<>();

for (String referenceName : referenceNames) {

// Change `sql from='joinName' column='columnName[a1:v1][a2:v2]'` to `joinName.columnName`
if (referenceName.startsWith(SQL_HELPER_PREFIX)) {
try {
Template template = handlebars.compileInline(toFormulaReference(referenceName));
referenceName = template.apply(Collections.emptyMap());
} catch (IOException e) {
throw new IllegalStateException(e.getMessage());
}
}

if (referenceName.startsWith("$$")) {
continue;
}
Expand Down Expand Up @@ -91,7 +140,7 @@ private Reference buildReferenceFromField(Queryable source, String fieldName) {
.builder()
.source(source)
.column(column)
.references(parse(source, column.getExpression()))
.references(parse(source, column))
.build();
}
}
Expand Down Expand Up @@ -141,4 +190,14 @@ private static List<String> resolveFormulaReferences(String expression) {

return references;
}

/**
* Convert a resolved formula reference back to a reference presented in formula format.
*
* @param reference referenced field
* @return formula reference, <code>{{reference}}</code>
*/
private static String toFormulaReference(String reference) {
return "{{" + reference + "}}";
}
}
Expand Up @@ -69,13 +69,19 @@ protected SQLReferenceTable(MetaDataStore metaDataStore, Set<Queryable> queryabl
this.metaDataStore = metaDataStore;
this.dictionary = this.metaDataStore.getMetadataDictionary();

queryables.stream().forEach(queryable -> {
Queryable next = queryable;
do {
resolveAndStoreAllReferencesAndJoins(next);
next = next.getSource();
} while (next.isNested());
});
queryables
.stream()
// If Queryable is root, then its SQLTable.
// We need to store references only for SQLTable and Nested Queries (Queryable -> Queryable -> SQLTable).
// In case of Query -> SQLTable. Query doesn't know about all logical references.
.filter(queryable -> queryable.isNested() || queryable.isRoot())
.forEach(queryable -> {
Queryable next = queryable;
do {
resolveAndStoreAllReferencesAndJoins(next);
next = next.getSource();
} while (next.isNested());
});
}

/**
Expand Down Expand Up @@ -138,7 +144,7 @@ public void resolveAndStoreAllReferencesAndJoins(Queryable queryable) {

queryable.getColumnProjections().forEach(column -> {
// validate that there is no reference loop
validator.visitColumn(queryable, column);
validator.parse(queryable, column);

String fieldName = column.getName();

Expand Down
Expand Up @@ -41,6 +41,7 @@ default String toSQL(Queryable source, SQLReferenceTable lookupTable) {
default boolean canNest(Queryable source, SQLReferenceTable lookupTable) {
SQLDialect dialect = source.getConnectionDetails().getDialect();
String sql = toSQL(source.getSource(), lookupTable);

SyntaxVerifier verifier = new SyntaxVerifier(dialect);
boolean canNest = verifier.verify(sql);
if (! canNest) {
Expand Down Expand Up @@ -69,7 +70,7 @@ default Pair<ColumnProjection, Set<ColumnProjection>> nest(Queryable source,
outerProjection = withExpression(getExpression(), inProjection);
innerProjections = joinProjections.stream().collect(Collectors.toCollection(LinkedHashSet::new));
} else {
outerProjection = withExpression("{{" + this.getSafeAlias() + "}}", isProjected());
outerProjection = withExpression("{{$" + this.getSafeAlias() + "}}", isProjected());
innerProjections = new LinkedHashSet<>(Arrays.asList(this));
}

Expand Down
Expand Up @@ -94,6 +94,7 @@ public String toSQL(Queryable source, SQLReferenceTable lookupTable) {

@Override
public boolean canNest(Queryable source, SQLReferenceTable lookupTable) {

if (lookupTable.getResolvedJoinProjections(source.getSource(), name).size() > 0) {
//We currently don't support nesting metrics with joins.
//A join could be part of the aggregation (inner) or post aggregation (outer) expression.
Expand Down Expand Up @@ -159,7 +160,7 @@ public Pair<ColumnProjection, Set<ColumnProjection>> nest(Queryable source,
dialect.getBeginQuote()
+ "?(" + getAggregationLabelPrefix(dialect.getCalciteDialect()) + "\\w+)"
+ dialect.getEndQuote()
+ "?", "{{$1}}");
+ "?", "{{\\$" + "$1" + "}}");

boolean inProjection = source.getColumnProjection(name, arguments) != null;

Expand Down
Expand Up @@ -117,7 +117,7 @@ public Pair<ColumnProjection, Set<ColumnProjection>> nest(Queryable source,
.grain(new TimeDimensionGrain(
this.getName(),
grain.getGrain()))
.expression("{{" + this.getSafeAlias() + "}}")
.expression("{{$" + this.getSafeAlias() + "}}")
.projected(isProjected())
.arguments(arguments)
.timeZone(timeZone)
Expand Down

0 comments on commit 699e5df

Please sign in to comment.