Skip to content

Commit

Permalink
Make contains() parametric
Browse files Browse the repository at this point in the history
  • Loading branch information
MonzBunz authored and cberner committed Mar 18, 2015
1 parent e66ee2e commit 35d641c
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 58 deletions.
Expand Up @@ -133,6 +133,7 @@
import static com.facebook.presto.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION; import static com.facebook.presto.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION;
import static com.facebook.presto.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR; import static com.facebook.presto.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR;
import static com.facebook.presto.operator.scalar.ArrayEqualOperator.ARRAY_EQUAL; import static com.facebook.presto.operator.scalar.ArrayEqualOperator.ARRAY_EQUAL;
import static com.facebook.presto.operator.scalar.ArrayContains.ARRAY_CONTAINS;
import static com.facebook.presto.operator.scalar.ArrayGreaterThanOperator.ARRAY_GREATER_THAN; import static com.facebook.presto.operator.scalar.ArrayGreaterThanOperator.ARRAY_GREATER_THAN;
import static com.facebook.presto.operator.scalar.ArrayGreaterThanOrEqualOperator.ARRAY_GREATER_THAN_OR_EQUAL; import static com.facebook.presto.operator.scalar.ArrayGreaterThanOrEqualOperator.ARRAY_GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.operator.scalar.ArrayHashCodeOperator.ARRAY_HASH_CODE; import static com.facebook.presto.operator.scalar.ArrayHashCodeOperator.ARRAY_HASH_CODE;
Expand Down Expand Up @@ -301,6 +302,7 @@ public FunctionInfo load(SpecializedFunctionKey key)
.scalar(ArrayFunctions.class) .scalar(ArrayFunctions.class)
.scalar(CombineHashFunction.class) .scalar(CombineHashFunction.class)
.scalar(JsonOperators.class) .scalar(JsonOperators.class)
.functions(ARRAY_CONTAINS)
.functions(ARRAY_HASH_CODE, ARRAY_EQUAL, ARRAY_NOT_EQUAL, ARRAY_LESS_THAN, ARRAY_LESS_THAN_OR_EQUAL, ARRAY_GREATER_THAN, ARRAY_GREATER_THAN_OR_EQUAL) .functions(ARRAY_HASH_CODE, ARRAY_EQUAL, ARRAY_NOT_EQUAL, ARRAY_LESS_THAN, ARRAY_LESS_THAN_OR_EQUAL, ARRAY_GREATER_THAN, ARRAY_GREATER_THAN_OR_EQUAL)
.functions(ARRAY_CONCAT_FUNCTION, ARRAY_TO_ELEMENT_CONCAT_FUNCTION, ELEMENT_TO_ARRAY_CONCAT_FUNCTION) .functions(ARRAY_CONCAT_FUNCTION, ARRAY_TO_ELEMENT_CONCAT_FUNCTION, ELEMENT_TO_ARRAY_CONCAT_FUNCTION)
.functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_CARDINALITY, ARRAY_SORT_FUNCTION, ARRAY_TO_JSON, JSON_TO_ARRAY) .functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_CARDINALITY, ARRAY_SORT_FUNCTION, ARRAY_TO_JSON, JSON_TO_ARRAY)
Expand Down
@@ -0,0 +1,129 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.metadata.FunctionInfo;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.ParametricScalar;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.spi.type.TypeSignature;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;

import java.lang.invoke.MethodHandle;
import java.util.Map;

import static com.facebook.presto.metadata.Signature.comparableTypeParameter;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.type.TypeUtils.createBlock;
import static com.facebook.presto.type.TypeUtils.parameterizedTypeName;
import static com.facebook.presto.type.TypeUtils.readStructuralBlock;
import static com.facebook.presto.util.Reflection.methodHandle;

public final class ArrayContains
extends ParametricScalar
{
public static final ArrayContains ARRAY_CONTAINS = new ArrayContains();
private static final TypeSignature RETURN_TYPE = parseTypeSignature(StandardTypes.BOOLEAN);
private static final String FUNCTION_NAME = "contains";
private static final Signature SIGNATURE = new Signature(FUNCTION_NAME, ImmutableList.of(comparableTypeParameter("T")), StandardTypes.BOOLEAN, ImmutableList.of("array<T>", "T"), false, false);

@Override
public Signature getSignature()
{
return SIGNATURE;
}

@Override
public boolean isHidden()
{
return false;
}

@Override
public boolean isDeterministic()
{
return true;
}

@Override
public String getDescription()
{
return "Determines whether given value exists in the array";
}

@Override
public FunctionInfo specialize(Map<String, Type> types, int arity, TypeManager typeManager, FunctionRegistry functionRegistry)
{
Type type = types.get("T");
TypeSignature valueType = type.getTypeSignature();
TypeSignature arrayType = parameterizedTypeName(StandardTypes.ARRAY, valueType);
MethodHandle methodHandle = methodHandle(ArrayContains.class, "contains", Type.class, Slice.class, type.getJavaType());
Signature signature = new Signature(FUNCTION_NAME, RETURN_TYPE, arrayType, valueType);

return new FunctionInfo(signature, getDescription(), isHidden(), methodHandle.bindTo(type), isDeterministic(), false, ImmutableList.of(false, false));
}

public static boolean contains(Type type, Slice slice, Slice value)
{
Block arrayBlock = readStructuralBlock(slice);
Block valueBlock = createBlock(type, value);
for (int i = 0; i < arrayBlock.getPositionCount(); i++) {
if (type.equalTo(arrayBlock, i, valueBlock, 0)) {
return true;
}
}
return false;
}

public static boolean contains(Type type, Slice slice, long value)
{
Block arrayBlock = readStructuralBlock(slice);
Block valueBlock = createBlock(type, value);
for (int i = 0; i < arrayBlock.getPositionCount(); i++) {
if (type.equalTo(arrayBlock, i, valueBlock, 0)) {
return true;
}
}
return false;
}

public static boolean contains(Type type, Slice slice, boolean value)
{
Block arrayBlock = readStructuralBlock(slice);
Block valueBlock = createBlock(type, value);
for (int i = 0; i < arrayBlock.getPositionCount(); i++) {
if (type.equalTo(arrayBlock, i, valueBlock, 0)) {
return true;
}
}
return false;
}

public static boolean contains(Type type, Slice slice, double value)
{
Block arrayBlock = readStructuralBlock(slice);
Block valueBlock = createBlock(type, value);
for (int i = 0; i < arrayBlock.getPositionCount(); i++) {
if (type.equalTo(arrayBlock, i, valueBlock, 0)) {
return true;
}
}
return false;
}
}
Expand Up @@ -13,24 +13,14 @@
*/ */
package com.facebook.presto.operator.scalar; package com.facebook.presto.operator.scalar;


import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.VariableWidthBlockBuilder; import com.facebook.presto.spi.block.VariableWidthBlockBuilder;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.type.SqlType; import com.facebook.presto.type.SqlType;
import io.airlift.slice.Slice; import io.airlift.slice.Slice;


import javax.annotation.Nullable;

import static com.facebook.presto.type.TypeUtils.readStructuralBlock;
import static com.facebook.presto.type.TypeUtils.buildStructuralSlice; import static com.facebook.presto.type.TypeUtils.buildStructuralSlice;
import static com.facebook.presto.type.TypeUtils.createBlock;


public final class ArrayFunctions public final class ArrayFunctions
{ {
Expand All @@ -45,51 +35,4 @@ public static Slice arrayConstructor()
BlockBuilder blockBuilder = new VariableWidthBlockBuilder(new BlockBuilderStatus(), 0); BlockBuilder blockBuilder = new VariableWidthBlockBuilder(new BlockBuilderStatus(), 0);
return buildStructuralSlice(blockBuilder); return buildStructuralSlice(blockBuilder);
} }

@Nullable
@ScalarFunction
@SqlType(StandardTypes.BOOLEAN)
public static Boolean contains(@SqlType("array<bigint>") Slice slice, @SqlType(StandardTypes.BIGINT) long value)
{
return arrayContains(slice, BigintType.BIGINT, value);
}

@Nullable
@ScalarFunction
@SqlType(StandardTypes.BOOLEAN)
public static Boolean contains(@SqlType("array<boolean>") Slice slice, @SqlType(StandardTypes.BOOLEAN) boolean value)
{
return arrayContains(slice, BooleanType.BOOLEAN, value);
}

@Nullable
@ScalarFunction
@SqlType(StandardTypes.BOOLEAN)
public static Boolean contains(@SqlType("array<double>") Slice slice, @SqlType(StandardTypes.DOUBLE) double value)
{
return arrayContains(slice, DoubleType.DOUBLE, value);
}

@Nullable
@ScalarFunction
@SqlType(StandardTypes.BOOLEAN)
public static Boolean contains(@SqlType("array<varchar>") Slice slice, @SqlType(StandardTypes.VARCHAR) Slice value)
{
return arrayContains(slice, VarcharType.VARCHAR, value);
}

private static Boolean arrayContains(Slice slice, Type type, Object value)
{
Block block = readStructuralBlock(slice);
Block valueBlock = createBlock(type, value);

//TODO: This could be quite slow, it should use parametric equals
for (int i = 0; i < block.getPositionCount(); i++) {
if (type.equalTo(block, i, valueBlock, 0)) {
return true;
}
}

return false;
}
} }

0 comments on commit 35d641c

Please sign in to comment.