Skip to content

Commit

Permalink
TEIID-3903 fix for multi-source count *
Browse files Browse the repository at this point in the history
  • Loading branch information
shawkins committed Jan 14, 2016
1 parent 656d69c commit e6f487d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
Expand Up @@ -188,13 +188,15 @@ private void assignOutputElements(PlanNode root, List<Expression> outputElements
if (NodeEditor.findParent(root, NodeConstants.Types.PROJECT, NodeConstants.Types.SOURCE) != null) {
//there's a chance that partial projection was used. we are not a defacto project node
//take credit for creating anything that is not an element symbol
List<Expression> filteredElements = new ArrayList<Expression>(outputElements.size());
LinkedHashSet<Expression> filteredElements = new LinkedHashSet<Expression>();
for (Expression element : outputElements) {
if(element instanceof ElementSymbol) {
filteredElements.add(element);
} else {
filteredElements.addAll(ElementCollectorVisitor.getElements(element, false));
}
}
outputElements = filteredElements;
outputElements = new ArrayList<Expression>(filteredElements);
}
}
assignOutputElements(root.getLastChild(), outputElements, metadata, capFinder, rules, analysisRecord, context);
Expand Down
Expand Up @@ -142,7 +142,7 @@ public PlanNode execute(PlanNode plan,
}
Set<Expression> stagedGroupingSymbols = new LinkedHashSet<Expression>();
stagedGroupingSymbols.addAll(groupingExpressions);
aggregates = stageAggregates(groupNode, metadata, stagedGroupingSymbols, aggregates);
aggregates = stageAggregates(groupNode, metadata, stagedGroupingSymbols, aggregates, false);
if (aggregates.isEmpty() && stagedGroupingSymbols.isEmpty()) {
continue;
}
Expand Down Expand Up @@ -275,7 +275,7 @@ private void pushGroupNodeOverUnion(QueryMetadataInterface metadata, Capabilitie

List<AggregateSymbol> copy = new ArrayList<AggregateSymbol>(aggregates);
aggregates.clear();
Map<AggregateSymbol, Expression> aggMap = buildAggregateMap(copy, metadata, aggregates);
Map<AggregateSymbol, Expression> aggMap = buildAggregateMap(copy, metadata, aggregates, false);

boolean shouldPushdown = false;
List<Boolean> pushdownList = new ArrayList<Boolean>(unionChildren.size());
Expand Down Expand Up @@ -754,7 +754,7 @@ private void pushGroupNode(PlanNode groupNode,
}*/

if (aggregates != null) {
aggregates = stageAggregates(groupNode, metadata, stagedGroupingSymbols, aggregates);
aggregates = stageAggregates(groupNode, metadata, stagedGroupingSymbols, aggregates, true);
} else {
aggregates = new ArrayList<AggregateSymbol>(1);
}
Expand Down Expand Up @@ -829,7 +829,7 @@ private void addEmptyFilter(Collection<AggregateSymbol> aggregates,
Set<AggregateSymbol> stageAggregates(PlanNode groupNode,
QueryMetadataInterface metadata,
Set<Expression> stagedGroupingSymbols,
Collection<AggregateSymbol> aggregates) throws TeiidComponentException, QueryPlannerException {
Collection<AggregateSymbol> aggregates, boolean join) throws TeiidComponentException, QueryPlannerException {
//remove any aggregates that are computed over a group by column
for (final Iterator<AggregateSymbol> iterator = aggregates.iterator(); iterator.hasNext();) {
final AggregateSymbol symbol = iterator.next();
Expand All @@ -849,7 +849,7 @@ Set<AggregateSymbol> stageAggregates(PlanNode groupNode,
Set<AggregateSymbol> newAggs = new HashSet<AggregateSymbol>();
Map<AggregateSymbol, Expression> aggMap;
try {
aggMap = buildAggregateMap(aggregates, metadata, newAggs);
aggMap = buildAggregateMap(aggregates, metadata, newAggs, join);
} catch (QueryResolverException e) {
throw new QueryPlannerException(QueryPlugin.Event.TEIID30266, e);
}
Expand Down Expand Up @@ -1101,7 +1101,7 @@ private <T extends Expression> Map<PlanNode, List<T>> createNodeMapping(PlanNode
}

private static Map<AggregateSymbol, Expression> buildAggregateMap(Collection<? extends AggregateSymbol> aggregateExpressions,
QueryMetadataInterface metadata, Set<AggregateSymbol> nestedAggregates) throws QueryResolverException,
QueryMetadataInterface metadata, Set<AggregateSymbol> nestedAggregates, boolean join) throws QueryResolverException,
TeiidComponentException {
Map<AggregateSymbol, Expression> aggMap = new LinkedHashMap<AggregateSymbol, Expression>();
for (AggregateSymbol partitionAgg : aggregateExpressions) {
Expand All @@ -1112,19 +1112,21 @@ private static Map<AggregateSymbol, Expression> buildAggregateMap(Collection<? e
if (aggFunction == Type.COUNT) {
//COUNT(x) -> IFNULL(CONVERT(SUM(COUNT(x)), INTEGER), 0)
AggregateSymbol newAgg = null;
if (partitionAgg.getArgs().length == 0) {
if (partitionAgg.getArgs().length == 0 && join) {
//count * case (if on the inner side of an outer join)
Function ifnull = new Function(FunctionLibrary.IFNULL, new Expression[] {partitionAgg, new Constant(1, DataTypeManager.DefaultDataClasses.INTEGER)});
newAgg = new AggregateSymbol(NonReserved.SUM, false, ifnull);
} else {
newAgg = new AggregateSymbol(NonReserved.SUM, false, partitionAgg);
}
// Build conversion function to convert SUM (which returns LONG) back to INTEGER
Function convertFunc = new Function(FunctionLibrary.CONVERT, new Expression[] {newAgg, new Constant(DataTypeManager.getDataTypeName(partitionAgg.getType()))});
Function ifnull = new Function(FunctionLibrary.IFNULL, new Expression[] {convertFunc, new Constant(0, DataTypeManager.DefaultDataClasses.INTEGER)});
ResolverVisitor.resolveLanguageObject(ifnull, metadata);
Function func = new Function(FunctionLibrary.CONVERT, new Expression[] {newAgg, new Constant(DataTypeManager.getDataTypeName(partitionAgg.getType()))});
if (join) {
func = new Function(FunctionLibrary.IFNULL, new Expression[] {func, new Constant(0, DataTypeManager.DefaultDataClasses.INTEGER)});
}
ResolverVisitor.resolveLanguageObject(func, metadata);

newExpression = ifnull;
newExpression = func;
nestedAggregates.add(partitionAgg);
} else if (aggFunction == Type.AVG) {
//AVG(x) -> SUM(SUM(x)) / SUM(COUNT(x))
Expand Down
Expand Up @@ -41,6 +41,7 @@
import org.teiid.query.metadata.TempMetadataAdapter;
import org.teiid.query.metadata.TempMetadataStore;
import org.teiid.query.optimizer.QueryOptimizer;
import org.teiid.query.optimizer.TestAggregatePushdown;
import org.teiid.query.optimizer.TestOptimizer;
import org.teiid.query.optimizer.capabilities.BasicSourceCapabilities;
import org.teiid.query.optimizer.capabilities.CapabilitiesFinder;
Expand Down Expand Up @@ -619,5 +620,25 @@ public ProcessorPlan helpTestMultiSourcePlan(QueryMetadataInterface metadata, St
helpTestMultiSourcePlan(metadata, userSql, multiModel, sources, dataMgr, expected, RealMetadataFactory.exampleMultiBindingVDB(), null, new Options().implicitMultiSourceJoin(false), new BasicSourceCapabilities());
assertEquals(3, dataMgr.getCommandHistory().size());
}

@Test public void testCountStar() throws Exception {
QueryMetadataInterface metadata = RealMetadataFactory.exampleMultiBinding();

final String userSql = "SELECT count(*) FROM MultiModel.Phys limit 100"; //$NON-NLS-1$
final String multiModel = "MultiModel"; //$NON-NLS-1$
final int sources = 2;
final List<?>[] expected = new List<?>[] {
Arrays.asList(4),
};
final HardcodedDataManager dataMgr = new HardcodedDataManager(metadata);
dataMgr.addData("SELECT COUNT(*) FROM Phys AS g_0", //$NON-NLS-1$
new List<?>[] {
Arrays.asList(2)});

BasicSourceCapabilities bsc = TestAggregatePushdown.getAggregateCapabilities();
bsc.setFunctionSupport("ifnull", true);

helpTestMultiSourcePlan(metadata, userSql, multiModel, sources, dataMgr, expected, RealMetadataFactory.exampleMultiBindingVDB(), null, new Options().implicitMultiSourceJoin(false), bsc);
}

}

0 comments on commit e6f487d

Please sign in to comment.