Skip to content

Commit

Permalink
Improve and add tests for scalar validation
Browse files Browse the repository at this point in the history
  • Loading branch information
electrum committed Jul 8, 2016
1 parent b72db45 commit 8ef8309
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 36 deletions.
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.type.SqlType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
Expand Down Expand Up @@ -66,7 +67,13 @@ private static List<ScalarHeaderAndMethods> findScalarsInFunctionDefinitionClass
checkArgument(!classHeaders.isEmpty(), "Class that defines function must be annotated with @ScalarFunction or @ScalarOperator.");

for (ScalarImplementationHeader header : classHeaders) {
builder.add(new ScalarHeaderAndMethods(header, findPublicMethodsWithAnnotation(annotated, SqlType.class)));
Set<Method> methods = findPublicMethodsWithAnnotation(annotated, SqlType.class, ScalarFunction.class, ScalarOperator.class);
checkArgument(!methods.isEmpty(), "Parametric class %s does not have any annotated methods", annotated.getName());
for (Method method : methods) {
checkArgument(method.getAnnotation(ScalarFunction.class) == null, "Parametric class method [%s] is annotated with @ScalarFunction", method);
checkArgument(method.getAnnotation(ScalarOperator.class) == null, "Parametric class method [%s] is annotated with @ScalarOperator", method);
}
builder.add(new ScalarHeaderAndMethods(header, methods));
}

return builder.build();
Expand All @@ -75,9 +82,11 @@ private static List<ScalarHeaderAndMethods> findScalarsInFunctionDefinitionClass
private static List<ScalarHeaderAndMethods> findScalarsInFunctionSetClass(Class<?> annotated)
{
ImmutableList.Builder<ScalarHeaderAndMethods> builder = ImmutableList.builder();
for (Method method : findPublicMethodsWithAnnotation(annotated, SqlType.class)) {
for (Method method : findPublicMethodsWithAnnotation(annotated, SqlType.class, ScalarFunction.class, ScalarOperator.class)) {
checkArgument((method.getAnnotation(ScalarFunction.class) != null) || (method.getAnnotation(ScalarOperator.class) != null),
"Method [%s] annotated with @SqlType is missing @ScalarFunction or @ScalarOperator", method);
for (ScalarImplementationHeader header : ScalarImplementationHeader.fromAnnotatedElement(method)) {
builder.add(new ScalarHeaderAndMethods(header, ImmutableList.of(method)));
builder.add(new ScalarHeaderAndMethods(header, ImmutableSet.of(method)));
}
}
return builder.build();
Expand Down Expand Up @@ -142,14 +151,17 @@ private static Map<Set<TypeParameter>, Constructor<?>> findConstructors(Class<?>
return builder.build();
}

private static List<Method> findPublicMethodsWithAnnotation(Class<?> clazz, Class<?> annotationClass)
@SafeVarargs
private static Set<Method> findPublicMethodsWithAnnotation(Class<?> clazz, Class<? extends Annotation>... annotationClasses)
{
ImmutableList.Builder<Method> methods = ImmutableList.builder();
for (Method method : clazz.getMethods()) {
ImmutableSet.Builder<Method> methods = ImmutableSet.builder();
for (Method method : clazz.getDeclaredMethods()) {
for (Annotation annotation : method.getAnnotations()) {
if (annotationClass.isInstance(annotation)) {
checkArgument(Modifier.isPublic(method.getModifiers()), "%s annotated with %s must be public", method.getName(), annotationClass.getSimpleName());
methods.add(method);
for (Class<?> annotationClass : annotationClasses) {
if (annotationClass.isInstance(annotation)) {
checkArgument(Modifier.isPublic(method.getModifiers()), "Method [%s] annotated with @%s must be public", method, annotationClass.getSimpleName());
methods.add(method);
}
}
}
}
Expand All @@ -159,9 +171,9 @@ private static List<Method> findPublicMethodsWithAnnotation(Class<?> clazz, Clas
private static class ScalarHeaderAndMethods
{
private final ScalarImplementationHeader header;
private final List<Method> methods;
private final Set<Method> methods;

public ScalarHeaderAndMethods(ScalarImplementationHeader header, List<Method> methods)
public ScalarHeaderAndMethods(ScalarImplementationHeader header, Set<Method> methods)
{
this.header = requireNonNull(header);
this.methods = requireNonNull(methods);
Expand All @@ -172,7 +184,7 @@ public ScalarImplementationHeader getHeader()
return header;
}

public List<Method> getMethods()
public Set<Method> getMethods()
{
return methods;
}
Expand Down
Expand Up @@ -290,12 +290,15 @@ private Parser(String functionName, Method method, Map<Set<TypeParameter>, Const
}

SqlType returnType = method.getAnnotation(SqlType.class);
requireNonNull(returnType, format("%s is missing @SqlType annotation", method));
checkArgument(returnType != null, format("Method [%s] is missing @SqlType annotation", method));
this.returnType = parseTypeSignature(returnType.value(), literalParameters);

Class<?> actualReturnType = method.getReturnType();
if (Primitives.isWrapperType(actualReturnType)) {
checkArgument(nullable, "Method %s has return value with type %s that is missing @Nullable", method, actualReturnType);
checkArgument(nullable, "Method [%s] has wrapper return type %s but is missing @Nullable", method, actualReturnType.getSimpleName());
}
else if (actualReturnType.isPrimitive()) {
checkArgument(!nullable, "Method [%s] annotated with @Nullable has primitive return type %s", method, actualReturnType.getSimpleName());
}

Stream.of(method.getAnnotationsByType(Constraint.class))
Expand Down Expand Up @@ -324,34 +327,34 @@ private void parseArguments(Method method)
continue;
}
if (containsMetaParameter(annotations)) {
checkArgument(annotations.length == 1, "Meta parameters may only have a single annotation");
checkArgument(argumentTypes.isEmpty(), "Meta parameter must come before parameters");
checkArgument(annotations.length == 1, "Meta parameters may only have a single annotation [%s]", method);
checkArgument(argumentTypes.isEmpty(), "Meta parameter must come before parameters [%s]", method);
Annotation annotation = annotations[0];
if (annotation instanceof TypeParameter) {
checkArgument(typeParameters.contains(annotation), "Injected type parameters must be declared with @TypeParameter annotation on the method");
checkArgument(typeParameters.contains(annotation), "Injected type parameters must be declared with @TypeParameter annotation on the method [%s]", method);
}
dependencies.add(parseDependency(annotation));
}
else {
SqlType type = null;
boolean nullableArgument = false;
for (Annotation annotation : annotations) {
if (annotation instanceof SqlType) {
type = (SqlType) annotation;
}
if (annotation instanceof Nullable) {
nullableArgument = true;
}
}
requireNonNull(type, format("@SqlType annotation missing for argument to %s", method));
SqlType type = Stream.of(annotations)
.filter(SqlType.class::isInstance)
.map(SqlType.class::cast)
.findFirst()
.orElseThrow(() -> new IllegalArgumentException(format("Method [%s] is missing @SqlType annotation for parameter", method)));
boolean nullableArgument = Stream.of(annotations).anyMatch(Nullable.class::isInstance);

if (Primitives.isWrapperType(parameterType)) {
checkArgument(nullableArgument, "Method %s has parameter with type %s that is missing @Nullable", method, parameterType);
checkArgument(nullableArgument, "Method [%s] has parameter with wrapper type %s that is missing @Nullable", method, parameterType.getSimpleName());
}
else if (parameterType.isPrimitive()) {
checkArgument(!nullableArgument, "Method [%s] has parameter with primitive type %s annotated with @Nullable", method, parameterType.getSimpleName());
}

if (typeParameterNames.contains(type.value()) && !(parameterType == Object.class && nullableArgument)) {
// Infer specialization on this type parameter. We don't do this for @Nullable Object because it could match a type like BIGINT
Class<?> specialization = specializedTypeParameters.get(type.value());
Class<?> nativeParameterType = Primitives.unwrap(parameterType);
checkArgument(specialization == null || specialization.equals(nativeParameterType), "%s has conflicting specializations %s and %s", type.value(), specialization, nativeParameterType);
checkArgument(specialization == null || specialization.equals(nativeParameterType), "Method [%s] type %s has conflicting specializations %s and %s", method, type.value(), specialization, nativeParameterType);
specializedTypeParameters.put(type.value(), nativeParameterType);
}
argumentNativeContainerTypes.add(parameterType);
Expand All @@ -369,14 +372,14 @@ private Optional<MethodHandle> getConstructor(Method method, Map<Set<TypeParamet
}

Constructor<?> constructor = constructors.get(typeParameters);
requireNonNull(constructor, format("%s is an instance method and requires a public constructor to be declared with %s type parameters", method.getName(), typeParameters));
checkArgument(constructor != null, "Method [%s] is an instance method and requires a public constructor to be declared with %s type parameters", method, typeParameters);
for (int i = 0; i < constructor.getParameterCount(); i++) {
Annotation[] annotations = constructor.getParameterAnnotations()[i];
checkArgument(containsMetaParameter(annotations), "Constructors may only have meta parameters");
checkArgument(annotations.length == 1, "Meta parameters may only have a single annotation");
checkArgument(containsMetaParameter(annotations), "Constructors may only have meta parameters [%s]", constructor);
checkArgument(annotations.length == 1, "Meta parameters may only have a single annotation [%s]", constructor);
Annotation annotation = annotations[0];
if (annotation instanceof TypeParameter) {
checkArgument(typeParameters.contains(annotation), "Injected type parameters must be declared with @TypeParameter annotation on the constructor");
checkArgument(typeParameters.contains(annotation), "Injected type parameters must be declared with @TypeParameter annotation on the constructor [%s]", constructor);
}
constructorDependencies.add(parseDependency(annotation));
}
Expand All @@ -396,9 +399,10 @@ private Map<String, Class<?>> getDeclaredSpecializedTypeParameters(Method method
.map(TypeParameter::value)
.collect(toImmutableSet());
for (TypeParameterSpecialization specialization : typeParameterSpecializations) {
checkArgument(typeParameterNames.contains(specialization.name()), "%s does not match any declared type parameters (%s)", specialization.name(), typeParameters);
checkArgument(typeParameterNames.contains(specialization.name()), "%s does not match any declared type parameters (%s) [%s]", specialization.name(), typeParameters, method);
Class<?> existingSpecialization = specializedTypeParameters.get(specialization.name());
checkArgument(existingSpecialization == null || existingSpecialization.equals(specialization.nativeContainerType()), "%s has conflicting specializations %s and %s", specialization.name(), existingSpecialization, specialization.nativeContainerType());
checkArgument(existingSpecialization == null || existingSpecialization.equals(specialization.nativeContainerType()),
"%s has conflicting specializations %s and %s [%s]", specialization.name(), existingSpecialization, specialization.nativeContainerType(), method);
specializedTypeParameters.put(specialization.name(), specialization.nativeContainerType());
}
return specializedTypeParameters;
Expand Down
Expand Up @@ -144,6 +144,15 @@ protected void registerScalar(Class<?> clazz)
metadata.getFunctionRegistry().addFunctions(functions);
}

protected void registerParametricScalar(Class<?> clazz)
{
Metadata metadata = functionAssertions.getMetadata();
List<SqlFunction> functions = new FunctionListBuilder()
.scalar(clazz)
.getFunctions();
metadata.getFunctionRegistry().addFunctions(functions);
}

protected SqlDecimal decimal(String decimalString)
{
DecimalParseResult parseResult = Decimals.parseIncludeLeadingZerosInPrecision(decimalString);
Expand Down

0 comments on commit 8ef8309

Please sign in to comment.