diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java index fda8c13a8659..234484814d39 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java @@ -180,6 +180,7 @@ import static com.facebook.presto.metadata.FunctionKind.SCALAR; import static com.facebook.presto.metadata.FunctionKind.WINDOW; import static com.facebook.presto.metadata.Signature.internalOperator; +import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; import static com.facebook.presto.operator.aggregation.ArbitraryAggregationFunction.ARBITRARY_AGGREGATION; import static com.facebook.presto.operator.aggregation.ChecksumAggregationFunction.CHECKSUM_AGGREGATION; import static com.facebook.presto.operator.aggregation.CountColumn.COUNT_COLUMN; @@ -783,52 +784,43 @@ public WindowFunctionSupplier getWindowFunctionImplementation(Signature signatur { checkArgument(signature.getKind() == WINDOW || signature.getKind() == AGGREGATE, "%s is not a window function", signature); checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); - Iterable candidates = functions.get(QualifiedName.of(signature.getName())); - // search for exact match - for (SqlFunction operator : candidates) { - Type returnType = typeManager.getType(signature.getReturnType()); - List argumentTypes = resolveTypes(signature.getArgumentTypes(), typeManager); - Optional boundVariables = new SignatureBinder(typeManager, operator.getSignature(), false) - .bindVariables(argumentTypes, returnType); - if (boundVariables.isPresent()) { - try { - return specializedWindowCache.getUnchecked(new SpecializedFunctionKey(operator, boundVariables.get(), argumentTypes.size())); - } - catch (UncheckedExecutionException e) { - throw Throwables.propagate(e.getCause()); - } - } + + try { + return specializedWindowCache.getUnchecked(getSpecializedFunctionKey(signature)); + } + catch (UncheckedExecutionException e) { + throw Throwables.propagate(e.getCause()); } - throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("%s not found", signature)); } public InternalAggregationFunction getAggregateFunctionImplementation(Signature signature) { checkArgument(signature.getKind() == AGGREGATE || signature.getKind() == APPROXIMATE_AGGREGATE, "%s is not an aggregate function", signature); checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); - Iterable candidates = functions.get(QualifiedName.of(signature.getName())); - // search for exact match - for (SqlFunction operator : candidates) { - Type returnType = typeManager.getType(signature.getReturnType()); - List argumentTypes = resolveTypes(signature.getArgumentTypes(), typeManager); - Optional boundVariables = new SignatureBinder(typeManager, operator.getSignature(), false) - .bindVariables(argumentTypes, returnType); - if (boundVariables.isPresent()) { - try { - return specializedAggregationCache.getUnchecked(new SpecializedFunctionKey(operator, boundVariables.get(), signature.getArgumentTypes().size())); - } - catch (UncheckedExecutionException e) { - throw Throwables.propagate(e.getCause()); - } - } + + try { + return specializedAggregationCache.getUnchecked(getSpecializedFunctionKey(signature)); + } + catch (UncheckedExecutionException e) { + throw Throwables.propagate(e.getCause()); } - throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("%s not found", signature)); } public ScalarFunctionImplementation getScalarFunctionImplementation(Signature signature) { checkArgument(signature.getKind() == SCALAR, "%s is not a scalar function", signature); checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); + + try { + return specializedScalarCache.getUnchecked(getSpecializedFunctionKey(signature)); + } + catch (UncheckedExecutionException e) { + throw Throwables.propagate(e.getCause()); + } + } + + private SpecializedFunctionKey getSpecializedFunctionKey(Signature signature) + { Iterable candidates = functions.get(QualifiedName.of(signature.getName())); // search for exact match Type returnType = typeManager.getType(signature.getReturnType()); @@ -837,12 +829,7 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(Signature si Optional boundVariables = new SignatureBinder(typeManager, candidate.getSignature(), false) .bindVariables(argumentTypes, returnType); if (boundVariables.isPresent()) { - try { - return specializedScalarCache.getUnchecked(new SpecializedFunctionKey(candidate, boundVariables.get(), argumentTypes.size())); - } - catch (UncheckedExecutionException e) { - throw Throwables.propagate(e.getCause()); - } + return new SpecializedFunctionKey(candidate, boundVariables.get(), argumentTypes.size()); } } @@ -850,33 +837,28 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(Signature si // so do a second pass allowing "type only" coercions for (SqlFunction candidate : candidates) { SignatureBinder binder = new SignatureBinder(typeManager, candidate.getSignature(), true); - Optional boundSignature = binder.bind(argumentTypes, returnType); - if (boundSignature.isPresent()) { - if (!typeManager.isTypeOnlyCoercion(typeManager.getType(boundSignature.get().getReturnType()), returnType)) { - continue; - } - boolean nonTypeOnlyCoercion = false; - for (int i = 0; i < argumentTypes.size(); i++) { - Type expectedType = typeManager.getType(boundSignature.get().getArgumentTypes().get(i)); - if (!typeManager.isTypeOnlyCoercion(argumentTypes.get(i), expectedType)) { - nonTypeOnlyCoercion = true; - break; - } - } - if (nonTypeOnlyCoercion) { - continue; - } + Optional boundVariables = binder.bindVariables(argumentTypes, returnType); + if (!boundVariables.isPresent()) { + continue; } + Signature boundSignature = applyBoundVariables(candidate.getSignature(), boundVariables.get(), argumentTypes.size()); - Optional boundVariables = binder.bindVariables(argumentTypes, returnType); - if (boundVariables.isPresent()) { - try { - return specializedScalarCache.getUnchecked(new SpecializedFunctionKey(candidate, boundVariables.get(), argumentTypes.size())); - } - catch (UncheckedExecutionException e) { - throw Throwables.propagate(e.getCause()); + if (!typeManager.isTypeOnlyCoercion(typeManager.getType(boundSignature.getReturnType()), returnType)) { + continue; + } + boolean nonTypeOnlyCoercion = false; + for (int i = 0; i < argumentTypes.size(); i++) { + Type expectedType = typeManager.getType(boundSignature.getArgumentTypes().get(i)); + if (!typeManager.isTypeOnlyCoercion(argumentTypes.get(i), expectedType)) { + nonTypeOnlyCoercion = true; + break; } } + if (nonTypeOnlyCoercion) { + continue; + } + + return new SpecializedFunctionKey(candidate, boundVariables.get(), argumentTypes.size()); } // TODO: this is a hack and should be removed @@ -894,20 +876,13 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(Signature si Type parameterType = typeManager.getType(parameterTypes.get(0)); requireNonNull(parameterType, format("Type %s not found", parameterTypes.get(0))); - SpecializedFunctionKey specializedFunctionKey = new SpecializedFunctionKey( + return new SpecializedFunctionKey( magicLiteralFunction, BoundVariables.builder() .setTypeVariable("T", parameterType) .setTypeVariable("R", type) .build(), 1); - - try { - return specializedScalarCache.getUnchecked(specializedFunctionKey); - } - catch (UncheckedExecutionException e) { - throw Throwables.propagate(e.getCause()); - } } throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("%s not found", signature));