Skip to content

Commit

Permalink
Add ROW constructor
Browse files Browse the repository at this point in the history
It doesn't set the fields names, so these rows are "anonymous". The row
can be named by doing something like
"select cast(row(1, 2) as row(a bigint, b bigint))"
  • Loading branch information
lacbs authored and cberner committed May 9, 2016
1 parent 8fa3fe9 commit 57d666f
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 10 deletions.
Expand Up @@ -40,6 +40,7 @@
import static com.facebook.presto.sql.relational.Signatures.IN;
import static com.facebook.presto.sql.relational.Signatures.IS_NULL;
import static com.facebook.presto.sql.relational.Signatures.NULL_IF;
import static com.facebook.presto.sql.relational.Signatures.ROW_CONSTRUCTOR;
import static com.facebook.presto.sql.relational.Signatures.SWITCH;
import static com.facebook.presto.sql.relational.Signatures.TRY;

Expand Down Expand Up @@ -114,6 +115,9 @@ public BytecodeNode visitCall(CallExpression call, final Scope scope)
case DEREFERENCE:
generator = new DereferenceCodeGenerator();
break;
case ROW_CONSTRUCTOR:
generator = new RowConstructorCodeGenerator();
break;
default:
generator = new FunctionCallCodeGenerator();
}
Expand Down
@@ -0,0 +1,77 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.InterleavedBlockBuilder;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.relational.RowExpression;

import java.util.List;

import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance;
import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant;
import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType;

public class RowConstructorCodeGenerator
implements BytecodeGenerator
{
@Override
public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorContext context, Type rowType, List<RowExpression> arguments)
{
BytecodeBlock block = new BytecodeBlock().setDescription("Constructor for " + rowType.toString());
CallSiteBinder binder = context.getCallSiteBinder();
Scope scope = context.getScope();
List<Type> types = rowType.getTypeParameters();

block.comment("BlockBuilder blockBuilder = new InterleavedBlockBuilder(types, new BlockBuilderStatus(), 1);");
Variable blockBuilder = scope.createTempVariable(BlockBuilder.class);
Binding typesBinding = binder.bind(types, List.class);
block.append(blockBuilder.set(
newInstance(InterleavedBlockBuilder.class, loadConstant(typesBinding), newInstance(BlockBuilderStatus.class), constantInt(1))));

for (int i = 0; i < arguments.size(); ++i) {
Type fieldType = types.get(i);
Class<?> javaType = fieldType.getJavaType();
if (javaType == void.class) {
block.comment(i + "-th field type of row is undefined");
block.append(blockBuilder.invoke("appendNull", BlockBuilder.class).pop());
}
else {
Variable field = scope.createTempVariable(javaType);
block.comment("Generate + " + i + "-th field of row");
block.append(context.generate(arguments.get(i)));
block.putVariable(field);
block.append(new IfStatement()
.condition(context.wasNull())
.ifTrue(blockBuilder.invoke("appendNull", BlockBuilder.class).pop())
.ifFalse(constantType(binder, fieldType).writeValue(blockBuilder, field).pop()));
}
}
block.comment("put (Block) blockBuilder.build(); wasNull = false;");
block.append(blockBuilder.invoke("build", Block.class));
block.append(context.wasNull().set(constantFalse()));
return block;
}
}
Expand Up @@ -27,6 +27,7 @@
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.InterleavedBlockBuilder;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.analyzer.AnalysisContext;
Expand Down Expand Up @@ -1037,7 +1038,25 @@ protected Object visitArrayConstructor(ArrayConstructor node, Object context)
@Override
protected Object visitRow(Row node, Object context)
{
throw new PrestoException(NOT_SUPPORTED, "Row expressions not yet supported");
RowType rowType = checkType(expressionTypes.get(node), RowType.class, "type");
List<Type> parameterTypes = rowType.getTypeParameters();
List<Expression> arguments = node.getItems();

int cardinality = arguments.size();
List<Object> values = new ArrayList<>(cardinality);
for (Expression argument : arguments) {
values.add(process(argument, context));
}
if (hasUnresolvedValue(values)) {
return new Row(toExpressions(values, parameterTypes));
}
else {
BlockBuilder blockBuilder = new InterleavedBlockBuilder(parameterTypes, new BlockBuilderStatus(), cardinality);
for (int i = 0; i < cardinality; ++i) {
writeNativeValue(parameterTypes.get(i), blockBuilder, values.get(i));
}
return blockBuilder.build();
}
}

@Override
Expand Down
Expand Up @@ -36,6 +36,7 @@
import static com.facebook.presto.metadata.Signature.internalScalarFunction;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.sql.tree.ArrayConstructor.ARRAY_CONSTRUCTOR;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;

public final class Signatures
{
Expand All @@ -49,6 +50,7 @@ public final class Signatures
public static final String IN = "IN";
public static final String TRY = "TRY";
public static final String DEREFERENCE = "DEREFERENCE";
public static final String ROW_CONSTRUCTOR = "ROW_CONSTRUCTOR";

private Signatures()
{
Expand Down Expand Up @@ -158,6 +160,11 @@ public static Signature inSignature()
return internalScalarFunction(IN, StandardTypes.BOOLEAN);
}

public static Signature rowConstructorSignature(Type returnType, List<Type> argumentTypes)
{
return internalScalarFunction(ROW_CONSTRUCTOR, returnType.getTypeSignature(), argumentTypes.stream().map(Type::getTypeSignature).collect(toImmutableList()));
}

// **************** functions that need to do special null handling ****************
public static Signature isNullSignature(Type argumentType)
{
Expand Down
Expand Up @@ -53,6 +53,7 @@
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.Row;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SimpleCaseExpression;
import com.facebook.presto.sql.tree.StringLiteral;
Expand Down Expand Up @@ -97,6 +98,7 @@
import static com.facebook.presto.sql.relational.Signatures.likeSignature;
import static com.facebook.presto.sql.relational.Signatures.logicalExpressionSignature;
import static com.facebook.presto.sql.relational.Signatures.nullIfSignature;
import static com.facebook.presto.sql.relational.Signatures.rowConstructorSignature;
import static com.facebook.presto.sql.relational.Signatures.subscriptSignature;
import static com.facebook.presto.sql.relational.Signatures.switchSignature;
import static com.facebook.presto.sql.relational.Signatures.tryCastSignature;
Expand Down Expand Up @@ -619,5 +621,18 @@ protected RowExpression visitArrayConstructor(ArrayConstructor node, Void contex
.collect(toImmutableList());
return call(arrayConstructorSignature(types.get(node), argumentTypes), types.get(node), arguments);
}

@Override
protected RowExpression visitRow(Row node, Void context)
{
List<RowExpression> arguments = node.getItems().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
Type returnType = types.get(node);
List<Type> argumentTypes = node.getItems().stream()
.map(value -> types.get(value))
.collect(toImmutableList());
return call(rowConstructorSignature(returnType, argumentTypes), returnType, arguments);
}
}
}
Expand Up @@ -41,6 +41,7 @@
import static com.facebook.presto.sql.relational.Signatures.IN;
import static com.facebook.presto.sql.relational.Signatures.IS_NULL;
import static com.facebook.presto.sql.relational.Signatures.NULL_IF;
import static com.facebook.presto.sql.relational.Signatures.ROW_CONSTRUCTOR;
import static com.facebook.presto.sql.relational.Signatures.SWITCH;
import static com.facebook.presto.sql.relational.Signatures.TRY;
import static com.facebook.presto.sql.relational.Signatures.TRY_CAST;
Expand Down Expand Up @@ -133,7 +134,8 @@ public RowExpression visitCall(CallExpression call, Void context)
case "AND":
case "OR":
case IN:
case DEREFERENCE: {
case DEREFERENCE:
case ROW_CONSTRUCTOR: {
List<RowExpression> arguments = call.getArguments().stream()
.map(argument -> argument.accept(this, null))
.collect(toImmutableList());
Expand Down
Expand Up @@ -1136,6 +1136,26 @@ public void testArrayConstructor()
"array_constructor((bound_long + 0), (unbound_long + 1), NULL)");
}

@Test
public void testRowConstructor()
{
optimize("ROW(NULL)");
optimize("ROW(1)");
optimize("ROW(unbound_long + 0)");
optimize("ROW(unbound_long + unbound_long2, unbound_string, unbound_double)");
optimize("ROW(unbound_boolean, FALSE, ARRAY[unbound_long, unbound_long2], unbound_null_string, unbound_interval)");
optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(unbound_string, 0.0)]");
optimize("ARRAY [ROW('string', unbound_double), ROW('string', bound_double)]");
optimize("ROW(ROW(NULL), ROW(ROW(ROW(ROW('rowception')))))");
optimize("ROW(unbound_string, bound_string)");

optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(CAST(bound_string AS VARCHAR), 0.0)]");
optimize("ARRAY [ROW(CAST(bound_string AS VARCHAR), 0.0), ROW(unbound_string, unbound_double)]");

optimize("ARRAY [ROW(unbound_string, unbound_double), CAST(NULL AS ROW(VARCHAR, DOUBLE))]");
optimize("ARRAY [CAST(NULL AS ROW(VARCHAR, DOUBLE)), ROW(unbound_string, unbound_double)]");
}

@Test(expectedExceptions = PrestoException.class)
public void testArraySubscriptConstantNegativeIndex()
{
Expand Down
Expand Up @@ -60,14 +60,16 @@ public void testRowTypeLookup()
public void testRowToJson()
throws Exception
{
assertFunction("CAST(test_row(1, 2) AS JSON)", JSON, "[1,2]");
assertFunction("CAST(test_row(1, CAST(NULL AS INTEGER)) AS JSON)", JSON, "[1,null]");
assertFunction("CAST(test_row(1, 2.0) AS JSON)", JSON, "[1,2.0]");
assertFunction("CAST(test_row(1.0, 2.5) AS JSON)", JSON, "[1.0,2.5]");
assertFunction("CAST(test_row(1.0, 'kittens') AS JSON)", JSON, "[1.0,\"kittens\"]");
assertFunction("CAST(test_row(TRUE, FALSE) AS JSON)", JSON, "[true,false]");
assertFunction("CAST(test_row(from_unixtime(1)) AS JSON)", JSON, "[\"" + new SqlTimestamp(1000, TEST_SESSION.getTimeZoneKey()) + "\"]");
assertFunction("CAST(test_row(FALSE, ARRAY [1, 2], MAP(ARRAY[1, 3], ARRAY[2.0, 4.0])) AS JSON)", JSON, "[false,[1,2],{\"1\":2.0,\"3\":4.0}]");
assertFunction("CAST(ROW(1, 2) AS JSON)", JSON, "[1,2]");
assertFunction("CAST(CAST(ROW(1, 2) AS ROW(a BIGINT, b BIGINT)) AS JSON)", JSON, "[1,2]");
assertFunction("CAST(ROW(1, NULL) AS JSON)", JSON, "[1,null]");
assertFunction("CAST(ROW(1, CAST(NULL AS INTEGER)) AS JSON)", JSON, "[1,null]");
assertFunction("CAST(ROW(1, 2.0) AS JSON)", JSON, "[1,2.0]");
assertFunction("CAST(ROW(1.0, 2.5) AS JSON)", JSON, "[1.0,2.5]");
assertFunction("CAST(ROW(1.0, 'kittens') AS JSON)", JSON, "[1.0,\"kittens\"]");
assertFunction("CAST(ROW(TRUE, FALSE) AS JSON)", JSON, "[true,false]");
assertFunction("CAST(ROW(from_unixtime(1)) AS JSON)", JSON, "[\"" + new SqlTimestamp(1000, TEST_SESSION.getTimeZoneKey()) + "\"]");
assertFunction("CAST(ROW(FALSE, ARRAY [1, 2], MAP(ARRAY[1, 3], ARRAY[2.0, 4.0])) AS JSON)", JSON, "[false,[1,2],{\"1\":2.0,\"3\":4.0}]");
}

@Test
Expand All @@ -93,6 +95,16 @@ public void testFieldAccessor()
assertFunction("test_row(FALSE, ARRAY [1, 2], MAP(ARRAY[1, 3], ARRAY[2.0, 4.0])).col1", new ArrayType(INTEGER), ImmutableList.of(1, 2));
assertFunction("test_row(FALSE, ARRAY [1, 2], MAP(ARRAY[1, 3], ARRAY[2.0, 4.0])).col2", new MapType(INTEGER, DOUBLE), ImmutableMap.of(1, 2.0, 3, 4.0));
assertFunction("test_row(1.0, ARRAY[test_row(31, 4.1), test_row(32, 4.2)], test_row(3, 4.0)).col1[2].col0", INTEGER, 32);

// Using ROW constructor
assertFunction("CAST(ROW(1, 2) AS ROW(a BIGINT, b DOUBLE)).a", BIGINT, 1L);
assertFunction("CAST(ROW(1, 2) AS ROW(a BIGINT, b DOUBLE)).b", DOUBLE, 2.0);
assertFunction("CAST(ROW(CAST(ROW('aa') AS ROW(a VARCHAR))) AS ROW(a ROW(a VARCHAR))).a.a", VARCHAR, "aa");
assertFunction("CAST(ROW(ROW('ab')) AS ROW(a ROW(b VARCHAR))).a.b", VARCHAR, "ab");
assertFunction("CAST(ROW(ARRAY[NULL]) AS ROW(a ARRAY(BIGINT))).a", new ArrayType(BIGINT), Arrays.asList((Integer) null));

// Row type is not case sensitive
assertFunction("CAST(ROW(1) AS ROW(A BIGINT)).A", BIGINT, 1L);
}

@Test
Expand All @@ -104,6 +116,14 @@ public void testRowCast()
assertFunction("cast(test_row(2, cast(null as double)) as row(aa bigint, bb double)).bb", DOUBLE, null);
assertFunction("cast(test_row(2, 'test_str') as row(aa bigint, bb varchar)).bb", VARCHAR, "test_str");

try {
assertFunction("CAST(ROW(1, 2) AS ROW(a BIGINT, A DOUBLE)).a", BIGINT, 1L);
fail("fields in Row are case insensitive");
}
catch (RuntimeException e) {
// Expected
}

// there are totally 7 field names
String longFieldNameCast = "CAST(test_row(1.2, ARRAY[test_row(233, 6.9)], test_row(1000, 6.3)) AS ROW(%s VARCHAR, %s ARRAY(ROW(%s VARCHAR, %s VARCHAR)), %s ROW(%s VARCHAR, %s VARCHAR))).%s[1].%s";
int fieldCount = 7;
Expand Down Expand Up @@ -169,5 +189,8 @@ public void testRowEquality()
assertFunction("test_row(TRUE, ARRAY [1]) != test_row(TRUE, ARRAY [1])", BOOLEAN, false);
assertFunction("test_row(TRUE, ARRAY [1]) != test_row(TRUE, ARRAY [1,2])", BOOLEAN, true);
assertFunction("test_row(1.0, ARRAY [1,2,3], test_row(2,2.0)) != test_row(1.0, ARRAY [1,2,3], test_row(1,2.0))", BOOLEAN, true);

assertFunction("ROW(1, 2) = ROW(1, 2)", BOOLEAN, true);
assertFunction("ROW(2, 1) != ROW(1, 2)", BOOLEAN, true);
}
}
Expand Up @@ -50,6 +50,7 @@
import com.facebook.presto.sql.tree.RenameTable;
import com.facebook.presto.sql.tree.ResetSession;
import com.facebook.presto.sql.tree.Rollback;
import com.facebook.presto.sql.tree.Row;
import com.facebook.presto.sql.tree.SampledRelation;
import com.facebook.presto.sql.tree.Select;
import com.facebook.presto.sql.tree.SelectItem;
Expand Down Expand Up @@ -831,6 +832,22 @@ protected Void visitCall(Call node, Integer indent)
return null;
}

@Override
protected Void visitRow(Row node, Integer indent)
{
builder.append("ROW(");
boolean firstItem = true;
for (Expression item : node.getItems()) {
if (!firstItem) {
builder.append(", ");
}
process(item, indent);
firstItem = false;
}
builder.append(")");
return null;
}

@Override
protected Void visitStartTransaction(StartTransaction node, Integer indent)
{
Expand Down
Expand Up @@ -500,6 +500,22 @@ public void testArrays()
assertQuery("SELECT CARDINALITY(a) FROM (SELECT ARRAY[orderkey, orderkey + 1] AS a FROM orders ORDER BY orderkey) t", "SELECT 2 FROM orders");
}

@Test
public void testRows()
throws Exception
{
// Using JSON_FORMAT(CAST(_ AS JSON)) because H2 does not support ROW type
assertQuery("SELECT JSON_FORMAT(CAST(ROW(1 + 2, CONCAT('a', 'b')) AS JSON))", "SELECT '[3,\"ab\"]'");
assertQuery("SELECT JSON_FORMAT(CAST(ROW(a + b) AS JSON)) FROM (VALUES (1, 2)) AS t(a, b)", "SELECT '[3]'");
assertQuery("SELECT JSON_FORMAT(CAST(ROW(1, ROW(9, a, ARRAY[], NULL), ROW(1, 2)) AS JSON)) FROM (VALUES ('a')) t(a)", "SELECT '[1,[9,\"a\",[],null],[1,2]]'");
assertQuery("SELECT JSON_FORMAT(CAST(ROW(ROW(ROW(ROW(ROW(a, b), c), d), e), f) AS JSON)) FROM (VALUES (ROW(0, 1), 2, '3', NULL, ARRAY[5], ARRAY[])) t(a, b, c, d, e, f)",
"SELECT '[[[[[[0,1],2],\"3\"],null],[5]],[]]'");
assertQuery("SELECT JSON_FORMAT(CAST(ARRAY_AGG(ROW(a, b)) AS JSON)) FROM (VALUES (1, 2), (3, 4), (5, 6)) t(a, b)", "SELECT '[[1,2],[3,4],[5,6]]'");
assertQuery("SELECT CONTAINS(ARRAY_AGG(ROW(a, b)), ROW(1, 2)) FROM (VALUES (1, 2), (3, 4), (5, 6)) t(a, b)", "SELECT TRUE");
assertQuery("SELECT JSON_FORMAT(CAST(ARRAY_AGG(ROW(c, d)) AS JSON)) FROM (VALUES (ARRAY[1, 3, 5], ARRAY[2, 4, 6])) AS t(a, b) CROSS JOIN UNNEST(a, b) AS u(c, d)",
"SELECT '[[1,2],[3,4],[5,6]]'");
}

@Test
public void testMaps()
throws Exception
Expand Down

0 comments on commit 57d666f

Please sign in to comment.