Skip to content

Commit

Permalink
TEIID-5528 adding array literal insert and update support to pg
Browse files Browse the repository at this point in the history
  • Loading branch information
shawkins committed Nov 12, 2018
1 parent d844c76 commit 5b70509
Show file tree
Hide file tree
Showing 9 changed files with 281 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,14 @@ public void setBooleanNullable(boolean booleanNullable) {
public boolean hasTypeMapping(int type) {
return this.typeMapping.containsKey(type) || this.typeModifier.containsKey(type);
}

/**
* Return the direct type mapping for a given type code
* @param code
* @return
*/
public String getSimpleTypeMapping(int code) {
return typeMapping.get(code);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -261,19 +261,26 @@ public void visit(Function obj) {

@Override
public void visit(Parameter obj) {
addBinding(obj);
}

/**
* Add a bind ? value to the sql string and to the binding value list
* @param value should be an {@link Argument}, {@link Parameter}, or {@link Argument}
*/
protected void addBinding(LanguageObject value) {
this.prepared = true;
buffer.append(UNDEFINED_PARAM);
preparedValues.add(obj);
preparedValues.add(value);
usingBinding = true;
}
}

/**
* @see org.teiid.language.visitor.SQLStringVisitor#visit(org.teiid.language.Literal)
*/
public void visit(Literal obj) {
if (this.prepared && ((replaceWithBinding && obj.isBindEligible()) || TranslatedCommand.isBindEligible(obj))) {
buffer.append(UNDEFINED_PARAM);
preparedValues.add(obj);
usingBinding = true;
addBinding(obj);
} else {
translateSQLType(obj.getType(), obj.getValue(), buffer);
}
Expand Down Expand Up @@ -460,8 +467,8 @@ protected void appendBaseName(NamedTable obj) {
@Override
public void substitute(Argument arg, StringBuilder builder, int index) {
if (this.prepared && arg.getExpression() instanceof Literal) {
buffer.append('?');
this.preparedValues.add(arg);
buffer.append(UNDEFINED_PARAM);
this.preparedValues.add(arg);
} else {
visit(arg);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags and
* the COPYRIGHT.txt file distributed with this work.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.teiid.translator.jdbc.postgresql;

import org.teiid.language.Array;
import org.teiid.language.ColumnReference;
import org.teiid.language.DerivedColumn;
import org.teiid.language.Expression;
import org.teiid.language.Literal;
import org.teiid.language.SQLConstants;
import org.teiid.language.With;
import org.teiid.language.WithItem;
import org.teiid.translator.TypeFacility;
import org.teiid.translator.jdbc.ConvertModifier;
import org.teiid.translator.jdbc.SQLConversionVisitor;

public class PostgreSQLConversionVisitor
extends SQLConversionVisitor {

private PostgreSQLExecutionFactory postgreSQLExecutionFactory;

public PostgreSQLConversionVisitor(PostgreSQLExecutionFactory ef) {
super(ef);
this.postgreSQLExecutionFactory = ef;
}

@Override
protected void appendWithKeyword(With obj) {
super.appendWithKeyword(obj);
for (WithItem with : obj.getItems()) {
if (with.isRecusive()) {
buffer.append(SQLConstants.Tokens.SPACE);
buffer.append(SQLConstants.Reserved.RECURSIVE);
break;
}
}
}

/**
* Some literals in the select need a cast to prevent being seen as the unknown/string type
*/
@Override
public void visit(DerivedColumn obj) {
if (obj.getExpression() instanceof Literal) {
String castType = null;
if (obj.getExpression().getType() == TypeFacility.RUNTIME_TYPES.STRING) {
castType = "bpchar"; //$NON-NLS-1$
} else if (obj.getExpression().getType() == TypeFacility.RUNTIME_TYPES.VARBINARY) {
castType = "bytea"; //$NON-NLS-1$
}
if (castType != null) {
obj.setExpression(postgreSQLExecutionFactory.getLanguageFactory().createFunction("cast", //$NON-NLS-1$
new Expression[] {obj.getExpression(), postgreSQLExecutionFactory.getLanguageFactory().createLiteral(castType, TypeFacility.RUNTIME_TYPES.STRING)},
TypeFacility.RUNTIME_TYPES.STRING));
}
} else if (obj.isProjected() && obj.getExpression() instanceof ColumnReference) {
ColumnReference elem = (ColumnReference)obj.getExpression();
if (elem.getMetadataObject() != null) {
String nativeType = elem.getMetadataObject().getNativeType();
if (TypeFacility.RUNTIME_TYPES.STRING.equals(elem.getType())
&& elem.getMetadataObject() != null
&& nativeType != null
&& nativeType.equalsIgnoreCase(PostgreSQLExecutionFactory.UUID_TYPE)) {
obj.setExpression(postgreSQLExecutionFactory.getLanguageFactory().createFunction("cast", //$NON-NLS-1$
new Expression[] {obj.getExpression(), postgreSQLExecutionFactory.getLanguageFactory().createLiteral("varchar", TypeFacility.RUNTIME_TYPES.STRING)}, //$NON-NLS-1$
TypeFacility.RUNTIME_TYPES.STRING));
}
}
}
super.visit(obj);
}

@Override
public void visit(Array array) {
boolean allLiterals = true;
Class<?> baseType = array.getBaseType();
//the pg driver expects only values that are convertible to string
//we could introduce some conversions, but for now we'll just fail
//some cases- there's also potential issue with date time as this logic
//won't consider the database timezone setting
if (!baseType.isArray() &&
postgreSQLExecutionFactory.convertModifier.getSimpleTypeMapping(ConvertModifier.getCode(baseType)) != null) {
for (Expression ex : array.getExpressions()) {
if (!(ex instanceof Literal)) {
allLiterals = false;
break;
}
}
if (allLiterals) {
//TODO: this could be pushed to the language bridge factory
//to just push a literal array
addBinding(new Literal(array, array.getType()));
return;
}
}
//mixed or lob case
//TODO: if this is used in the context specifically of an array, rather than
//a row value, then this will fail
super.visit(array);
}

/*
@Override
public void visit(In obj) {
//TODO: array binding TEIID-3537
super.visit(obj);
}
*/
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,15 @@
import org.teiid.core.types.BinaryType;
import org.teiid.core.types.ClobImpl;
import org.teiid.core.types.JsonType;
import org.teiid.language.*;
import org.teiid.language.AggregateFunction;
import org.teiid.language.Array;
import org.teiid.language.Expression;
import org.teiid.language.Function;
import org.teiid.language.LanguageObject;
import org.teiid.language.Like;
import org.teiid.language.Like.MatchMode;
import org.teiid.language.Limit;
import org.teiid.language.Literal;
import org.teiid.language.SQLConstants.NonReserved;
import org.teiid.language.visitor.SQLStringVisitor;
import org.teiid.logging.LogConstants;
Expand Down Expand Up @@ -221,6 +228,7 @@ public List<?> translate(Function function) {

@Override
public List<?> translate(Function function) {
//TODO: doesn't work for all array expressions
return Arrays.asList(function.getParameters().get(0), '[', function.getParameters().get(1), ']');
}
});
Expand Down Expand Up @@ -888,53 +896,7 @@ public void loadedTemporaryTable(String tableName,

@Override
public SQLConversionVisitor getSQLConversionVisitor() {
return new SQLConversionVisitor(this) {
@Override
protected void appendWithKeyword(With obj) {
super.appendWithKeyword(obj);
for (WithItem with : obj.getItems()) {
if (with.isRecusive()) {
buffer.append(SQLConstants.Tokens.SPACE);
buffer.append(SQLConstants.Reserved.RECURSIVE);
break;
}
}
}

/**
* Some literals in the select need a cast to prevent being seen as the unknown/string type
*/
@Override
public void visit(DerivedColumn obj) {
if (obj.getExpression() instanceof Literal) {
String castType = null;
if (obj.getExpression().getType() == TypeFacility.RUNTIME_TYPES.STRING) {
castType = "bpchar"; //$NON-NLS-1$
} else if (obj.getExpression().getType() == TypeFacility.RUNTIME_TYPES.VARBINARY) {
castType = "bytea"; //$NON-NLS-1$
}
if (castType != null) {
obj.setExpression(getLanguageFactory().createFunction("cast", //$NON-NLS-1$
new Expression[] {obj.getExpression(), getLanguageFactory().createLiteral(castType, TypeFacility.RUNTIME_TYPES.STRING)},
TypeFacility.RUNTIME_TYPES.STRING));
}
} else if (obj.isProjected() && obj.getExpression() instanceof ColumnReference) {
ColumnReference elem = (ColumnReference)obj.getExpression();
if (elem.getMetadataObject() != null) {
String nativeType = elem.getMetadataObject().getNativeType();
if (TypeFacility.RUNTIME_TYPES.STRING.equals(elem.getType())
&& elem.getMetadataObject() != null
&& nativeType != null
&& nativeType.equalsIgnoreCase(UUID_TYPE)) {
obj.setExpression(getLanguageFactory().createFunction("cast", //$NON-NLS-1$
new Expression[] {obj.getExpression(), getLanguageFactory().createLiteral("varchar", TypeFacility.RUNTIME_TYPES.STRING)}, //$NON-NLS-1$
TypeFacility.RUNTIME_TYPES.STRING));
}
}
}
super.visit(obj);
}
};
return new PostgreSQLConversionVisitor(this);
}

public void setPostGisVersion(String postGisVersion) {
Expand Down Expand Up @@ -1025,6 +987,24 @@ public void bindValue(PreparedStatement stmt, Object param,
|| paramType == TypeFacility.RUNTIME_TYPES.GEOGRAPHY)) {
//the blob sql type causes a failure with nulls
paramType = TypeFacility.RUNTIME_TYPES.VARBINARY;
} else if (param instanceof Array) {
//pg allows for the direct binding of certain arrays
Array array = (Array)param;
Connection c = stmt.getConnection();
int code = ConvertModifier.getCode(array.getBaseType());
String nativeType = convertModifier.getSimpleTypeMapping(code);
int index = nativeType.indexOf('(');
if (index > 0) {
nativeType = nativeType.substring(0, index);
}
Object[] values = new Object[array.getExpressions().size()];
for (int j = 0; j < values.length; j++) {
Expression ex = array.getExpressions().get(j);
values[j] = ((Literal)ex).getValue();
}
java.sql.Array value = c.createArrayOf(nativeType, values);
stmt.setArray(i, value);
return;
}
super.bindValue(stmt, param, paramType, i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
import static org.junit.Assert.*;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.util.Arrays;

import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Mockito;
import org.teiid.core.util.SimpleMock;
import org.teiid.language.Array;
import org.teiid.language.Expression;
import org.teiid.language.Literal;
import org.teiid.translator.SourceSystemFunctions;
import org.teiid.translator.TranslatorException;
import org.teiid.translator.jdbc.TranslationHelper;
Expand Down Expand Up @@ -692,5 +698,41 @@ public void helpTestVisitor(String vdb, String input, String expectedOutput) thr
String output = "SELECT SmallA.StringNum FROM SmallA WHERE (SmallA.StringKey ilike 'a_') = TRUE AND (SmallA.StringKey ~* '_b') <> TRUE"; //$NON-NLS-1$
TranslationHelper.helpTestVisitor(TranslationHelper.BQT_VDB, input, output, TRANSLATOR);
}

@Test public void testArrayInsert() throws Exception {
String input = "insert into t (id, b) values ('a', (true,false))"; //$NON-NLS-1$
String output = "INSERT INTO t (id, b) VALUES ('a', ?)"; //$NON-NLS-1$

helpTestVisitor("create foreign table t (id string, b boolean[])",
input,
output);
}

@Test public void testArrayUpdate() throws Exception {
String input = "update t set b = (null, 'a') where id = 'b'"; //$NON-NLS-1$
String output = "UPDATE t SET b = ? WHERE t.id = 'b'"; //$NON-NLS-1$

helpTestVisitor("create foreign table t (id string, b string[])",
input,
output);
}

@Test public void testArrayComparison() throws Exception {
String input = "SELECT id from t where (id,b) = ('a', true)"; //$NON-NLS-1$
String output = "SELECT t.id FROM t WHERE (t.id, t.b) = ('a', TRUE)"; //$NON-NLS-1$

helpTestVisitor("create foreign table t (id string, b boolean)",
input,
output);
}

@Test public void testArrayBind() throws Exception {
PreparedStatement ps = Mockito.mock(PreparedStatement.class);
Connection c = Mockito.mock(Connection.class);
Mockito.stub(ps.getConnection()).toReturn(c);
TRANSLATOR.bindValue(ps, new Array(String.class,
Arrays.asList((Expression)new Literal("a", String.class))), String[].class, 1);
Mockito.verify(c, Mockito.times(1)).createArrayOf("varchar", new Object[] {"a"});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,11 @@ org.teiid.language.Expression translate(Constant constant) {
if (baseType == null) {
baseType = c.getType();
} else if (!baseType.equals(c.getType())) {
baseType = DataTypeManager.DefaultDataClasses.OBJECT;
if (baseType == DataTypeManager.DefaultDataClasses.NULL) {
baseType = c.getType();
} else if (c.getType() != DataTypeManager.DefaultDataClasses.NULL) {
baseType = DataTypeManager.DefaultDataClasses.OBJECT;
}
}
}
return new org.teiid.language.Array(baseType, translateExpressionList(vals));
Expand Down
10 changes: 10 additions & 0 deletions engine/src/main/java/org/teiid/query/rewriter/QueryRewriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -2312,6 +2312,16 @@ private Expression rewriteExpressionDirect(Expression expression) throws TeiidCo
windowFrame.setEnd(null);
}
}
} else if (expression instanceof Array) {
Array array = (Array)expression;
boolean foundAny = false;
for (Expression ex : array.getExpressions()) {
if(!isConstantConvert(ex)) {
foundAny = true;
break;
}
}
isBindEligible = foundAny;
}
rewriteExpressions(expression);
}
Expand Down

0 comments on commit 5b70509

Please sign in to comment.