Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup literal #1019

Merged
merged 3 commits into from
Jun 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 12 additions & 147 deletions presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,17 @@
*/
package io.prestosql.metadata;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Ordering;
import com.google.common.primitives.Primitives;
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.airlift.slice.Slice;
import io.prestosql.block.BlockSerdeUtil;
import io.prestosql.operator.aggregation.ApproximateCountDistinctAggregation;
import io.prestosql.operator.aggregation.ApproximateDoublePercentileAggregations;
import io.prestosql.operator.aggregation.ApproximateDoublePercentileArrayAggregations;
Expand Down Expand Up @@ -159,12 +154,9 @@
import io.prestosql.operator.window.SqlWindowFunction;
import io.prestosql.operator.window.WindowFunctionSupplier;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockEncodingSerde;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.DynamicFilters;
import io.prestosql.sql.analyzer.FeaturesConfig;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
Expand Down Expand Up @@ -203,8 +195,6 @@

import javax.annotation.concurrent.ThreadSafe;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
Expand All @@ -222,7 +212,10 @@
import static io.prestosql.metadata.FunctionKind.AGGREGATE;
import static io.prestosql.metadata.FunctionKind.SCALAR;
import static io.prestosql.metadata.FunctionKind.WINDOW;
import static io.prestosql.metadata.LiteralFunction.LITERAL_FUNCTION_NAME;
import static io.prestosql.metadata.LiteralFunction.getLiteralFunctionSignature;
import static io.prestosql.metadata.Signature.internalOperator;
import static io.prestosql.metadata.Signature.mangleOperatorName;
import static io.prestosql.metadata.SignatureBinder.applyBoundVariables;
import static io.prestosql.operator.aggregation.ArbitraryAggregationFunction.ARBITRARY_AGGREGATION;
import static io.prestosql.operator.aggregation.ChecksumAggregationFunction.CHECKSUM_AGGREGATION;
Expand Down Expand Up @@ -293,7 +286,6 @@
import static io.prestosql.operator.scalar.RowNotEqualOperator.ROW_NOT_EQUAL;
import static io.prestosql.operator.scalar.RowToJsonCast.ROW_TO_JSON;
import static io.prestosql.operator.scalar.RowToRowCast.ROW_TO_ROW_CAST;
import static io.prestosql.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty;
import static io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL;
import static io.prestosql.operator.scalar.TryCastFunction.TRY_CAST;
import static io.prestosql.operator.scalar.ZipFunction.ZIP_FUNCTIONS;
Expand All @@ -302,11 +294,7 @@
import static io.prestosql.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_CALL;
import static io.prestosql.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
import static io.prestosql.spi.StandardErrorCode.FUNCTION_NOT_FOUND;
import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.spi.type.BooleanType.BOOLEAN;
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.spi.type.TypeSignature.parseTypeSignature;
import static io.prestosql.spi.type.VarbinaryType.VARBINARY;
import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypeSignatures;
import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.prestosql.type.DecimalCasts.BIGINT_TO_DECIMAL_CAST;
Expand Down Expand Up @@ -358,26 +346,20 @@
@ThreadSafe
public class FunctionRegistry
{
private static final String MAGIC_LITERAL_FUNCTION_PREFIX = "$literal$";
private static final String OPERATOR_PREFIX = "$operator$";

// hack: java classes for types that can be used with magic literals
private static final Set<Class<?>> SUPPORTED_LITERAL_TYPES = ImmutableSet.of(long.class, double.class, Slice.class, boolean.class);

private final Metadata metadata;
private final TypeCoercion typeCoercion;
private final LoadingCache<Signature, SpecializedFunctionKey> specializedFunctionKeyCache;
private final LoadingCache<SpecializedFunctionKey, ScalarFunctionImplementation> specializedScalarCache;
private final LoadingCache<SpecializedFunctionKey, InternalAggregationFunction> specializedAggregationCache;
private final LoadingCache<SpecializedFunctionKey, WindowFunctionSupplier> specializedWindowCache;
private final MagicLiteralFunction magicLiteralFunction;
private final LiteralFunction literalFunction;
private volatile FunctionMap functions = new FunctionMap();

public FunctionRegistry(Metadata metadata, FeaturesConfig featuresConfig)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.typeCoercion = new TypeCoercion(metadata::getType);
this.magicLiteralFunction = new MagicLiteralFunction(metadata.getBlockEncodingSerde());
this.literalFunction = new LiteralFunction();

specializedFunctionKeyCache = CacheBuilder.newBuilder()
.maximumSize(1000)
Expand Down Expand Up @@ -727,9 +709,9 @@ Signature resolveFunction(QualifiedName name, List<TypeSignatureProvider> parame
message = format("Unexpected parameters (%s) for function %s. Expected: %s", parameters, name, expected);
}

if (name.getSuffix().startsWith(MAGIC_LITERAL_FUNCTION_PREFIX)) {
if (name.getSuffix().startsWith(LITERAL_FUNCTION_NAME)) {
// extract type from function name
String typeName = name.getSuffix().substring(MAGIC_LITERAL_FUNCTION_PREFIX.length());
String typeName = name.getSuffix().substring(LITERAL_FUNCTION_NAME.length());

// lookup the type
Type type = metadata.getType(parseTypeSignature(typeName));
Expand All @@ -738,7 +720,7 @@ Signature resolveFunction(QualifiedName name, List<TypeSignatureProvider> parame
checkArgument(parameterTypes.size() == 1, "Expected one argument to literal function, but got %s", parameterTypes);
metadata.getType(parameterTypes.get(0).getTypeSignature());

return getMagicLiteralFunctionSignature(type);
return getLiteralFunctionSignature(type);
}

throw new PrestoException(FUNCTION_NOT_FOUND, message);
Expand Down Expand Up @@ -1024,10 +1006,10 @@ private SpecializedFunctionKey doGetSpecializedFunctionKey(Signature signature)
}

// TODO: this is a hack and should be removed
if (signature.getName().startsWith(MAGIC_LITERAL_FUNCTION_PREFIX)) {
if (signature.getName().startsWith(LITERAL_FUNCTION_NAME)) {
List<TypeSignature> parameterTypes = signature.getArgumentTypes();
// extract type from function name
String typeName = signature.getName().substring(MAGIC_LITERAL_FUNCTION_PREFIX.length());
String typeName = signature.getName().substring(LITERAL_FUNCTION_NAME.length());

// lookup the type
Type type = metadata.getType(parseTypeSignature(typeName));
Expand All @@ -1038,7 +1020,7 @@ private SpecializedFunctionKey doGetSpecializedFunctionKey(Signature signature)
requireNonNull(parameterType, format("Type %s not found", parameterTypes.get(0)));

return new SpecializedFunctionKey(
magicLiteralFunction,
literalFunction,
BoundVariables.builder()
.setTypeVariable("T", parameterType)
.setTypeVariable("R", type)
Expand Down Expand Up @@ -1087,7 +1069,7 @@ public Signature resolveOperator(OperatorType operatorType, List<? extends Type>

public Signature getCoercion(TypeSignature fromType, TypeSignature toType)
{
Signature signature = internalOperator(OperatorType.CAST.name(), toType, ImmutableList.of(fromType));
Signature signature = internalOperator(OperatorType.CAST, toType, ImmutableList.of(fromType));
try {
getScalarFunctionImplementation(signature);
}
Expand All @@ -1100,63 +1082,6 @@ public Signature getCoercion(TypeSignature fromType, TypeSignature toType)
return signature;
}

public static Type typeForMagicLiteral(Type type)
{
Class<?> clazz = type.getJavaType();
clazz = Primitives.unwrap(clazz);

if (clazz == long.class) {
return BIGINT;
}
if (clazz == double.class) {
return DOUBLE;
}
if (!clazz.isPrimitive()) {
if (type instanceof VarcharType) {
return type;
}
else {
return VARBINARY;
}
}
if (clazz == boolean.class) {
return BOOLEAN;
}
throw new IllegalArgumentException("Unhandled Java type: " + clazz.getName());
}

public static Signature getMagicLiteralFunctionSignature(Type type)
{
TypeSignature argumentType = typeForMagicLiteral(type).getTypeSignature();

return new Signature(MAGIC_LITERAL_FUNCTION_PREFIX + type.getTypeSignature(),
SCALAR,
type.getTypeSignature(),
argumentType);
}

public static boolean isSupportedLiteralType(Type type)
{
return SUPPORTED_LITERAL_TYPES.contains(type.getJavaType());
}

public static String mangleOperatorName(OperatorType operatorType)
{
return mangleOperatorName(operatorType.name());
}

public static String mangleOperatorName(String operatorName)
{
return OPERATOR_PREFIX + operatorName;
}

@VisibleForTesting
public static OperatorType unmangleOperator(String mangledName)
{
checkArgument(mangledName.startsWith(OPERATOR_PREFIX), "%s is not a mangled operator name", mangledName);
return OperatorType.valueOf(mangledName.substring(OPERATOR_PREFIX.length()));
}

private Optional<List<Type>> toTypes(List<TypeSignatureProvider> typeSignatureProviders)
{
ImmutableList.Builder<Type> resultBuilder = ImmutableList.builder();
Expand Down Expand Up @@ -1248,64 +1173,4 @@ public String toString()
.toString();
}
}

private static class MagicLiteralFunction
extends SqlScalarFunction
{
private final BlockEncodingSerde blockEncodingSerde;

public MagicLiteralFunction(BlockEncodingSerde blockEncodingSerde)
{
super(new Signature(MAGIC_LITERAL_FUNCTION_PREFIX, FunctionKind.SCALAR, TypeSignature.parseTypeSignature("R"), TypeSignature.parseTypeSignature("T")));
this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null");
}

@Override
public boolean isHidden()
{
return true;
}

@Override
public boolean isDeterministic()
{
return true;
}

@Override
public String getDescription()
{
return "magic literal";
}

@Override
public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, Metadata metadata)
{
Type parameterType = boundVariables.getTypeVariable("T");
Type type = boundVariables.getTypeVariable("R");

MethodHandle methodHandle = null;
if (parameterType.getJavaType() == type.getJavaType()) {
methodHandle = MethodHandles.identity(parameterType.getJavaType());
}

if (parameterType.getJavaType() == Slice.class) {
if (type.getJavaType() == Block.class) {
methodHandle = BlockSerdeUtil.READ_BLOCK.bindTo(blockEncodingSerde);
}
}

checkArgument(methodHandle != null,
"Expected type %s to use (or can be converted into) Java type %s, but Java type is %s",
type,
parameterType.getJavaType(),
type.getJavaType());

return new ScalarFunctionImplementation(
false,
ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)),
methodHandle,
isDeterministic());
}
}
}
Loading