Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quote row type field names #1963

Merged
merged 5 commits into from Nov 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 3 additions & 16 deletions presto-client/src/main/java/io/prestosql/client/RowFieldName.java
Expand Up @@ -23,15 +23,12 @@
public class RowFieldName
{
private final String name;
private final boolean delimited;

@JsonCreator
public RowFieldName(
@JsonProperty("name") String name,
@JsonProperty("delimited") boolean delimited)
@JsonProperty("name") String name)
{
this.name = requireNonNull(name, "name is null");
this.delimited = delimited;
}

@JsonProperty
Expand All @@ -40,12 +37,6 @@ public String getName()
return name;
}

@JsonProperty
public boolean isDelimited()
{
return delimited;
}

@Override
public boolean equals(Object o)
{
Expand All @@ -58,22 +49,18 @@ public boolean equals(Object o)

RowFieldName other = (RowFieldName) o;

return Objects.equals(this.name, other.name) &&
Objects.equals(this.delimited, other.delimited);
return Objects.equals(this.name, other.name);
}

@Override
public String toString()
{
if (!isDelimited()) {
return name;
}
return '"' + name.replace("\"", "\"\"") + '"';
}

@Override
public int hashCode()
{
return Objects.hash(name, delimited);
return Objects.hash(name);
}
}
Expand Up @@ -48,8 +48,8 @@ public void testJsonRoundTrip()
assertJsonRoundTrip(new ClientTypeSignature(
"row",
ImmutableList.of(
ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(Optional.of(new RowFieldName("foo", false)), bigint)),
ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(Optional.of(new RowFieldName("bar", false)), bigint)))));
ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(Optional.of(new RowFieldName("foo")), bigint)),
ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(Optional.of(new RowFieldName("bar")), bigint)))));
}

@Test
Expand All @@ -64,8 +64,8 @@ public void testStringSerialization()
ClientTypeSignature row = new ClientTypeSignature(
StandardTypes.ROW,
ImmutableList.of(
ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(Optional.of(new RowFieldName("foo", false)), bigint)),
ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(Optional.of(new RowFieldName("bar", false)), bigint))));
ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(Optional.of(new RowFieldName("foo")), bigint)),
ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(Optional.of(new RowFieldName("bar")), bigint))));
assertEquals(row.toString(), "row(foo bigint,bar bigint)");
}

Expand Down
Expand Up @@ -117,7 +117,7 @@ private static ClientTypeSignatureParameter toClientTypeSignatureParameter(TypeS
case NAMED_TYPE:
return ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(
parameter.getNamedTypeSignature().getFieldName().map(value ->
new RowFieldName(value.getName(), value.isDelimited())),
new RowFieldName(value.getName())),
toClientTypeSignature(parameter.getNamedTypeSignature().getTypeSignature())));
case LONG:
return ClientTypeSignatureParameter.ofLong(parameter.getLongLiteral());
Expand Down
Expand Up @@ -236,7 +236,7 @@ private static TypeSignature getTypeSignature(TypeInfo typeInfo)
// Users can't work around this by casting in their queries because Presto parser always lower case types.
// TODO: This is a hack. Presto engine should be able to handle identifiers in a case insensitive way where necessary.
String rowFieldName = structFieldNames.get(i).toLowerCase(Locale.US);
typeSignatureBuilder.add(TypeSignatureParameter.namedTypeParameter(new NamedTypeSignature(Optional.of(new RowFieldName(rowFieldName, false)), typeSignature)));
typeSignatureBuilder.add(TypeSignatureParameter.namedTypeParameter(new NamedTypeSignature(Optional.of(new RowFieldName(rowFieldName)), typeSignature)));
}
return new TypeSignature(StandardTypes.ROW, typeSignatureBuilder.build());
}
Expand Down
Expand Up @@ -283,9 +283,9 @@ public abstract class AbstractTestHive
private static final Type ARRAY_TYPE = arrayType(createUnboundedVarcharType());
private static final Type MAP_TYPE = mapType(createUnboundedVarcharType(), BIGINT);
private static final Type ROW_TYPE = rowType(ImmutableList.of(
new NamedTypeSignature(Optional.of(new RowFieldName("f_string", false)), createUnboundedVarcharType().getTypeSignature()),
new NamedTypeSignature(Optional.of(new RowFieldName("f_bigint", false)), BIGINT.getTypeSignature()),
new NamedTypeSignature(Optional.of(new RowFieldName("f_boolean", false)), BOOLEAN.getTypeSignature())));
new NamedTypeSignature(Optional.of(new RowFieldName("f_string")), createUnboundedVarcharType().getTypeSignature()),
new NamedTypeSignature(Optional.of(new RowFieldName("f_bigint")), BIGINT.getTypeSignature()),
new NamedTypeSignature(Optional.of(new RowFieldName("f_boolean")), BOOLEAN.getTypeSignature())));

private static final List<ColumnMetadata> CREATE_TABLE_COLUMNS = ImmutableList.<ColumnMetadata>builder()
.add(new ColumnMetadata("id", BIGINT))
Expand Down Expand Up @@ -356,7 +356,7 @@ public abstract class AbstractTestHive
private static RowType toRowType(List<ColumnMetadata> columns)
{
return rowType(columns.stream()
.map(col -> new NamedTypeSignature(Optional.of(new RowFieldName(format("f_%s", col.getName()), false)), col.getType().getTypeSignature()))
.map(col -> new NamedTypeSignature(Optional.of(new RowFieldName(format("f_%s", col.getName()))), col.getType().getTypeSignature()))
.collect(toImmutableList()));
}

Expand Down
Expand Up @@ -94,7 +94,7 @@ public void testTypeTranslator()
assertInvalidTypeTranslation(
RowType.anonymous(ImmutableList.of(INTEGER, VARBINARY)),
NOT_SUPPORTED.toErrorCode(),
"Anonymous row type is not supported in Hive. Please give each field a name: row(integer,varbinary)");
"Anonymous row type is not supported in Hive. Please give each field a name: row(integer, varbinary)");
}

private void assertTypeTranslation(Type type, HiveType hiveType)
Expand Down
Expand Up @@ -1029,11 +1029,7 @@ public ResolvedFunction resolveOperator(OperatorType operatorType, List<? extend
}
catch (PrestoException e) {
if (e.getErrorCode().getCode() == FUNCTION_NOT_FOUND.toErrorCode().getCode()) {
throw new OperatorNotFoundException(
operatorType,
argumentTypes.stream()
.map(Type::getTypeSignature)
.collect(toImmutableList()));
throw new OperatorNotFoundException(operatorType, argumentTypes);
}
else {
throw e;
Expand All @@ -1052,7 +1048,7 @@ public ResolvedFunction getCoercion(OperatorType operatorType, Type fromType, Ty
}
catch (PrestoException e) {
if (e.getErrorCode().getCode() == FUNCTION_IMPLEMENTATION_MISSING.toErrorCode().getCode()) {
throw new OperatorNotFoundException(operatorType, ImmutableList.of(fromType.getTypeSignature()), toType.getTypeSignature());
throw new OperatorNotFoundException(operatorType, ImmutableList.of(fromType), toType.getTypeSignature());
}
throw e;
}
Expand Down
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;

import java.util.List;
Expand All @@ -31,25 +32,25 @@ public class OperatorNotFoundException
{
private final OperatorType operatorType;
private final TypeSignature returnType;
private final List<TypeSignature> argumentTypes;
private final List<Type> argumentTypes;

public OperatorNotFoundException(OperatorType operatorType, List<? extends TypeSignature> argumentTypes)
public OperatorNotFoundException(OperatorType operatorType, List<? extends Type> argumentTypes)
{
super(OPERATOR_NOT_FOUND, formatErrorMessage(operatorType, argumentTypes, Optional.empty()));
this.operatorType = requireNonNull(operatorType, "operatorType is null");
this.returnType = null;
this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null"));
}

public OperatorNotFoundException(OperatorType operatorType, List<? extends TypeSignature> argumentTypes, TypeSignature returnType)
public OperatorNotFoundException(OperatorType operatorType, List<? extends Type> argumentTypes, TypeSignature returnType)
{
super(OPERATOR_NOT_FOUND, formatErrorMessage(operatorType, argumentTypes, Optional.of(returnType)));
this.operatorType = requireNonNull(operatorType, "operatorType is null");
this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null"));
this.returnType = requireNonNull(returnType, "returnType is null");
}

private static String formatErrorMessage(OperatorType operatorType, List<? extends TypeSignature> argumentTypes, Optional<TypeSignature> returnType)
private static String formatErrorMessage(OperatorType operatorType, List<? extends Type> argumentTypes, Optional<TypeSignature> returnType)
{
String operatorString;
switch (operatorType) {
Expand All @@ -74,7 +75,7 @@ public TypeSignature getReturnType()
return returnType;
}

public List<TypeSignature> getArgumentTypes()
public List<Type> getArgumentTypes()
{
return argumentTypes;
}
Expand Down
Expand Up @@ -194,7 +194,7 @@ else if (c == ')') {
verify(tokenStart >= 0, "Expect tokenStart to be non-negative");
verify(delimitedColumnName != null, "Expect delimitedColumnName to be non-null");
fields.add(TypeSignatureParameter.namedTypeParameter(new NamedTypeSignature(
Optional.of(new RowFieldName(delimitedColumnName, true)),
Optional.of(new RowFieldName(delimitedColumnName)),
parseTypeSignature(signature.substring(tokenStart, i).trim(), literalParameters))));
delimitedColumnName = null;
tokenStart = -1;
Expand All @@ -204,7 +204,7 @@ else if (c == ',' && bracketLevel == 1) {
verify(tokenStart >= 0, "Expect tokenStart to be non-negative");
verify(delimitedColumnName != null, "Expect delimitedColumnName to be non-null");
fields.add(TypeSignatureParameter.namedTypeParameter(new NamedTypeSignature(
Optional.of(new RowFieldName(delimitedColumnName, true)),
Optional.of(new RowFieldName(delimitedColumnName)),
parseTypeSignature(signature.substring(tokenStart, i).trim(), literalParameters))));
delimitedColumnName = null;
tokenStart = -1;
Expand Down Expand Up @@ -238,7 +238,7 @@ private static TypeSignatureParameter parseTypeOrNamedType(String typeOrNamedTyp
String firstPart = typeOrNamedType.substring(0, split);
if (IDENTIFIER_PATTERN.matcher(firstPart).matches()) {
return TypeSignatureParameter.namedTypeParameter(new NamedTypeSignature(
Optional.of(new RowFieldName(firstPart, false)),
Optional.of(new RowFieldName(firstPart)),
parseTypeSignature(typeOrNamedType.substring(split + 1).trim(), literalParameters)));
}

Expand Down
Expand Up @@ -594,7 +594,7 @@ private static ClientTypeSignatureParameter toClientTypeSignatureParameter(TypeS
case NAMED_TYPE:
return ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(
parameter.getNamedTypeSignature().getFieldName().map(value ->
new RowFieldName(value.getName(), value.isDelimited())),
new RowFieldName(value.getName())),
toClientTypeSignature(parameter.getNamedTypeSignature().getTypeSignature())));
case LONG:
return ClientTypeSignatureParameter.ofLong(parameter.getLongLiteral());
Expand Down
Expand Up @@ -110,7 +110,7 @@ private static TypeSignature toTypeSignature(RowDataType type)
.map(field -> namedTypeParameter(new NamedTypeSignature(
field.getName()
.map(TypeSignatureTranslator::canonicalize)
.map(value -> new RowFieldName(value, false)),
.map(value -> new RowFieldName(value)),
toTypeSignature(field.getType()))))
.collect(toImmutableList());

Expand Down
Expand Up @@ -89,7 +89,7 @@ public Expression rewriteSubscriptExpression(SubscriptExpression node, Void cont

// Do not cast if Row fields are named
if (fieldName.isPresent()) {
result = new DereferenceExpression(base, new Identifier(fieldName.get()));
result = new DereferenceExpression(base, new Identifier(fieldName.get(), true));
}
else {
// Cast to Row with named fields
Expand Down
Expand Up @@ -181,8 +181,8 @@ public static Object evaluateConstantExpression(Expression expression, Type expe
Type actualType = analyzer.getExpressionTypes().get(NodeRef.of(expression));
if (!new TypeCoercion(metadata::getType).canCoerce(actualType, expectedType)) {
throw semanticException(TYPE_MISMATCH, expression, format("Cannot cast type %s to %s",
actualType.getTypeSignature(),
expectedType.getTypeSignature()));
actualType.getDisplayName(),
expectedType.getDisplayName()));
}

Map<NodeRef<Expression>, Type> coercions = ImmutableMap.<NodeRef<Expression>, Type>builder()
Expand Down