Skip to content

Commit

Permalink
Add transform_key lambda function
Browse files Browse the repository at this point in the history
  • Loading branch information
wenleix committed Feb 16, 2017
1 parent eb68300 commit e854d23
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 1 deletion.
11 changes: 11 additions & 0 deletions presto-docs/src/main/sphinx/functions/lambda.rst
Expand Up @@ -53,6 +53,17 @@ Lambda Functions
SELECT transform(ARRAY ['x', 'abc', 'z'], x -> x || '0'); -- ['x0', 'abc0', 'z0']
SELECT transform(ARRAY [ARRAY [1, NULL, 2], ARRAY[3, NULL]], a -> filter(a, x -> x IS NOT NULL)); -- [[1, 2], [3]]

.. function:: transform_key(map<K1,V>, function<K1,V,K2>) -> MAP<K2,V>

Returns a map that applies ``function`` to each entry of ``map`` and transforms the keys::

SELECT transform_key(MAP(ARRAY[], ARRAY[]), (k, v) -> k + 1); -- {}
SELECT transform_key(MAP(ARRAY [1, 2, 3], ARRAY ['a', 'b', 'c']), (k, v) -> k + 1); -- {2 -> a, 3 -> b, 4 -> c}
SELECT transform_key(MAP(ARRAY ['a', 'b', 'c'], ARRAY [1, 2, 3]), (k, v) -> v * v); -- {1 -> 1, 4 -> 2, 9 -> 3}
SELECT transform_key(MAP(ARRAY ['a', 'b'], ARRAY [1, 2]), (k, v) -> k || CAST(v as VARCHAR)); -- {a1 -> 1, b2 -> 2}
SELECT transform_key(MAP(ARRAY [1, 2], ARRAY [1.0, 1.4]), -- {one -> 1.0, two -> 1.4}
(k, v) -> MAP(ARRAY[1, 2], ARRAY['one', 'two'])[k]);

.. function:: reduce(array<T>, initialState S, inputFunction<S,T,S>, outputFunction<S,R>) -> R

Returns a single value reduced from ``array``. ``inputFunction`` will
Expand Down
5 changes: 5 additions & 0 deletions presto-docs/src/main/sphinx/functions/map.rst
Expand Up @@ -46,6 +46,11 @@ Map Functions

See :func:`map_filter`.

.. function:: transform_key(map<K1,V>, function) -> MAP<K2,V>
:noindex:

See :func:`transform_key`.

.. function:: map_keys(x<K,V>) -> array<K>

Returns all the keys in the map ``x``.
Expand Down
Expand Up @@ -223,6 +223,7 @@
import static com.facebook.presto.operator.scalar.MapFilterFunction.MAP_FILTER_FUNCTION;
import static com.facebook.presto.operator.scalar.MapHashCodeOperator.MAP_HASH_CODE;
import static com.facebook.presto.operator.scalar.MapToJsonCast.MAP_TO_JSON;
import static com.facebook.presto.operator.scalar.MapTransformKeyFunction.MAP_TRANSFORM_KEY_FUNCTION;
import static com.facebook.presto.operator.scalar.MathFunctions.DECIMAL_CEILING_FUNCTIONS;
import static com.facebook.presto.operator.scalar.MathFunctions.DECIMAL_FLOOR_FUNCTION;
import static com.facebook.presto.operator.scalar.MathFunctions.DECIMAL_MOD_FUNCTION;
Expand Down Expand Up @@ -550,7 +551,7 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key)
.functions(DECIMAL_ROUND_FUNCTIONS)
.function(DECIMAL_TRUNCATE_FUNCTION)
.functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_FILTER_FUNCTION, ARRAY_REDUCE_FUNCTION)
.functions(MAP_FILTER_FUNCTION)
.functions(MAP_FILTER_FUNCTION, MAP_TRANSFORM_KEY_FUNCTION)
.function(TRY_CAST);

builder.function(new ArrayAggregationFunction(featuresConfig.isLegacyArrayAgg()));
Expand Down
@@ -0,0 +1,133 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.operator.scalar;

import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.metadata.SqlScalarFunction;
import com.facebook.presto.operator.aggregation.TypedSet;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.InterleavedBlockBuilder;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;

import java.lang.invoke.MethodHandle;

import static com.facebook.presto.metadata.Signature.typeVariable;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.spi.type.TypeUtils.readNativeValue;
import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue;
import static com.facebook.presto.util.Reflection.methodHandle;
import static java.lang.String.format;

public final class MapTransformKeyFunction
extends SqlScalarFunction
{
public static final MapTransformKeyFunction MAP_TRANSFORM_KEY_FUNCTION = new MapTransformKeyFunction();

private static final MethodHandle METHOD_HANDLE = methodHandle(
MapTransformKeyFunction.class,
"transform",
Type.class,
Type.class,
Type.class,
ConnectorSession.class,
Block.class,
MethodHandle.class);

private MapTransformKeyFunction()
{
super(new Signature(
"transform_key",
FunctionKind.SCALAR,
ImmutableList.of(typeVariable("K1"), typeVariable("K2"), typeVariable("V")),
ImmutableList.of(),
parseTypeSignature("map(K2,V)"),
ImmutableList.of(parseTypeSignature("map(K1,V)"), parseTypeSignature("function(K1,V,K2)")),
false));
}

@Override
public boolean isHidden()
{
return false;
}

@Override
public boolean isDeterministic()
{
return false;
}

@Override
public String getDescription()
{
return "apply lambda to each entry of the map and transform the key";
}

@Override
public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry)
{
Type keyType = boundVariables.getTypeVariable("K1");
Type transformedKeyType = boundVariables.getTypeVariable("K2");
Type valueType = boundVariables.getTypeVariable("V");
return new ScalarFunctionImplementation(
false,
ImmutableList.of(false, false),
METHOD_HANDLE.bindTo(keyType).bindTo(transformedKeyType).bindTo(valueType),
isDeterministic());
}

public static Block transform(Type keyType, Type transformedKeyType, Type valueType, ConnectorSession session, Block block, MethodHandle function)
{
int positionCount = block.getPositionCount();
BlockBuilder resultBuilder = new InterleavedBlockBuilder(ImmutableList.of(transformedKeyType, valueType), new BlockBuilderStatus(), positionCount);
TypedSet typedSet = new TypedSet(transformedKeyType, positionCount / 2);

for (int position = 0; position < positionCount; position += 2) {
Object key = readNativeValue(keyType, block, position);
Object value = readNativeValue(valueType, block, position + 1);
Object transformedKey;
try {
transformedKey = function.invoke(key, value);
}
catch (Throwable throwable) {
throw Throwables.propagate(throwable);
}

if (transformedKey == null) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "map key cannot be null");
}

writeNativeValue(transformedKeyType, resultBuilder, transformedKey);
valueType.appendTo(block, position + 1, resultBuilder);

if (typedSet.contains(resultBuilder, position)) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Duplicate keys (%s) are not allowed", transformedKeyType.getObjectValue(session, resultBuilder, position)));
}
typedSet.add(resultBuilder, position);
}
return resultBuilder.build();
}
}

0 comments on commit e854d23

Please sign in to comment.