Skip to content

Commit

Permalink
Refactor UnaliasSymbolReferences and SymbolMapper
Browse files Browse the repository at this point in the history
In UnaliasSymbolReferences:
Introduce context, including inherited mapping for correlation symbols.
Introduce return type for visit() methods, including:
    - rewritten plan,
    - symbol mappings (both inherited and derived).

Previously, there was global symbol mapping that changed each time
when new symbol mapping was defined. Hence, once a mapping for a symbol
was defined, it was applied to all occurrences of the symbol that
were encountered later.
This approach caused errors in case of reused (duplicate) symbols
in different branches of the plan.

Symbol mapping is now performed only by SymbolMapper.
Previously, some mapping methods were defined in
UnaliasSymbolReferences and some code was copied between
Unalias and SymbolMapper.
  • Loading branch information
kasiafi authored and martint committed Jun 30, 2020
1 parent 769bfcb commit ce94183
Show file tree
Hide file tree
Showing 3 changed files with 898 additions and 521 deletions.
Expand Up @@ -2,7 +2,7 @@ local exchange (GATHER, SINGLE, [])
remote exchange (GATHER, SINGLE, [])
final aggregation over (concat_306, d_year_307, w_city_302, w_country_305, w_county_303, w_state_304, w_warehouse_name_300, w_warehouse_sq_ft_301)
local exchange (REPARTITION, HASH, ["concat_306", "d_year_307", "w_city_302", "w_country_305", "w_county_303", "w_state_304", "w_warehouse_name_300", "w_warehouse_sq_ft_301"])
partial aggregation over (concat_603, d_year, w_city, w_country, w_county, w_state, w_warehouse_name, w_warehouse_sq_ft)
partial aggregation over (concat_632, d_year, w_city, w_country, w_county, w_state, w_warehouse_name, w_warehouse_sq_ft)
final aggregation over (d_year, w_city, w_country, w_county, w_state, w_warehouse_name, w_warehouse_sq_ft)
local exchange (GATHER, SINGLE, [])
remote exchange (REPARTITION, HASH, ["d_year", "w_city", "w_country", "w_county", "w_state", "w_warehouse_name", "w_warehouse_sq_ft"])
Expand All @@ -24,7 +24,7 @@ local exchange (GATHER, SINGLE, [])
local exchange (GATHER, SINGLE, [])
remote exchange (REPLICATE, BROADCAST, [])
scan warehouse
partial aggregation over (concat_647, d_year_136, w_city_124, w_country_128, w_county_125, w_state_126, w_warehouse_name_118, w_warehouse_sq_ft_119)
partial aggregation over (concat_676, d_year_136, w_city_124, w_country_128, w_county_125, w_state_126, w_warehouse_name_118, w_warehouse_sq_ft_119)
final aggregation over (d_year_136, w_city_124, w_country_128, w_county_125, w_state_126, w_warehouse_name_118, w_warehouse_sq_ft_119)
local exchange (GATHER, SINGLE, [])
remote exchange (REPARTITION, HASH, ["d_year_136", "w_city_124", "w_country_128", "w_county_125", "w_state_126", "w_warehouse_name_118", "w_warehouse_sq_ft_119"])
Expand Down
Expand Up @@ -21,20 +21,25 @@
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.AggregationNode.Aggregation;
import io.prestosql.sql.planner.plan.DistinctLimitNode;
import io.prestosql.sql.planner.plan.GroupIdNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.sql.planner.plan.RowNumberNode;
import io.prestosql.sql.planner.plan.StatisticAggregations;
import io.prestosql.sql.planner.plan.StatisticAggregationsDescriptor;
import io.prestosql.sql.planner.plan.StatisticsWriterNode;
import io.prestosql.sql.planner.plan.TableFinishNode;
import io.prestosql.sql.planner.plan.TableWriterNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.TopNRowNumberNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.SymbolReference;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -43,6 +48,7 @@

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.prestosql.sql.planner.plan.AggregationNode.groupingSets;
import static java.util.Objects.requireNonNull;

Expand All @@ -55,6 +61,12 @@ public SymbolMapper(Map<Symbol, Symbol> mapping)
this.mapping = ImmutableMap.copyOf(requireNonNull(mapping, "mapping is null"));
}

public Map<Symbol, Symbol> getMapping()
{
return mapping;
}

// Return the canonical mapping for the symbol.
public Symbol map(Symbol symbol)
{
while (mapping.containsKey(symbol) && !mapping.get(symbol).equals(symbol)) {
Expand All @@ -63,6 +75,21 @@ public Symbol map(Symbol symbol)
return symbol;
}

public List<Symbol> map(List<Symbol> symbols)
{
return symbols.stream()
.map(this::map)
.collect(toImmutableList());
}

public List<Symbol> mapAndDistinct(List<Symbol> symbols)
{
return symbols.stream()
.map(this::map)
.distinct()
.collect(toImmutableList());
}

public Expression map(Expression expression)
{
return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
Expand Down Expand Up @@ -115,14 +142,73 @@ private Aggregation map(Aggregation aggregation)
aggregation.getMask().map(this::map));
}

public TopNNode map(TopNNode node, PlanNode source, PlanNodeId newNodeId)
public GroupIdNode map(GroupIdNode node, PlanNode source)
{
return new TopNNode(
newNodeId,
Map<Symbol, Symbol> newGroupingMappings = new HashMap<>();
ImmutableList.Builder<List<Symbol>> newGroupingSets = ImmutableList.builder();

for (List<Symbol> groupingSet : node.getGroupingSets()) {
ImmutableList.Builder<Symbol> newGroupingSet = ImmutableList.builder();
for (Symbol output : groupingSet) {
Symbol newOutput = map(output);
newGroupingMappings.putIfAbsent(
newOutput,
map(node.getGroupingColumns().get(output)));
newGroupingSet.add(newOutput);
}
newGroupingSets.add(newGroupingSet.build());
}

return new GroupIdNode(
node.getId(),
source,
node.getCount(),
map(node.getOrderingScheme()),
node.getStep());
newGroupingSets.build(),
newGroupingMappings,
mapAndDistinct(node.getAggregationArguments()),
map(node.getGroupIdSymbol()));
}

public WindowNode map(WindowNode node, PlanNode source)
{
ImmutableMap.Builder<Symbol, WindowNode.Function> newFunctions = ImmutableMap.builder();
node.getWindowFunctions().forEach((symbol, function) -> {
List<Expression> newArguments = function.getArguments().stream()
.map(this::map)
.collect(toImmutableList());
WindowNode.Frame newFrame = map(function.getFrame());

newFunctions.put(map(symbol), new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls()));
});

return new WindowNode(
node.getId(),
source,
mapAndDistinct(node.getSpecification()),
newFunctions.build(),
node.getHashSymbol().map(this::map),
node.getPrePartitionedInputs().stream()
.map(this::map)
.collect(toImmutableSet()),
node.getPreSortedOrderPrefix());
}

private WindowNode.Frame map(WindowNode.Frame frame)
{
return new WindowNode.Frame(
frame.getType(),
frame.getStartType(),
frame.getStartValue().map(this::map),
frame.getEndType(),
frame.getEndValue().map(this::map),
frame.getOriginalStartValue(),
frame.getOriginalEndValue());
}

private WindowNode.Specification mapAndDistinct(WindowNode.Specification specification)
{
return new WindowNode.Specification(
mapAndDistinct(specification.getPartitionBy()),
specification.getOrderingScheme().map(this::map));
}

public LimitNode map(LimitNode node, PlanNode source)
Expand All @@ -135,30 +221,30 @@ public LimitNode map(LimitNode node, PlanNode source)
node.isPartial());
}

public TableWriterNode map(TableWriterNode node, PlanNode source)
public OrderingScheme map(OrderingScheme orderingScheme)
{
return map(node, source, node.getId());
ImmutableList.Builder<Symbol> newSymbols = ImmutableList.builder();
ImmutableMap.Builder<Symbol, SortOrder> newOrderings = ImmutableMap.builder();
Set<Symbol> added = new HashSet<>(orderingScheme.getOrderBy().size());
for (Symbol symbol : orderingScheme.getOrderBy()) {
Symbol canonical = map(symbol);
if (added.add(canonical)) {
newSymbols.add(canonical);
newOrderings.put(canonical, orderingScheme.getOrdering(symbol));
}
}
return new OrderingScheme(newSymbols.build(), newOrderings.build());
}

public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId newNodeId)
public DistinctLimitNode map(DistinctLimitNode node, PlanNode source)
{
// Intentionally does not use canonicalizeAndDistinct as that would remove columns
ImmutableList<Symbol> columns = node.getColumns().stream()
.map(this::map)
.collect(toImmutableList());

return new TableWriterNode(
newNodeId,
return new DistinctLimitNode(
node.getId(),
source,
node.getTarget(),
map(node.getRowCountSymbol()),
map(node.getFragmentSymbol()),
columns,
node.getColumnNames(),
node.getNotNullColumnSymbols(),
node.getPartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source)),
node.getStatisticsAggregation().map(this::map),
node.getStatisticsAggregationDescriptor().map(this::map));
node.getLimit(),
node.isPartial(),
mapAndDistinct(node.getDistinctSymbols()),
node.getHashSymbol().map(this::map));
}

public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source)
Expand All @@ -167,80 +253,101 @@ public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source)
node.getId(),
source,
node.getTarget(),
node.getRowCountSymbol(),
map(node.getRowCountSymbol()),
node.isRowCountEnabled(),
node.getDescriptor().map(this::map));
}

public TableFinishNode map(TableFinishNode node, PlanNode source)
public TableWriterNode map(TableWriterNode node, PlanNode source)
{
return new TableFinishNode(
node.getId(),
return map(node, source, node.getId());
}

public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId newId)
{
// Intentionally does not use mapAndDistinct on columns as that would remove columns
return new TableWriterNode(
newId,
source,
node.getTarget(),
map(node.getRowCountSymbol()),
map(node.getFragmentSymbol()),
map(node.getColumns()),
node.getColumnNames(),
node.getNotNullColumnSymbols(),
node.getPartitioningScheme().map(partitioningScheme -> map(partitioningScheme, source.getOutputSymbols())),
node.getStatisticsAggregation().map(this::map),
node.getStatisticsAggregationDescriptor().map(descriptor -> descriptor.map(this::map)));
}

private PartitioningScheme canonicalize(PartitioningScheme scheme, PlanNode source)
public PartitioningScheme map(PartitioningScheme scheme, List<Symbol> sourceLayout)
{
return new PartitioningScheme(
scheme.getPartitioning().translate(this::map),
mapAndDistinct(source.getOutputSymbols()),
mapAndDistinct(sourceLayout),
scheme.getHashColumn().map(this::map),
scheme.isReplicateNullsAndAny(),
scheme.getBucketToPartition());
}

public TableFinishNode map(TableFinishNode node, PlanNode source)
{
return new TableFinishNode(
node.getId(),
source,
node.getTarget(),
map(node.getRowCountSymbol()),
node.getStatisticsAggregation().map(this::map),
node.getStatisticsAggregationDescriptor().map(descriptor -> descriptor.map(this::map)));
}

private StatisticAggregations map(StatisticAggregations statisticAggregations)
{
Map<Symbol, Aggregation> aggregations = statisticAggregations.getAggregations().entrySet().stream()
.collect(toImmutableMap(entry -> map(entry.getKey()), entry -> map(entry.getValue())));
return new StatisticAggregations(aggregations, mapAndDistinct(statisticAggregations.getGroupingSymbols()));
}

private StatisticAggregationsDescriptor<Symbol> map(StatisticAggregationsDescriptor<Symbol> descriptor)
public RowNumberNode map(RowNumberNode node, PlanNode source)
{
return descriptor.map(this::map);
return new RowNumberNode(
node.getId(),
source,
mapAndDistinct(node.getPartitionBy()),
node.isOrderSensitive(),
map(node.getRowNumberSymbol()),
node.getMaxRowCountPerPartition(),
node.getHashSymbol().map(this::map));
}

private List<Symbol> map(List<Symbol> outputs)
public TopNRowNumberNode map(TopNRowNumberNode node, PlanNode source)
{
return outputs.stream()
.map(this::map)
.collect(toImmutableList());
return new TopNRowNumberNode(
node.getId(),
source,
mapAndDistinct(node.getSpecification()),
map(node.getRowNumberSymbol()),
node.getMaxRowCountPerPartition(),
node.isPartial(),
node.getHashSymbol().map(this::map));
}

private List<Symbol> mapAndDistinct(List<Symbol> outputs)
public TopNNode map(TopNNode node, PlanNode source)
{
Set<Symbol> added = new HashSet<>();
ImmutableList.Builder<Symbol> builder = ImmutableList.builder();
for (Symbol symbol : outputs) {
Symbol canonical = map(symbol);
if (added.add(canonical)) {
builder.add(canonical);
}
}
return builder.build();
return map(node, source, node.getId());
}

private OrderingScheme map(OrderingScheme orderingScheme)
public TopNNode map(TopNNode node, PlanNode source, PlanNodeId nodeId)
{
ImmutableList.Builder<Symbol> symbols = ImmutableList.builder();
ImmutableMap.Builder<Symbol, SortOrder> orderings = ImmutableMap.builder();
Set<Symbol> seenCanonicals = new HashSet<>(orderingScheme.getOrderBy().size());
for (Symbol symbol : orderingScheme.getOrderBy()) {
Symbol canonical = map(symbol);
if (seenCanonicals.add(canonical)) {
symbols.add(canonical);
orderings.put(canonical, orderingScheme.getOrdering(symbol));
}
}
return new OrderingScheme(symbols.build(), orderings.build());
return new TopNNode(
nodeId,
source,
node.getCount(),
map(node.getOrderingScheme()),
node.getStep());
}

public static SymbolMapper.Builder builder()
public static Builder builder()
{
return new Builder();
}
Expand Down

0 comments on commit ce94183

Please sign in to comment.