Skip to content

Commit

Permalink
Rewrite RowConstructorCodeGenerator to reduce generated code size
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhargavi-Sagi committed Mar 29, 2024
1 parent 12eb467 commit f610d0b
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.block.RowBlockBuilder;
import io.trino.spi.block.SqlRow;
import io.trino.spi.type.Type;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SpecialForm;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static io.airlift.bytecode.ParameterizedType.type;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantNull;
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic;
import static io.airlift.bytecode.expression.BytecodeExpressions.newArray;
import static io.airlift.bytecode.expression.BytecodeExpressions.newInstance;
import static io.trino.sql.gen.SqlTypeBytecodeExpression.constantType;
Expand All @@ -42,6 +46,7 @@ public class RowConstructorCodeGenerator
{
private final Type rowType;
private final List<RowExpression> arguments;
private static final int MEGAMORPHIC_FIELD_COUNT = 64;

public RowConstructorCodeGenerator(SpecialForm specialForm)
{
Expand All @@ -53,6 +58,10 @@ public RowConstructorCodeGenerator(SpecialForm specialForm)
@Override
public BytecodeNode generateExpression(BytecodeGeneratorContext context)
{
if (arguments.size() > MEGAMORPHIC_FIELD_COUNT) {
return generateExpressionForLargeRows(context);
}

BytecodeBlock block = new BytecodeBlock().setDescription("Constructor for " + rowType);
CallSiteBinder binder = context.getCallSiteBinder();
Scope scope = context.getScope();
Expand Down Expand Up @@ -88,4 +97,38 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext context)
block.append(context.wasNull().set(constantFalse()));
return block;
}

private BytecodeNode generateExpressionForLargeRows(BytecodeGeneratorContext context)
{
BytecodeBlock block = new BytecodeBlock().setDescription("Constructor for " + rowType);
CallSiteBinder binder = context.getCallSiteBinder();
Scope scope = context.getScope();
List<Type> types = rowType.getTypeParameters();

Variable fieldBuilders = scope.createTempVariable(BlockBuilder[].class);
block.append(fieldBuilders.set(invokeStatic(RowBlockBuilder.class, "createFieldBlockBuildersForSingleRow", BlockBuilder[].class, constantType(binder, rowType))));

// Cache local variable declarations per java type on stack for reuse
Map<Class<?>, Variable> javaTypeTempVariables = new HashMap<>();
Variable blockBuilder = scope.createTempVariable(BlockBuilder.class);
for (int i = 0; i < arguments.size(); ++i) {
Type fieldType = types.get(i);
Variable field = javaTypeTempVariables.computeIfAbsent(fieldType.getJavaType(), scope::createTempVariable);

block.append(blockBuilder.set(fieldBuilders.getElement(constantInt(i))));

block.comment("Clean wasNull and Generate + " + i + "-th field of row");
block.append(context.wasNull().set(constantFalse()));
block.append(context.generate(arguments.get(i)));
block.putVariable(field);
block.append(new IfStatement()
.condition(context.wasNull())
.ifTrue(blockBuilder.invoke("appendNull", BlockBuilder.class).pop())
.ifFalse(constantType(binder, fieldType).writeValue(blockBuilder, field).pop()));
}

block.append(invokeStatic(RowBlockBuilder.class, "createSqlRowFromFieldBuildersForSingleRow", SqlRow.class, fieldBuilders));
block.append(context.wasNull().set(constantFalse()));
return block;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package io.trino.spi.block;

import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import jakarta.annotation.Nullable;

Expand Down Expand Up @@ -61,6 +62,26 @@ private RowBlockBuilder(@Nullable BlockBuilderStatus blockBuilderStatus, BlockBu
this.fieldBlockBuildersList = List.of(fieldBlockBuilders);
}

public static BlockBuilder[] createFieldBlockBuildersForSingleRow(Type rowType)
{
if (!(rowType instanceof RowType)) {
throw new IllegalArgumentException("Not a row type: " + rowType);
}
return createFieldBlockBuilders(rowType.getTypeParameters(), null, 1);
}

public static SqlRow createSqlRowFromFieldBuildersForSingleRow(BlockBuilder[] fieldBuilders)
{
Block[] fieldBlocks = new Block[fieldBuilders.length];
for (int i = 0; i < fieldBuilders.length; i++) {
fieldBlocks[i] = fieldBuilders[i].build();
if (fieldBlocks[i].getPositionCount() != 1) {
throw new IllegalArgumentException(format("builder must only contain a single position, found: %s positions", fieldBlocks[i].getPositionCount()));
}
}
return new SqlRow(0, fieldBlocks);
}

private static BlockBuilder[] createFieldBlockBuilders(List<Type> fieldTypes, BlockBuilderStatus blockBuilderStatus, int expectedEntries)
{
// Stream API should not be used since constructor can be called in performance sensitive sections
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import io.trino.plugin.hive.metastore.HiveMetastore;
import io.trino.plugin.hive.metastore.HiveMetastoreFactory;
import io.trino.testing.QueryRunner;
import io.trino.testing.sql.TestTable;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import java.io.IOException;
Expand All @@ -26,6 +28,9 @@

import static io.trino.plugin.iceberg.IcebergQueryRunner.ICEBERG_CATALOG;
import static io.trino.plugin.iceberg.IcebergTestUtils.checkOrcFileSorting;
import static io.trino.tpch.TpchTable.NATION;
import static io.trino.tpch.TpchTable.ORDERS;
import static io.trino.tpch.TpchTable.REGION;
import static org.apache.iceberg.FileFormat.ORC;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
Expand All @@ -48,7 +53,7 @@ protected QueryRunner createQueryRunner()
throws Exception
{
QueryRunner queryRunner = IcebergQueryRunner.builder()
.setInitialTables(REQUIRED_TPCH_TABLES)
.setInitialTables(NATION, ORDERS, REGION)
.setIcebergProperties(ImmutableMap.of(
"iceberg.file-format", format.name(),
"iceberg.register-table-procedure.enabled", "true",
Expand Down Expand Up @@ -108,4 +113,40 @@ protected boolean isFileSorted(Location path, String sortColumnName)
{
return checkOrcFileSorting(fileSystem, path, sortColumnName);
}

@Test
public void testRowConstructorColumnLimitForMergeQuery()
{
String[] colNames = {"orderkey", "custkey", "orderstatus", "totalprice", "orderpriority", "clerk", "shippriority", "comment", "orderdate"};
String[] colTypes = {"bigint", "bigint", "varchar", "decimal(12,2)", "varchar", "varchar", "int", "varchar", "date"};
String tableDefinition = "(";
String columns = "(";
String selectQuery = "select ";
String notMatchedClause = "";
String matchedClause = "";
// Creating merge query with 325 columns
for (int i = 0; i < 36; i++) {
for (int j = 0; j < 9; j++) {
tableDefinition += colNames[j] + "_" + i + " " + colTypes[j] + ",";
selectQuery += colNames[j] + " " + colNames[j] + "_" + i + ",";
columns += colNames[j] + "_" + i + ",";
notMatchedClause += "s." + colNames[j] + "_" + i + ",";
matchedClause += colNames[j] + "_" + i + " = s." + colNames[j] + "_" + i + ",";
}
}
tableDefinition += "orderkey bigint, custkey bigint, orderstatus varchar, totalprice decimal(12,2), orderpriority varchar) ";
selectQuery += "orderkey, custkey, orderstatus, totalprice, orderpriority from orders limit 1 ";
columns += "orderkey, custkey, orderstatus, totalprice, orderpriority) ";
notMatchedClause += "s.orderkey, s.custkey, s.orderstatus, s.totalprice, s.orderpriority ";
matchedClause += "orderkey = s.orderkey, custkey = s.custkey, orderstatus = s.orderstatus, totalprice = t.totalprice, orderpriority = s.orderpriority ";
TestTable table = new TestTable(getQueryRunner()::execute, "test_merge_", tableDefinition);
assertUpdate("INSERT INTO " + table.getName() + " " + columns + " " + selectQuery, 1);
TestTable mergeTable = new TestTable(getQueryRunner()::execute, "test_table_", tableDefinition);
assertUpdate("INSERT INTO " + mergeTable.getName() + " " + columns + " " + selectQuery, 1);
assertUpdate("MERGE INTO " + mergeTable.getName() + " t " +
"USING (select * from " + table.getName() + ") s " +
"ON (t.orderkey = s.orderkey) " +
"WHEN MATCHED THEN UPDATE SET " + matchedClause +
"WHEN NOT MATCHED THEN INSERT VALUES (" + notMatchedClause + ")", 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -393,4 +393,14 @@ public void testSelectiveLimit()
"LIMIT 1",
"VALUES -1");
}

@Test
public void testRowConstructorColumnLimit()
{
// Generate a query with 859 columns: SELECT row(col1, col2, ....col859) from t
String colNames = "orderkey, custkey, orderstatus, totalprice, orderpriority, clerk, shippriority, comment, orderdate";
String rowFields = colNames + (", " + colNames).repeat(94) + ", orderkey, custkey, orderstatus, totalprice";
@Language("SQL") String query = "SELECT row(" + rowFields + ") FROM (select * from tpch.tiny.orders limit 1) t(" + colNames + ")";
assertThat(getQueryRunner().execute(query).getOnlyValue()).isNotNull();
}
}

0 comments on commit f610d0b

Please sign in to comment.