Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
package io.substrait.isthmus;

import io.substrait.extension.SimpleExtension;
import io.substrait.plan.Plan;
import io.substrait.relation.AbstractRelVisitor;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Rel;
import io.substrait.type.NamedStruct;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.jdbc.LookupCalciteSchema;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.schema.Table;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;

/**
* Converts between Substrait {@link Rel}s and Calcite {@link RelNode}s.
Expand Down Expand Up @@ -99,6 +106,79 @@ public RelNode convert(Rel rel) {
return rel.accept(converter);
}

/**
* Converts a Substrait {@link Plan.Root} to a Calcite {@link RelRoot}
*
* <p>Generates a {@link RelDataType} row type with the final field names of the {@link Plan.Root}
* and creates a Calcite {@link RelRoot} with it.
*
* <p>TODO: revisit this code when support for WriteRel is added to substrait-java
*
* <p>TODO: this code assumes that the Apache Calcite knows how to properly alias hierarchical
* field names which is currently not the case (Calcite version 1.39.0)
*
* @param root {@link Plan.Root} to convert
* @return {@link RelRoot}
*/
public RelRoot convert(Plan.Root root) {
RelNode input = convert(root.getInput());
RelDataType inputRowType = input.getRowType();

RelDataType newRowType = renameFields(inputRowType, root.getNames(), 0).right;
RelRoot calciteRoot = RelRoot.of(input, newRowType, SqlKind.SELECT);

return calciteRoot;
}

/**
* Produces a new {@link RelDataType} from the given {@link RelDataType} by recursively applying
* the given names in depth-first order.
*
* @param type the source {@link RelDataType} to rename
* @param names the names to use for renaming
* @param currentIndex the current index within the list of names
* @return the renamed {@link RelDataType}
*/
private Pair<Integer, RelDataType> renameFields(
RelDataType type, List<String> names, Integer currentIndex) {
Integer nextIndex = currentIndex;

switch (type.getSqlTypeName()) {
case ROW:
case STRUCTURED:
final List<String> newFieldNames = new ArrayList<>();
final List<RelDataType> renamedFields = new ArrayList<>();
for (RelDataTypeField field : type.getFieldList()) {
newFieldNames.add(names.get(nextIndex));
Pair<Integer, RelDataType> p = renameFields(field.getType(), names, (nextIndex + 1));
renamedFields.add(p.right);
nextIndex = p.left;
}

return Pair.of(
nextIndex,
typeFactory.createStructType(type.getStructKind(), renamedFields, newFieldNames));
case ARRAY:
case MULTISET:
Pair<Integer, RelDataType> renamedElementType =
renameFields(type.getComponentType(), names, nextIndex);

return Pair.of(
renamedElementType.left, typeFactory.createArrayType(renamedElementType.right, -1L));
case MAP:
Pair<Integer, RelDataType> renamedKeyType =
renameFields(type.getKeyType(), names, nextIndex);
Pair<Integer, RelDataType> renamedValueType =
renameFields(type.getValueType(), names, renamedKeyType.left);

return Pair.of(
renamedValueType.left,
typeFactory.createMapType(renamedKeyType.right, renamedValueType.right));
default:
return Pair.of(currentIndex, type);
}
}

private static class NamedStructGatherer extends AbstractRelVisitor<Void, RuntimeException> {
Map<List<String>, NamedStruct> tableMap;

Expand Down
163 changes: 163 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.plan.ImmutableRoot;
import io.substrait.plan.Plan.Root;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.Set;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.type.SqlTypeName;
import org.junit.jupiter.api.Test;

public class SubstraitToCalciteTest extends PlanTestBase {
final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory);

@Test
void testConvertRootSingleColumn() {
Iterable<Type> types = List.of(TypeCreator.REQUIRED.STRING);
Root root =
ImmutableRoot.builder()
.input(substraitBuilder.namedScan(List.of("stores"), List.of("s"), types))
.addNames("store")
.build();

RelRoot relRoot = converter.convert(root);

assertEquals(root.getNames(), relRoot.fields.rightList());
}

@Test
void testConvertRootMultipleColumns() {
Iterable<Type> types = List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING);
Root root =
ImmutableRoot.builder()
.input(substraitBuilder.namedScan(List.of("stores"), List.of("s_store_id", "s"), types))
.addNames("s_store_id", "store")
.build();

RelRoot relRoot = converter.convert(root);

assertEquals(root.getNames(), relRoot.fields.rightList());
}

Comment thread
vbarua marked this conversation as resolved.
@Test
void testConvertRootStructField() {
final Type structType =
TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING);
Iterable<Type> types = List.of(structType);
Root root =
ImmutableRoot.builder()
.input(
substraitBuilder.namedScan(
List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types))
.addNames("store", "store_id", "store_name")
.build();

assertEquals(List.of("store", "store_id", "store_name"), root.getNames());

RelRoot relRoot = converter.convert(root);

// Apache Calcite's RelRoot.fields only contains the top level field names
assertEquals(List.of("store"), relRoot.fields.rightList());

// the sub field names are stored within RelRoot.validatedRowType
assertEquals(List.of("store"), relRoot.validatedRowType.getFieldNames());

RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType();
assertEquals(List.of("store_id", "store_name"), storeFieldDataType.getFieldNames());
}

@Test
void testConvertRootArrayWithStructField() {
final Type structType =
TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING);
final Type arrayType = TypeCreator.REQUIRED.list(structType);
Set<Type> types = Set.of(arrayType);
Root root =
ImmutableRoot.builder()
.input(
substraitBuilder.namedScan(
List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types))
.addNames("store", "store_id", "store_name")
.build();

RelRoot relRoot = converter.convert(root);

// Apache Calcite's RelRoot.fields only contains the top level field names
assertEquals(List.of("store"), relRoot.fields.rightList());

// the hierarchical structure is stored within RelRoot.validatedRowType
assertEquals(List.of("store"), relRoot.validatedRowType.getFieldNames());

RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType();
assertEquals(SqlTypeName.ARRAY, storeFieldDataType.getSqlTypeName());

final RelDataType arrayElementType = storeFieldDataType.getComponentType();
assertEquals(SqlTypeName.ROW, arrayElementType.getSqlTypeName());
assertEquals(List.of("store_id", "store_name"), arrayElementType.getFieldNames());
}

@Test
void testConvertRootMapWithStructValues() {
final Type structType =
TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING);
final Type mapValueType = TypeCreator.REQUIRED.map(TypeCreator.REQUIRED.I64, structType);
Set<Type> types = Set.of(mapValueType);
Root root =
ImmutableRoot.builder()
.input(
substraitBuilder.namedScan(
List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types))
.addNames("store", "store_id", "store_name")
.build();

final RelRoot relRoot = converter.convert(root);

// Apache Calcite's RelRoot.fields only contains the top level field names
assertEquals(List.of("store"), relRoot.fields.rightList());

// the hierarchical structure is stored within RelRoot.validatedRowType
assertEquals(List.of("store"), relRoot.validatedRowType.getFieldNames());

final RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType();
assertEquals(SqlTypeName.MAP, storeFieldDataType.getSqlTypeName());

final RelDataType mapValueDataType = storeFieldDataType.getValueType();
assertEquals(SqlTypeName.ROW, mapValueDataType.getSqlTypeName());
assertEquals(List.of("store_id", "store_name"), mapValueDataType.getFieldNames());
}

@Test
void testConvertRootMapWithStructKeys() {
final Type structType =
TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING);
final Type mapKeyType = TypeCreator.REQUIRED.map(structType, TypeCreator.REQUIRED.I64);
Set<Type> types = Set.of(mapKeyType);
Root root =
ImmutableRoot.builder()
.input(
substraitBuilder.namedScan(
List.of("stores"), List.of("s", "s_store_id", "s_store_name"), types))
.addNames("store", "store_id", "store_name")
.build();

RelRoot relRoot = converter.convert(root);

// Apache Calcite's RelRoot.fields only contains the top level field names
assertEquals(List.of("store"), relRoot.fields.rightList());

// the hierarchical structure is stored within RelRoot.validatedRowType
assertEquals(List.of("store"), relRoot.validatedRowType.getFieldNames());

RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType();
assertEquals(SqlTypeName.MAP, storeFieldDataType.getSqlTypeName());

final RelDataType mapKeyDataType = storeFieldDataType.getKeyType();
assertEquals(SqlTypeName.ROW, mapKeyDataType.getSqlTypeName());
assertEquals(List.of("store_id", "store_name"), mapKeyDataType.getFieldNames());
}
}