Skip to content

Commit

Permalink
Migrate element_at() to new framework
Browse files Browse the repository at this point in the history
  • Loading branch information
cberner committed Feb 23, 2016
1 parent 91b932a commit 20a25f2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 74 deletions.
Expand Up @@ -39,6 +39,7 @@
import com.facebook.presto.operator.aggregation.NumericHistogramAggregation;
import com.facebook.presto.operator.aggregation.RegressionAggregation;
import com.facebook.presto.operator.aggregation.VarianceAggregation;
import com.facebook.presto.operator.scalar.ArrayElementAtFunction;
import com.facebook.presto.operator.scalar.ArrayFunctions;
import com.facebook.presto.operator.scalar.ArrayGreaterThanOperator;
import com.facebook.presto.operator.scalar.ArrayRemoveFunction;
Expand Down Expand Up @@ -150,7 +151,6 @@
import static com.facebook.presto.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR;
import static com.facebook.presto.operator.scalar.ArrayContains.ARRAY_CONTAINS;
import static com.facebook.presto.operator.scalar.ArrayDistinctFunction.ARRAY_DISTINCT_FUNCTION;
import static com.facebook.presto.operator.scalar.ArrayElementAtFunction.ARRAY_ELEMENT_AT_FUNCTION;
import static com.facebook.presto.operator.scalar.ArrayEqualOperator.ARRAY_EQUAL;
import static com.facebook.presto.operator.scalar.ArrayGreaterThanOrEqualOperator.ARRAY_GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.operator.scalar.ArrayHashCodeOperator.ARRAY_HASH_CODE;
Expand Down Expand Up @@ -347,12 +347,13 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key)
.functions(IDENTITY_CAST, CAST_FROM_UNKNOWN)
.scalar(ArrayRemoveFunction.class)
.scalar(ArrayGreaterThanOperator.class)
.scalar(ArrayElementAtFunction.class)
.functions(ARRAY_CONTAINS, ARRAY_JOIN, ARRAY_JOIN_WITH_NULL_REPLACEMENT)
.functions(ARRAY_MIN, ARRAY_MAX)
.functions(ARRAY_TO_ARRAY_CAST, ARRAY_HASH_CODE, ARRAY_EQUAL, ARRAY_NOT_EQUAL, ARRAY_LESS_THAN, ARRAY_LESS_THAN_OR_EQUAL, ARRAY_GREATER_THAN_OR_EQUAL)
.functions(ARRAY_CONCAT_FUNCTION, ARRAY_TO_ELEMENT_CONCAT_FUNCTION, ELEMENT_TO_ARRAY_CONCAT_FUNCTION)
.functions(MAP_EQUAL, MAP_NOT_EQUAL, MAP_HASH_CODE)
.functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_ELEMENT_AT_FUNCTION, ARRAY_CARDINALITY, ARRAY_POSITION, ARRAY_SORT_FUNCTION, ARRAY_INTERSECT_FUNCTION, ARRAY_TO_JSON, JSON_TO_ARRAY, ARRAY_DISTINCT_FUNCTION, ARRAY_SLICE_FUNCTION)
.functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_CARDINALITY, ARRAY_POSITION, ARRAY_SORT_FUNCTION, ARRAY_INTERSECT_FUNCTION, ARRAY_TO_JSON, JSON_TO_ARRAY, ARRAY_DISTINCT_FUNCTION, ARRAY_SLICE_FUNCTION)
.functions(MAP_CONSTRUCTOR, MAP_CARDINALITY, MAP_SUBSCRIPT, MAP_TO_JSON, JSON_TO_MAP, MAP_KEYS, MAP_VALUES, MAP_CONCAT_FUNCTION)
.functions(MAP_AGG, MULTIMAP_AGG)
.function(HISTOGRAM)
Expand Down
Expand Up @@ -13,88 +13,37 @@
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.SqlScalarFunction;
import com.facebook.presto.operator.Description;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.facebook.presto.type.SqlType;
import com.google.common.primitives.Ints;
import io.airlift.slice.Slice;

import java.lang.invoke.MethodHandle;
import java.util.Map;
import javax.annotation.Nullable;

import static com.facebook.presto.metadata.Signature.typeParameter;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.util.Reflection.methodHandle;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class ArrayElementAtFunction
extends SqlScalarFunction
@ScalarFunction("element_at")
@Description("Get element of array at given index")
public final class ArrayElementAtFunction
{
public static final ArrayElementAtFunction ARRAY_ELEMENT_AT_FUNCTION = new ArrayElementAtFunction();
private static final String FUNCTION_NAME = "element_at";
private static final Map<Class<?>, MethodHandle> METHOD_HANDLES = ImmutableMap.<Class<?>, MethodHandle>builder()
.put(boolean.class, methodHandle(ArrayElementAtFunction.class, "booleanElementAt", Type.class, Block.class, long.class))
.put(long.class, methodHandle(ArrayElementAtFunction.class, "longElementAt", Type.class, Block.class, long.class))
.put(double.class, methodHandle(ArrayElementAtFunction.class, "doubleElementAt", Type.class, Block.class, long.class))
.put(Slice.class, methodHandle(ArrayElementAtFunction.class, "sliceElementAt", Type.class, Block.class, long.class))
.put(void.class, methodHandle(ArrayElementAtFunction.class, "voidElementAt", Type.class, Block.class, long.class))
.build();
private static final MethodHandle OBJECT_METHOD_HANDLE = methodHandle(ArrayElementAtFunction.class, "objectElementAt", Type.class, Block.class, long.class);

public ArrayElementAtFunction()
{
super(FUNCTION_NAME, ImmutableList.of(typeParameter("E")), "E", ImmutableList.of("array(E)", "bigint"));
}

@Override
public boolean isHidden()
{
return false;
}

@Override
public boolean isDeterministic()
{
return true;
}

@Override
public String getDescription()
{
return "Get element of array at given index";
}

@Override
public ScalarFunctionImplementation specialize(Map<String, Type> types, int arity, TypeManager typeManager, FunctionRegistry functionRegistry)
{
checkArgument(types.size() == 1, "Expected one type, got %s", types);
Type elementType = types.get("E");

MethodHandle methodHandle;
if (METHOD_HANDLES.containsKey(elementType.getJavaType())) {
methodHandle = METHOD_HANDLES.get(elementType.getJavaType());
}
else {
checkArgument(!elementType.getJavaType().isPrimitive(), "Unsupported primitive type: " + elementType.getJavaType());
methodHandle = OBJECT_METHOD_HANDLE;
}
requireNonNull(methodHandle, "methodHandle is null");
methodHandle = methodHandle.bindTo(elementType);
return new ScalarFunctionImplementation(true, ImmutableList.of(false, false), methodHandle, isDeterministic());
}
private ArrayElementAtFunction() {}

public static void voidElementAt(Type elementType, Block array, long index)
@TypeParameter("E")
@Nullable
@SqlType("E")
public static Void voidElementAt(@SqlType("array(E)") Block array, @SqlType("bigint") long index)
{
checkedIndexToBlockPosition(array, index);
return null;
}

public static Long longElementAt(Type elementType, Block array, long index)
@TypeParameter("E")
@Nullable
@SqlType("E")
public static Long longElementAt(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block array, @SqlType("bigint") long index)
{
int position = checkedIndexToBlockPosition(array, index);
if (array.isNull(position)) {
Expand All @@ -104,7 +53,10 @@ public static Long longElementAt(Type elementType, Block array, long index)
return elementType.getLong(array, position);
}

public static Boolean booleanElementAt(Type elementType, Block array, long index)
@TypeParameter("E")
@Nullable
@SqlType("E")
public static Boolean booleanElementAt(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block array, @SqlType("bigint") long index)
{
int position = checkedIndexToBlockPosition(array, index);
if (array.isNull(position)) {
Expand All @@ -114,7 +66,10 @@ public static Boolean booleanElementAt(Type elementType, Block array, long index
return elementType.getBoolean(array, position);
}

public static Double doubleElementAt(Type elementType, Block array, long index)
@TypeParameter("E")
@Nullable
@SqlType("E")
public static Double doubleElementAt(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block array, @SqlType("bigint") long index)
{
int position = checkedIndexToBlockPosition(array, index);
if (array.isNull(position)) {
Expand All @@ -124,7 +79,10 @@ public static Double doubleElementAt(Type elementType, Block array, long index)
return elementType.getDouble(array, position);
}

public static Slice sliceElementAt(Type elementType, Block array, long index)
@TypeParameter("E")
@Nullable
@SqlType("E")
public static Slice sliceElementAt(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block array, @SqlType("bigint") long index)
{
int position = checkedIndexToBlockPosition(array, index);
if (array.isNull(position)) {
Expand All @@ -134,14 +92,17 @@ public static Slice sliceElementAt(Type elementType, Block array, long index)
return elementType.getSlice(array, position);
}

public static Object objectElementAt(Type elementType, Block array, long index)
@TypeParameter("E")
@Nullable
@SqlType("E")
public static Block blockElementAt(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block array, @SqlType("bigint") long index)
{
int position = checkedIndexToBlockPosition(array, index);
if (array.isNull(position)) {
return null;
}

return elementType.getObject(array, position);
return (Block) elementType.getObject(array, position);
}

private static int checkedIndexToBlockPosition(Block block, long index)
Expand Down

0 comments on commit 20a25f2

Please sign in to comment.