Skip to content

Commit

Permalink
Reprepare invalid and missing prepared statements.
Browse files Browse the repository at this point in the history
We now reprepare (retry) a statement that is contextually invalid or cannot be found on the server.

[resolves #271]

Signed-off-by: Mark Paluch <mpaluch@vmware.com>
  • Loading branch information
mp911de committed Jun 23, 2023
1 parent 617ea48 commit 4bf95d6
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 31 deletions.
89 changes: 66 additions & 23 deletions src/main/java/io/r2dbc/mssql/RpcQueryMessageFlow.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2018-2022 the original author or authors.
* Copyright 2018-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,15 +25,7 @@
import io.r2dbc.mssql.message.ClientMessage;
import io.r2dbc.mssql.message.Message;
import io.r2dbc.mssql.message.TransactionDescriptor;
import io.r2dbc.mssql.message.token.AbstractDoneToken;
import io.r2dbc.mssql.message.token.AbstractInfoToken;
import io.r2dbc.mssql.message.token.ColumnMetadataToken;
import io.r2dbc.mssql.message.token.DoneInProcToken;
import io.r2dbc.mssql.message.token.DoneProcToken;
import io.r2dbc.mssql.message.token.ErrorToken;
import io.r2dbc.mssql.message.token.ReturnValue;
import io.r2dbc.mssql.message.token.RowToken;
import io.r2dbc.mssql.message.token.RpcRequest;
import io.r2dbc.mssql.message.token.*;
import io.r2dbc.mssql.message.type.Collation;
import io.r2dbc.mssql.util.Assert;
import io.r2dbc.mssql.util.Operators;
Expand All @@ -46,6 +38,7 @@
import reactor.util.Loggers;

import javax.annotation.processing.Completion;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.function.Consumer;
import java.util.function.Predicate;
Expand Down Expand Up @@ -224,29 +217,30 @@ static Flux<Message> exchange(PreparedStatementCache statementCache, Client clie
Assert.requireNonNull(query, "Query must not be null");

Sinks.Many<ClientMessage> outbound = Sinks.many().unicast().onBackpressureBuffer();
CursorState state = new CursorState();

int handle = statementCache.getHandle(query, binding);

boolean needsPrepare;
AtomicBoolean retryReprepare = new AtomicBoolean(true);
AtomicBoolean needsPrepare = new AtomicBoolean(false);

Flux<ClientMessage> messageProducer;

if (handle == PreparedStatementCache.UNPREPARED) {
messageProducer = Flux.defer(() -> {
outbound.emitNext(spCursorPrepExec(PreparedStatementCache.UNPREPARED, query, binding, client.getRequiredCollation(),
client.getTransactionDescriptor()), Sinks.EmitFailureHandler.FAIL_FAST);
client.getTransactionDescriptor()), Sinks.EmitFailureHandler.FAIL_FAST);
return outbound.asFlux();
});

needsPrepare = true;
needsPrepare.set(true);
} else {
messageProducer = Flux.defer(() -> {
outbound.emitNext(spCursorExec(handle, binding, client.getTransactionDescriptor()), Sinks.EmitFailureHandler.FAIL_FAST);
return outbound.asFlux();
});
needsPrepare = false;
needsPrepare.set(false);
}

CursorState state = new CursorState();
Flux<Message> exchange = client.exchange(messageProducer, isFinalToken(state));
OnCursorComplete cursorComplete = new OnCursorComplete();

Expand All @@ -258,7 +252,7 @@ static Flux<Message> exchange(PreparedStatementCache statementCache, Client clie

ReturnValue returnValue = (ReturnValue) message;

emit = handleSpCursorReturnValue(statementCache, codecs, query, binding, state, needsPrepare, returnValue);
emit = handleSpCursorReturnValue(statementCache, codecs, query, binding, state, needsPrepare.get(), returnValue);

if (!emit) {
returnValue.release();
Expand All @@ -267,6 +261,27 @@ static Flux<Message> exchange(PreparedStatementCache statementCache, Client clie

state.update(message);

if (message instanceof ErrorToken) {
if (isPreparedStatementNotFound(((ErrorToken) message).getNumber()) && retryReprepare.compareAndSet(true, false)) {
logger.debug("Prepared statement no longer valid: {}", handle);
state.update(Phase.PREPARE_RETRY);
}
}

if (state.phase == Phase.PREPARE_RETRY) {
emit = false;
}

if (DoneProcToken.isDone(message) && state.phase == Phase.PREPARE_RETRY) {

logger.debug("Attempting to re-prepare statement: {}", query);
needsPrepare.set(true);
state.update(Phase.NONE);
outbound.emitNext(spCursorPrepExec(PreparedStatementCache.UNPREPARED, query, binding, client.getRequiredCollation(),
client.getTransactionDescriptor()), Sinks.EmitFailureHandler.FAIL_FAST);
return;
}

handleMessage(client, fetchSize, outbound, state, message, sink, cursorComplete, emit);
})
.filter(FILTER_PREDICATE);
Expand All @@ -277,6 +292,21 @@ static Flux<Message> exchange(PreparedStatementCache statementCache, Client clie
.transform(it -> Operators.discardOnCancel(it, state::cancel).doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release)).takeUntilOther(cursorComplete.takeUntil());
}

/**
* Check whether the error indicates a prepared statement requiring reprepare.
* <p>
* <ul><li>586: The prepared statement handle %d is not valid in this context. Please verify that current database, user
* default schema ANSI_NULLS and QUOTED_IDENTIFIER set options are not changed since the handle is prepared.</li>
* <li>8179: Could not find prepared statement with handle %d.</li>
* </ul>
*
* @param errorNumber
* @return
*/
private static boolean isPreparedStatementNotFound(long errorNumber) {
return errorNumber == 8179 || errorNumber == 586;
}

private static boolean handleSpCursorReturnValue(PreparedStatementCache statementCache, Codecs codecs, String query, Binding binding, CursorState state, boolean needsPrepare,
ReturnValue returnValue) {

Expand Down Expand Up @@ -356,7 +386,7 @@ private static void handleMessage(Client client, int fetchSize, Consumer<ClientM

if (AbstractDoneToken.isAttentionAck(message)) {

state.phase = Phase.CLOSED;
state.update(Phase.CLOSED);
sink.next(message);
return;
}
Expand All @@ -370,7 +400,7 @@ private static void handleMessage(Client client, int fetchSize, Consumer<ClientM
}

if (state.hasSeenError) {
state.phase = Phase.ERROR;
state.update(Phase.ERROR);
}

if (DoneProcToken.isDone(message)) {
Expand All @@ -386,19 +416,19 @@ static void onDone(Client client, int fetchSize, Consumer<ClientMessage> request

completion.run();

state.phase = Phase.CLOSED;
state.update(Phase.CLOSED);
return;
}

if (phase == Phase.NONE || phase == Phase.FETCHING) {

if (((state.hasMore && phase == Phase.NONE) || state.hasSeenRows) && state.wantsMore()) {
if (phase == Phase.NONE) {
state.phase = Phase.FETCHING;
state.update(Phase.FETCHING);
}
requests.accept(spCursorFetch(state.cursorId, FETCH_NEXT, fetchSize, client.getTransactionDescriptor()));
} else {
state.phase = Phase.CLOSING;
state.update(Phase.CLOSING);
// TODO: spCursorClose should happen also if a subscriber cancels its subscription.
requests.accept(spCursorClose(state.cursorId, client.getTransactionDescriptor()));
}
Expand Down Expand Up @@ -628,6 +658,8 @@ static class CursorState {

volatile boolean cancelRequested;

volatile ErrorToken errorToken;

Phase phase = Phase.NONE;

boolean wantsMore() {
Expand All @@ -644,12 +676,23 @@ void update(Message it) {
}

if (it instanceof ErrorToken) {
this.errorToken = (ErrorToken) it;
this.hasSeenError = true;
}
}

public void update(Phase newPhase) {

this.phase = newPhase;

if (newPhase == Phase.PREPARE_RETRY) {
errorToken = null;
hasSeenError = false;
}
}

enum Phase {
NONE, FETCHING, CLOSING, CLOSED, ERROR
NONE, FETCHING, PREPARE_RETRY, CLOSING, CLOSED, ERROR
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,48 @@ void shouldEmitSingleResultForCursoredExecution() {
assertThat(rowCounter).hasValue(3);
}

@Test
void shouldRepreparePreparedStatement() {

shouldExecuteBatch();

connection.createStatement("SET ANSI_NULLS ON")
.execute()
.flatMap(MssqlResult::getRowsUpdated)
.as(StepVerifier::create)
.verifyComplete();

Flux.from(connection.createStatement("SELECT first_name FROM r2dbc_example where id != @P0")
.fetchSize(2)
.bind("P0", 99)
.execute())
.flatMap(result -> {

return result.map((row, rowMetadata) -> new Object());
})
.as(StepVerifier::create)
.expectNextCount(3)
.verifyComplete();

connection.createStatement("SET ANSI_NULLS OFF")
.execute()
.flatMap(MssqlResult::getRowsUpdated)
.as(StepVerifier::create)
.verifyComplete();

Flux.from(connection.createStatement("SELECT first_name FROM r2dbc_example where id != @P0")
.fetchSize(2)
.bind("P0", 99)
.execute())
.flatMap(result -> {

return result.map((row, rowMetadata) -> new Object());
})
.as(StepVerifier::create)
.expectNextCount(3)
.verifyComplete();
}

@Test
void shouldRunStatementWithMultipleResults() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ void shouldCachePreparedStatementHandle() {
value.skipBytes(1); // skip maxlen byte

TestClient testClient = TestClient.builder()
.assertNextRequestWith(it -> {
assertThat(it).isInstanceOf(RpcRequest.class);
RpcRequest request = (RpcRequest) it;
assertThat(request.getProcId()).isEqualTo(RpcRequest.Sp_CursorPrepExec);
})
.thenRespond(new ReturnValue(0, null, (byte) 0, Types.integer(),
value))
.build();
.assertNextRequestWith(it -> {
assertThat(it).isInstanceOf(RpcRequest.class);
RpcRequest request = (RpcRequest) it;
assertThat(request.getProcId()).isEqualTo(RpcRequest.Sp_CursorPrepExec);
})
.thenRespond(new ReturnValue(0, null, (byte) 0, Types.integer(),
value))
.build();

String sql = "SELECT * from FOO where firstname = @firstname";
ParametrizedMssqlStatement statement = new ParametrizedMssqlStatement(testClient, this.connectionOptions, sql);
Expand Down

0 comments on commit 4bf95d6

Please sign in to comment.