Skip to content

Commit

Permalink
Rename AggregateFunction#inputs to arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
assaf2 authored and losipiuk committed Feb 23, 2022
1 parent 55d9ac2 commit 083035d
Show file tree
Hide file tree
Showing 32 changed files with 192 additions and 192 deletions.
Expand Up @@ -27,28 +27,28 @@ public class AggregateFunction
{
private final String functionName;
private final Type outputType;
private final List<ConnectorExpression> inputs;
private final List<ConnectorExpression> arguments;
private final List<SortItem> sortItems;
private final boolean isDistinct;
private final Optional<ConnectorExpression> filter;

public AggregateFunction(
String aggregateFunctionName,
Type outputType,
List<ConnectorExpression> inputs,
List<ConnectorExpression> arguments,
List<SortItem> sortItems,
boolean isDistinct,
Optional<ConnectorExpression> filter)
{
if (isDistinct && inputs.isEmpty()) {
throw new IllegalArgumentException("DISTINCT requires inputs");
if (isDistinct && arguments.isEmpty()) {
throw new IllegalArgumentException("DISTINCT requires arguments");
}

this.functionName = requireNonNull(aggregateFunctionName, "aggregateFunctionName is null");
this.outputType = requireNonNull(outputType, "outputType is null");
requireNonNull(inputs, "inputs is null");
requireNonNull(arguments, "arguments is null");
requireNonNull(sortItems, "sortItems is null");
this.inputs = List.copyOf(inputs);
this.arguments = List.copyOf(arguments);
this.sortItems = List.copyOf(sortItems);
this.isDistinct = isDistinct;
this.filter = requireNonNull(filter, "filter is null");
Expand All @@ -59,9 +59,9 @@ public String getFunctionName()
return functionName;
}

public List<ConnectorExpression> getInputs()
public List<ConnectorExpression> getArguments()
{
return inputs;
return arguments;
}

public Type getOutputType()
Expand Down Expand Up @@ -89,7 +89,7 @@ public String toString()
{
return new StringJoiner(", ", AggregateFunction.class.getSimpleName() + "[", "]")
.add("aggregationName='" + functionName + "'")
.add("inputs=" + inputs)
.add("arguments=" + arguments)
.add("outputType=" + outputType)
.add("sortOrder=" + sortItems)
.add("isDistinct=" + isDistinct)
Expand All @@ -111,7 +111,7 @@ public boolean equals(Object o)
AggregateFunction that = (AggregateFunction) o;
return isDistinct == that.isDistinct &&
Objects.equals(functionName, that.functionName) &&
Objects.equals(inputs, that.inputs) &&
Objects.equals(arguments, that.arguments) &&
Objects.equals(outputType, that.outputType) &&
Objects.equals(sortItems, that.sortItems) &&
Objects.equals(filter, that.filter);
Expand All @@ -120,6 +120,6 @@ public boolean equals(Object o)
@Override
public int hashCode()
{
return Objects.hash(functionName, inputs, outputType, sortItems, isDistinct, filter);
return Objects.hash(functionName, arguments, outputType, sortItems, isDistinct, filter);
}
}
Expand Up @@ -54,19 +54,19 @@ public static Pattern<AggregateFunction> basicAggregation()
return Property.property("outputType", AggregateFunction::getOutputType);
}

public static Property<AggregateFunction, ?, List<ConnectorExpression>> inputs()
public static Property<AggregateFunction, ?, List<ConnectorExpression>> arguments()
{
return Property.property("inputs", AggregateFunction::getInputs);
return Property.property("arguments", AggregateFunction::getArguments);
}

public static Property<AggregateFunction, ?, ConnectorExpression> singleInput()
public static Property<AggregateFunction, ?, ConnectorExpression> singleArgument()
{
return Property.optionalProperty("inputs", aggregateFunction -> {
List<ConnectorExpression> inputs = aggregateFunction.getInputs();
if (inputs.size() != 1) {
return Property.optionalProperty("arguments", aggregateFunction -> {
List<ConnectorExpression> arguments = aggregateFunction.getArguments();
if (arguments.size() != 1) {
return Optional.empty();
}
return Optional.of(inputs.get(0));
return Optional.of(arguments.get(0));
});
}

Expand Down
Expand Up @@ -31,7 +31,7 @@
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
Expand All @@ -45,29 +45,29 @@
public abstract class BaseImplementAvgBigint
implements AggregateFunctionRule<JdbcExpression>
{
private final Capture<Variable> input;
private final Capture<Variable> argument;

public BaseImplementAvgBigint()
{
this.input = newCapture();
this.argument = newCapture();
}

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
.with(singleArgument().matching(
variable()
.with(expressionType().matching(type -> type == BIGINT))
.capturedAs(this.input)));
.capturedAs(this.argument)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(this.input);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(this.argument);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == DOUBLE);

String columnName = context.getIdentifierQuote().apply(columnHandle.getColumnName());
Expand Down
Expand Up @@ -30,7 +30,7 @@
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static java.lang.String.format;

Expand All @@ -40,24 +40,24 @@
public class ImplementAvgDecimal
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();
private static final Capture<Variable> ARGUMENT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
.with(singleArgument().matching(
variable()
.with(expressionType().matching(DecimalType.class::isInstance))
.capturedAs(INPUT)));
.capturedAs(ARGUMENT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
DecimalType type = (DecimalType) columnHandle.getColumnType();
verify(aggregateFunction.getOutputType().equals(type));

Expand Down
Expand Up @@ -29,7 +29,7 @@
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
Expand All @@ -41,24 +41,24 @@
public class ImplementAvgFloatingPoint
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();
private static final Capture<Variable> ARGUMENT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
.with(singleArgument().matching(
variable()
.with(expressionType().matching(type -> type == REAL || type == DOUBLE))
.capturedAs(INPUT)));
.capturedAs(ARGUMENT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == columnHandle.getColumnType());

return Optional.of(new JdbcExpression(
Expand Down
Expand Up @@ -27,10 +27,10 @@

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.arguments;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.expressionTypes;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.inputs;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variables;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
Expand All @@ -39,27 +39,27 @@
public class ImplementCorr
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<List<Variable>> INPUTS = newCapture();
private static final Capture<List<Variable>> ARGUMENTS = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("corr"))
.with(inputs().matching(
.with(arguments().matching(
variables()
.matching(expressionTypes(REAL, REAL).or(expressionTypes(DOUBLE, DOUBLE)))
.capturedAs(INPUTS)));
.capturedAs(ARGUMENTS)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
List<Variable> inputs = captures.get(INPUTS);
verify(inputs.size() == 2);
List<Variable> arguments = captures.get(ARGUMENTS);
verify(arguments.size() == 2);

JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(inputs.get(0).getName());
JdbcColumnHandle columnHandle2 = (JdbcColumnHandle) context.getAssignment(inputs.get(1).getName());
JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(arguments.get(0).getName());
JdbcColumnHandle columnHandle2 = (JdbcColumnHandle) context.getAssignment(arguments.get(1).getName());
verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType()));

return Optional.of(new JdbcExpression(
Expand Down
Expand Up @@ -31,7 +31,7 @@
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.lang.String.format;
Expand All @@ -43,7 +43,7 @@
public class ImplementCount
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();
private static final Capture<Variable> ARGUMENT = newCapture();

private final JdbcTypeHandle bigintTypeHandle;

Expand All @@ -60,14 +60,14 @@ public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("count"))
.with(singleInput().matching(variable().capturedAs(INPUT)));
.with(singleArgument().matching(variable().capturedAs(ARGUMENT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == BIGINT);

return Optional.of(new JdbcExpression(
Expand Down
Expand Up @@ -26,9 +26,9 @@
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.arguments;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.inputs;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.util.Objects.requireNonNull;

Expand All @@ -53,7 +53,7 @@ public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("count"))
.with(inputs().equalTo(List.of()));
.with(arguments().equalTo(List.of()));
}

@Override
Expand Down
Expand Up @@ -34,7 +34,7 @@
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.distinct;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.hasFilter;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument;
import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.lang.String.format;
Expand All @@ -46,7 +46,7 @@
public class ImplementCountDistinct
implements AggregateFunctionRule<JdbcExpression>
{
private static final Capture<Variable> INPUT = newCapture();
private static final Capture<Variable> ARGUMENT = newCapture();

private final JdbcTypeHandle bigintTypeHandle;
private final boolean isRemoteCollationSensitive;
Expand All @@ -67,14 +67,14 @@ public Pattern<AggregateFunction> getPattern()
.with(distinct().equalTo(true))
.with(hasFilter().equalTo(false))
.with(functionName().equalTo("count"))
.with(singleInput().matching(variable().capturedAs(INPUT)));
.with(singleArgument().matching(variable().capturedAs(ARGUMENT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == BIGINT);

boolean isCaseSensitiveType = columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType;
Expand Down

0 comments on commit 083035d

Please sign in to comment.