Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pgwire): fix spurious error when executing "create table" SQL from Rust #2385

Merged
merged 17 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

import static io.questdb.cairo.sql.OperationFuture.QUERY_COMPLETE;
import static io.questdb.cutlass.pgwire.PGOids.*;
import static io.questdb.std.datetime.millitime.DateFormatUtils.PG_DATE_MILLI_TIME_Z_FORMAT;
import static io.questdb.std.datetime.millitime.DateFormatUtils.PG_DATE_MILLI_TIME_Z_PRINT_FORMAT;
import static io.questdb.std.datetime.millitime.DateFormatUtils.PG_DATE_Z_FORMAT;

/**
Expand All @@ -67,6 +67,8 @@ public class PGConnectionContext implements IOContext, Mutable, WriterSource {
public static final String TAG_COPY = "COPY";
public static final String TAG_INSERT = "INSERT";
public static final String TAG_UPDATE = "UPDATE";
// create as select tag
public static final String TAG_CTAS = "CTAS";
public static final char STATUS_IN_TRANSACTION = 'T';
public static final char STATUS_IN_ERROR = 'E';
public static final char STATUS_IDLE = 'I';
Expand Down Expand Up @@ -633,7 +635,7 @@ private void appendDateColumn(Record record, int columnIndex) {
final long longValue = record.getDate(columnIndex);
if (longValue != Numbers.LONG_NaN) {
final long a = responseAsciiSink.skip();
PG_DATE_MILLI_TIME_Z_FORMAT.format(longValue, null, null, responseAsciiSink);
PG_DATE_MILLI_TIME_Z_PRINT_FORMAT.format(longValue, null, null, responseAsciiSink);
responseAsciiSink.putLenEx(a);
} else {
responseAsciiSink.setNullValue();
Expand Down Expand Up @@ -1172,6 +1174,9 @@ private void configurePreparedStatement(CharSequence statementName) throws BadPr
if (index > -1) {
wrapper = namedStatementWrapperPool.pop();
wrapper.queryText = Chars.toString(queryText);
// COPY 'id' CANCEL; queries shouldn't be compiled multiple times, but it's fine to compile
// COPY 'x' FROM ...; queries multiple times since the import is executed lazily
wrapper.alreadyExecuted = (queryTag == TAG_OK || queryTag == TAG_CTAS || (queryTag == TAG_COPY && typesAndSelect == null));
namedStatementMap.putAt(index, Chars.toString(statementName), wrapper);
this.activeBindVariableTypes = wrapper.bindVariableTypes;
this.activeSelectColumnTypes = wrapper.selectColumnTypes;
Expand Down Expand Up @@ -1901,12 +1906,11 @@ private void processClose(long lo, long msgLimit) throws BadProtocolException {
final CharSequence statementName = getStatementName(lo, hi);
if (statementName != null) {
final int index = namedStatementMap.keyIndex(statementName);
// do not freak out if client is closing statement we don't have
// we could have reported error to client before statement was created
if (index < 0) {
namedStatementWrapperPool.push(namedStatementMap.valueAt(index));
namedStatementMap.removeAt(index);
} else {
LOG.error().$("invalid statement name [value=").$(statementName).$(']').$();
throw BadProtocolException.INSTANCE;
}
}
break;
Expand Down Expand Up @@ -1937,7 +1941,7 @@ private void processCompiledQuery(CompiledQuery cq) throws SqlException {

switch (cq.getType()) {
case CompiledQuery.CREATE_TABLE_AS_SELECT:
queryTag = TAG_SELECT;
queryTag = TAG_CTAS;
rowCount = cq.getAffectedRowsCount();
break;
case CompiledQuery.SELECT:
Expand Down Expand Up @@ -2178,6 +2182,10 @@ private void processInitialMessage(long address, int len) throws PeerDisconnecte
}

private void processParse(long address, long lo, long msgLimit, @Transient SqlCompiler compiler) throws BadProtocolException, SqlException {
// make sure there are no left-over sync actions
// we are starting a new iteration of the parse
syncActions.clear();

// 'Parse'
//message length
long hi = getStringLength(lo, msgLimit, "bad prepared statement name length");
Expand Down Expand Up @@ -2244,12 +2252,19 @@ private void processQuery(long lo, long limit, @Transient SqlCompiler compiler)

if (Chars.utf8Decode(lo, limit - 1, e)) {
queryText = characterStore.toImmutable();
compiler.compileBatch(queryText, sqlExecutionContext, batchCallback);
try {
compiler.compileBatch(queryText, sqlExecutionContext, batchCallback);
// we need to continue parsing receive buffer even if we errored out
// this is because PG client might expect separate responses to everything it sent
} catch (SqlException ex) {
prepareError(ex.getPosition(), ex.getFlyweightMessage(), 0);
} catch (CairoException ex) {
prepareError(0, ex.getFlyweightMessage(), ex.getErrno());
}
} else {
LOG.error().$("invalid UTF8 bytes in parse query").$();
throw BadProtocolException.INSTANCE;
}

sendReadyForNewQuery();
}

Expand Down Expand Up @@ -2497,9 +2512,11 @@ private void setupVariableSettersFromWrapper(
this.activeBindVariableTypes = wrapper.bindVariableTypes;
this.parsePhaseBindVariableCount = wrapper.bindVariableTypes.size();
this.activeSelectColumnTypes = wrapper.selectColumnTypes;
if (compileQuery(compiler) && typesAndSelect != null) {
if (!wrapper.alreadyExecuted && compileQuery(compiler) && typesAndSelect != null) {
buildSelectColumnTypes();
}
// We'll have to compile/execute the statement next time.
wrapper.alreadyExecuted = false;
}

private void shiftReceiveBuffer(long readOffsetBeforeParse) {
Expand Down Expand Up @@ -2549,6 +2566,8 @@ public static class NamedStatementWrapper implements Mutable {
public final IntList bindVariableTypes = new IntList();
public final IntList selectColumnTypes = new IntList();
public CharSequence queryText = null;
// Used for statements that are executed as a part of compilation (PREPARE), such as DDLs.
public boolean alreadyExecuted = false;

@Override
public void clear() {
Expand Down
17 changes: 13 additions & 4 deletions core/src/main/java/io/questdb/cutlass/pgwire/PGWireServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,19 @@ public static PGWireServer create(
sharedWorkerPool,
log,
cairoEngine,
(conf, engine, workerPool, local, functionFactoryCache1, snapshotAgent1, metrics1) -> {
PGConnectionContextFactory contextFactory = new PGConnectionContextFactory(engine, conf, workerPool.getWorkerCount());
return new PGWireServer(conf, engine, workerPool, local, functionFactoryCache1, snapshotAgent1, contextFactory);
},
(conf, engine, workerPool, local, cache, agent, m) -> new PGWireServer(
conf,
engine,
workerPool,
local,
cache,
agent,
new PGConnectionContextFactory(
engine,
conf,
workerPool.getWorkerCount()
)
),
functionFactoryCache,
snapshotAgent,
metrics
Expand Down
139 changes: 57 additions & 82 deletions core/src/main/java/io/questdb/griffin/SqlCompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@
import io.questdb.griffin.model.*;
import io.questdb.log.Log;
import io.questdb.log.LogFactory;
import io.questdb.mp.MPSequence;
import io.questdb.mp.RingQueue;
import io.questdb.network.PeerDisconnectedException;
import io.questdb.network.PeerIsSlowToReadException;
import io.questdb.std.*;
Expand Down Expand Up @@ -1882,102 +1880,79 @@ private long copyOrderedStrTimestamp(TableWriter writer, RecordCursor cursor, Re
private RecordCursorFactory executeCopy0(SqlExecutionContext executionContext, CopyModel model) throws SqlException {
try {
int workerCount = executionContext.getWorkerCount();
ExpressionNode fileNameNode = model.getFileName();
final CharSequence fileName = fileNameNode != null ? GenericLexer.assertNoDots(GenericLexer.unquote(fileNameNode.token), fileNameNode.position) : null;
if (workerCount < 1) {
throw SqlException.$(0, "Invalid worker count set [value=").put(workerCount).put("]");
if (workerCount < 1) {
throw SqlException.$(0, "Invalid worker count set [value=").put(workerCount).put("]");
}
if (model.isCancel()) {
cancelTextImport(model);
return null;
} else {
if (model.getTimestampColumnName() == null &&
((model.getPartitionBy() != -1 && model.getPartitionBy() != PartitionBy.NONE))) {
throw SqlException.$(-1, "invalid option used for import without a designated timestamp (format or partition by)");
}
if (model.isCancel()) {
addTextImportRequest(model, null);
return null;
} else {
if (model.getTimestampColumnName() == null &&
((model.getPartitionBy() != -1 && model.getPartitionBy() != PartitionBy.NONE))) {
throw SqlException.$(-1, "invalid option used for import without a designated timestamp (format or partition by)");
}
if (model.getTimestampFormat() == null) {
model.setTimestampFormat("yyyy-MM-ddTHH:mm:ss.SSSUUUZ");
}
if (model.getDelimiter() < 0) {
model.setDelimiter((byte) ',');
}
long importId = addTextImportRequest(model, fileName);
return new CopyFactory(importId);
if (model.getTimestampFormat() == null) {
model.setTimestampFormat("yyyy-MM-ddTHH:mm:ss.SSSUUUZ");
}
if (model.getDelimiter() < 0) {
model.setDelimiter((byte) ',');
}
return compileTextImport(model);
}
} catch (TextImportException | TextException e) {
LOG.error().$((Throwable) e).$();
throw SqlException.$(0, e.getMessage());
}
}

private long addTextImportRequest(CopyModel model, @Nullable CharSequence fileName) throws SqlException {
final RingQueue<TextImportRequestTask> textImportRequestQueue = messageBus.getTextImportRequestQueue();
final MPSequence textImportRequestPubSeq = messageBus.getTextImportRequestPubSeq();
private void cancelTextImport(CopyModel model) throws SqlException {
assert model.isCancel();

final TextImportExecutionContext textImportExecutionContext = engine.getTextImportExecutionContext();
final AtomicBooleanCircuitBreaker circuitBreaker = textImportExecutionContext.getCircuitBreaker();

long inProgressImportId = textImportExecutionContext.getActiveImportId();
if (model.isCancel()) {
// The cancellation is based on the best effort, so we don't worry about potential races with imports.
if (inProgressImportId == TextImportExecutionContext.INACTIVE) {
throw SqlException.$(0, "No active import to cancel.");
}
long importId;
try {
CharSequence idString = model.getTarget().token;
int start = 0;
int end = idString.length();
if (Chars.isQuoted(idString)) {
start = 1;
end--;
}
importId = Numbers.parseHexLong(idString, start, end);
} catch (NumericException e) {
throw SqlException.$(0, "Provided id has invalid format.");
}
if (inProgressImportId == importId) {
circuitBreaker.cancel();
return -1;
} else {
throw SqlException.$(0, "Active import has different id.");
}
// The cancellation is based on the best effort, so we don't worry about potential races with imports.
if (inProgressImportId == TextImportExecutionContext.INACTIVE) {
throw SqlException.$(0, "No active import to cancel.");
}
long importId;
try {
CharSequence idString = model.getTarget().token;
int start = 0;
int end = idString.length();
if (Chars.isQuoted(idString)) {
start = 1;
end--;
}
importId = Numbers.parseHexLong(idString, start, end);
} catch (NumericException e) {
throw SqlException.$(0, "Provided id has invalid format.");
}
if (inProgressImportId == importId) {
circuitBreaker.cancel();
} else {
if (inProgressImportId == TextImportExecutionContext.INACTIVE) {
long processingCursor = textImportRequestPubSeq.next();
if (processingCursor > -1) {
assert fileName != null;

final TextImportRequestTask task = textImportRequestQueue.get(processingCursor);
final CharSequence tableName = GenericLexer.unquote(model.getTarget().token);

long importId = textImportExecutionContext.assignActiveImportId();
task.of(
importId,
Chars.toString(tableName),
Chars.toString(fileName),
model.isHeader(),
Chars.toString(model.getTimestampColumnName()),
model.getDelimiter(),
Chars.toString(model.getTimestampFormat()),
model.getPartitionBy(),
model.getAtomicity()
);

circuitBreaker.reset();
textImportRequestPubSeq.done(processingCursor);
return importId;
} else {
throw SqlException.$(0, "Unable to process the import request. Another import request may be in progress.");
}
} else {
throw SqlException.$(0, "Another import request is in progress. ")
.put("[activeImportId=")
.put(inProgressImportId)
.put(']');
}
throw SqlException.$(0, "Active import has different id.");
}
}

private CopyFactory compileTextImport(CopyModel model) throws SqlException {
assert !model.isCancel();

final CharSequence tableName = GenericLexer.unquote(model.getTarget().token);
final ExpressionNode fileNameNode = model.getFileName();
final CharSequence fileName = fileNameNode != null ? GenericLexer.assertNoDots(GenericLexer.unquote(fileNameNode.token), fileNameNode.position) : null;
assert fileName != null;

return new CopyFactory(
messageBus,
engine.getTextImportExecutionContext(),
Chars.toString(tableName),
Chars.toString(fileName),
model
);
}

/**
* Sets insertCount to number of copied rows.
*/
Expand Down