Skip to content

Commit

Permalink
Make maps comparable
Browse files Browse the repository at this point in the history
  • Loading branch information
bcrfb authored and martint committed Feb 24, 2015
1 parent d36dd06 commit 09d9ae1
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 0 deletions.
Expand Up @@ -155,7 +155,9 @@
import static com.facebook.presto.operator.scalar.Least.LEAST;
import static com.facebook.presto.operator.scalar.MapCardinalityFunction.MAP_CARDINALITY;
import static com.facebook.presto.operator.scalar.MapConstructor.MAP_CONSTRUCTOR;
import static com.facebook.presto.operator.scalar.MapEqualOperator.MAP_EQUAL;
import static com.facebook.presto.operator.scalar.MapKeys.MAP_KEYS;
import static com.facebook.presto.operator.scalar.MapNotEqualOperator.MAP_NOT_EQUAL;
import static com.facebook.presto.operator.scalar.MapSubscriptOperator.MAP_SUBSCRIPT;
import static com.facebook.presto.operator.scalar.MapToJsonCast.MAP_TO_JSON;
import static com.facebook.presto.operator.scalar.MapValues.MAP_VALUES;
Expand Down Expand Up @@ -311,6 +313,7 @@ public FunctionInfo load(SpecializedFunctionKey key)
.functions(ARRAY_HASH_CODE, ARRAY_EQUAL, ARRAY_NOT_EQUAL, ARRAY_LESS_THAN, ARRAY_LESS_THAN_OR_EQUAL, ARRAY_GREATER_THAN, ARRAY_GREATER_THAN_OR_EQUAL)
.functions(ARRAY_CONCAT_FUNCTION, ARRAY_TO_ELEMENT_CONCAT_FUNCTION, ELEMENT_TO_ARRAY_CONCAT_FUNCTION)
.functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_CARDINALITY, ARRAY_SORT_FUNCTION, ARRAY_TO_JSON, JSON_TO_ARRAY)
.functions(MAP_EQUAL, MAP_NOT_EQUAL)
.functions(MAP_CONSTRUCTOR, MAP_CARDINALITY, MAP_SUBSCRIPT, MAP_TO_JSON, JSON_TO_MAP, MAP_KEYS, MAP_VALUES, MAP_AGG)
.function(IDENTITY_CAST)
.function(ARBITRARY_AGGREGATION)
Expand Down
@@ -0,0 +1,170 @@
package com.facebook.presto.operator.scalar;
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import com.facebook.presto.metadata.FunctionInfo;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.ParametricOperator;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.spi.type.TypeSignature;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;

import java.lang.invoke.MethodHandle;
import java.util.LinkedHashMap;
import java.util.Map;

import static com.facebook.presto.metadata.FunctionRegistry.operatorInfo;
import static com.facebook.presto.metadata.OperatorType.EQUAL;
import static com.facebook.presto.metadata.OperatorType.HASH_CODE;
import static com.facebook.presto.metadata.Signature.comparableTypeParameter;
import static com.facebook.presto.spi.StandardErrorCode.INTERNAL_ERROR;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.type.TypeJsonUtils.getObjectMap;
import static com.facebook.presto.type.TypeJsonUtils.castKey;
import static com.facebook.presto.type.TypeJsonUtils.castValue;
import static com.facebook.presto.util.Reflection.methodHandle;

public class MapEqualOperator
extends ParametricOperator
{
public static final MapEqualOperator MAP_EQUAL = new MapEqualOperator();
private static final TypeSignature RETURN_TYPE = parseTypeSignature(StandardTypes.BOOLEAN);

private MapEqualOperator()
{
super(EQUAL, ImmutableList.of(comparableTypeParameter("K"), comparableTypeParameter("V")), StandardTypes.BOOLEAN, ImmutableList.of("map<K,V>", "map<K,V>"));
}

@Override
public FunctionInfo specialize(Map<String, Type> types, int arity, TypeManager typeManager, FunctionRegistry functionRegistry)
{
Type keyType = types.get("K");
Type valueType = types.get("V");

Type type = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of(keyType.getTypeSignature(), valueType.getTypeSignature()), ImmutableList.of());
TypeSignature typeSignature = type.getTypeSignature();

MethodHandle keyEqualsFunction = functionRegistry.resolveOperator(EQUAL, ImmutableList.of(keyType, keyType)).getMethodHandle();
MethodHandle keyHashcodeFunction = functionRegistry.resolveOperator(HASH_CODE, ImmutableList.of(keyType)).getMethodHandle();
MethodHandle valueEqualsFunction = functionRegistry.resolveOperator(EQUAL, ImmutableList.of(valueType, valueType)).getMethodHandle();

MethodHandle methodHandle = methodHandle(MapEqualOperator.class, "equals", MethodHandle.class, MethodHandle.class, MethodHandle.class, Type.class, Type.class, Slice.class, Slice.class);
MethodHandle method = methodHandle.bindTo(keyEqualsFunction).bindTo(keyHashcodeFunction).bindTo(valueEqualsFunction).bindTo(keyType).bindTo(valueType);
return operatorInfo(EQUAL, RETURN_TYPE, ImmutableList.of(typeSignature, typeSignature), method, true, ImmutableList.of(false, false));
}

public static Boolean equals(MethodHandle keyEqualsFunction, MethodHandle keyHashcodeFunction, MethodHandle valueEqualsFunction, Type keyType, Type valueType, Slice left, Slice right)
{
Map<String, Object> leftMap = getObjectMap(left);
Map<String, Object> rightMap = getObjectMap(right);

Map<KeyWrapper, Object> wrappedLeftMap = new LinkedHashMap<>();
for (Map.Entry<String, Object> entry : leftMap.entrySet()) {
wrappedLeftMap.put(new KeyWrapper(castKey(keyType, entry.getKey()), keyEqualsFunction, keyHashcodeFunction), entry.getValue());
}

Map<KeyWrapper, Object> wrappedRightMap = new LinkedHashMap<>();
for (Map.Entry<String, Object> entry : rightMap.entrySet()) {
wrappedRightMap.put(new KeyWrapper(castKey(keyType, entry.getKey()), keyEqualsFunction, keyHashcodeFunction), entry.getValue());
}

if (wrappedLeftMap.size() != wrappedRightMap.size()) {
return false;
}

for (Map.Entry<KeyWrapper, Object> entry : wrappedRightMap.entrySet()) {
KeyWrapper key = entry.getKey();
if (!wrappedLeftMap.containsKey(key)) {
return false;
}

Object leftValue = wrappedLeftMap.get(key);
if (leftValue == null) {
return null;
}

Object rightValue = entry.getValue();
if (rightValue == null) {
return null;
}

try {
Boolean result = (Boolean) valueEqualsFunction.invoke(castValue(valueType, leftValue), castValue(valueType, rightValue));
if (result == null) {
return null;
}
else if (!result) {
return false;
}
}
catch (Throwable t) {
Throwables.propagateIfInstanceOf(t, Error.class);
Throwables.propagateIfInstanceOf(t, PrestoException.class);

throw new PrestoException(INTERNAL_ERROR, t);
}
}
return true;
}

private static final class KeyWrapper
{
private final Object key;
private final MethodHandle hashCode;
private final MethodHandle equals;

public KeyWrapper(Object key, MethodHandle equals, MethodHandle hashCode)
{
this.key = key;
this.equals = equals;
this.hashCode = hashCode;
}

@Override
public int hashCode()
{
try {
return Long.hashCode((long) hashCode.invoke(key));
}
catch (Throwable t) {
Throwables.propagateIfInstanceOf(t, Error.class);
Throwables.propagateIfInstanceOf(t, PrestoException.class);

throw new PrestoException(INTERNAL_ERROR, t);
}
}

@Override
public boolean equals(Object obj)
{
if (obj == null || !getClass().equals(obj.getClass())) {
return false;
}
KeyWrapper other = (KeyWrapper) obj;
try {
return (Boolean) equals.invoke(key, other.key);
}
catch (Throwable t) {
Throwables.propagateIfInstanceOf(t, Error.class);
Throwables.propagateIfInstanceOf(t, PrestoException.class);

throw new PrestoException(INTERNAL_ERROR, t);
}
}
}
}
@@ -0,0 +1,75 @@
package com.facebook.presto.operator.scalar;
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import com.facebook.presto.metadata.FunctionInfo;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.ParametricOperator;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.spi.type.TypeSignature;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;

import java.lang.invoke.MethodHandle;
import java.util.Map;

import static com.facebook.presto.metadata.FunctionRegistry.operatorInfo;
import static com.facebook.presto.metadata.OperatorType.EQUAL;
import static com.facebook.presto.metadata.OperatorType.HASH_CODE;
import static com.facebook.presto.metadata.OperatorType.NOT_EQUAL;
import static com.facebook.presto.metadata.Signature.comparableTypeParameter;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.util.Reflection.methodHandle;

public class MapNotEqualOperator
extends ParametricOperator
{
public static final MapNotEqualOperator MAP_NOT_EQUAL = new MapNotEqualOperator();
private static final TypeSignature RETURN_TYPE = parseTypeSignature(StandardTypes.BOOLEAN);

private MapNotEqualOperator()
{
super(NOT_EQUAL, ImmutableList.of(comparableTypeParameter("K"), comparableTypeParameter("V")), StandardTypes.BOOLEAN, ImmutableList.of("map<K,V>", "map<K,V>"));
}

@Override
public FunctionInfo specialize(Map<String, Type> types, int arity, TypeManager typeManager, FunctionRegistry functionRegistry)
{
Type keyType = types.get("K");
Type valueType = types.get("V");

Type type = typeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of(keyType.getTypeSignature(), valueType.getTypeSignature()), ImmutableList.of());
TypeSignature typeSignature = type.getTypeSignature();

MethodHandle keyEqualsFunction = functionRegistry.resolveOperator(EQUAL, ImmutableList.of(keyType, keyType)).getMethodHandle();
MethodHandle keyHashcodeFunction = functionRegistry.resolveOperator(HASH_CODE, ImmutableList.of(keyType)).getMethodHandle();
MethodHandle valueEqualsFunction = functionRegistry.resolveOperator(EQUAL, ImmutableList.of(valueType, valueType)).getMethodHandle();

MethodHandle methodHandle = methodHandle(MapNotEqualOperator.class, "notEqual", MethodHandle.class, MethodHandle.class, MethodHandle.class, Type.class, Type.class, Slice.class, Slice.class);
MethodHandle method = methodHandle.bindTo(keyEqualsFunction).bindTo(keyHashcodeFunction).bindTo(valueEqualsFunction).bindTo(keyType).bindTo(valueType);
return operatorInfo(NOT_EQUAL, RETURN_TYPE, ImmutableList.of(typeSignature, typeSignature), method, true, ImmutableList.of(false, false));
}

public static Boolean notEqual(MethodHandle keyEqualsFunction, MethodHandle keyHashcodeFunction, MethodHandle valueEqualsFunction, Type keyType, Type valueType, Slice left, Slice right)
{
Boolean equals = MapEqualOperator.equals(keyEqualsFunction, keyHashcodeFunction, valueEqualsFunction, keyType, valueType, left, right);
if (equals == null) {
return null;
}

return !equals;
}
}
Expand Up @@ -49,6 +49,7 @@ public final class TypeJsonUtils
private static final JsonFactory JSON_FACTORY = new JsonFactory().disable(CANONICALIZE_FIELD_NAMES);
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapperProvider().get();
private static final CollectionType COLLECTION_TYPE = OBJECT_MAPPER.getTypeFactory().constructCollectionType(List.class, Object.class);
private static final com.fasterxml.jackson.databind.type.MapType MAP_TYPE = OBJECT_MAPPER.getTypeFactory().constructMapType(Map.class, String.class, Object.class);

private TypeJsonUtils() {}

Expand Down Expand Up @@ -156,6 +157,16 @@ public static List<Object> getObjectList(Slice slice)
}
}

public static Map<String, Object> getObjectMap(Slice slice)
{
try {
return OBJECT_MAPPER.readValue(slice.getInput(), MAP_TYPE);
}
catch (IOException e) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e);
}
}

public static Block createBlock(Type type, Object element)
{
BlockBuilder blockBuilder = type.createBlockBuilder(new BlockBuilderStatus());
Expand All @@ -182,6 +193,51 @@ else if (javaType == Slice.class) {
return blockBuilder.build();
}

public static Object castKey(Type type, String key)
{
Class<?> javaType = type.getJavaType();

if (key == null) {
return null;
}
else if (javaType == boolean.class) {
return Boolean.valueOf(key);
}
else if (javaType == long.class) {
return Long.parseLong(key);
}
else if (javaType == double.class) {
return Double.parseDouble(key);
}
else if (javaType == Slice.class) {
return Slices.utf8Slice(key);
}
else {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Unexpected type %s", javaType.getName()));
}
}

public static Object castValue(Type type, Object value)
{
Class<?> javaType = type.getJavaType();

if (value == null) {
return null;
}
else if (javaType == boolean.class || javaType == double.class) {
return value;
}
else if (javaType == long.class) {
return ((Number) value).longValue();
}
else if (javaType == Slice.class) {
return Slices.utf8Slice(value.toString());
}
else {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Unexpected type %s", javaType.getName()));
}
}

public static Object getValue(Block input, Type type, int position)
{
if (input.isNull(position)) {
Expand Down

0 comments on commit 09d9ae1

Please sign in to comment.