Skip to content

Commit 1e7cf33

Browse files
divangDivang SharmaDavid-Engel
committed
Fixed session recovery for Azure SQL DB in redirect mode connected with AAD auth (#2668)
* In case of federated authentication request, do not terminate the connection to resolve Fed auth connection resiliency issue * In place of serverSupportsDNSCaching, routingInfo value is a better choice to take decision on connection termination. If endpoint is routed then no need to terminate the current connection. * Added mock test for connectionReconveryCheck() method --------- Co-authored-by: Divang Sharma <divangsharma@microsoft.com> Co-authored-by: David Engel <davidengel@microsoft.com>
1 parent da36c13 commit 1e7cf33

File tree

2 files changed

+73
-9
lines changed

2 files changed

+73
-9
lines changed

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7199,27 +7199,33 @@ final boolean complete(LogonCommand logonCommand, TDSReader tdsReader) throws SQ
71997199

72007200
LogonProcessor logonProcessor = new LogonProcessor(authentication);
72017201
TDSReader tdsReader;
7202+
sessionRecovery.setConnectionRecoveryPossible(false);
72027203
do {
72037204
tdsReader = logonCommand.startResponse();
7204-
sessionRecovery.setConnectionRecoveryPossible(false);
72057205
TDSParser.parse(tdsReader, logonProcessor);
72067206
} while (!logonProcessor.complete(logonCommand, tdsReader));
72077207

7208-
if (sessionRecovery.isReconnectRunning() && !sessionRecovery.isConnectionRecoveryPossible()) {
7208+
connectionReconveryCheck(sessionRecovery.isReconnectRunning(), sessionRecovery.isConnectionRecoveryPossible(),
7209+
routingInfo);
7210+
7211+
if (connectRetryCount > 0 && !sessionRecovery.isReconnectRunning()) {
7212+
sessionRecovery.getSessionStateTable().setOriginalCatalog(sCatalog);
7213+
sessionRecovery.getSessionStateTable().setOriginalCollation(databaseCollation);
7214+
sessionRecovery.getSessionStateTable().setOriginalLanguage(sLanguage);
7215+
}
7216+
}
7217+
7218+
private void connectionReconveryCheck(boolean isReconnectRunning, boolean isConnectionRecoveryPossible,
7219+
ServerPortPlaceHolder routingDetails) throws SQLServerException {
7220+
if (isReconnectRunning && !isConnectionRecoveryPossible && routingDetails == null) {
72097221
if (connectionlogger.isLoggable(Level.WARNING)) {
72107222
connectionlogger.warning(this.toString()
72117223
+ "SessionRecovery feature extension ack was not sent by the server during reconnection.");
72127224
}
72137225
terminate(SQLServerException.DRIVER_ERROR_INVALID_TDS,
72147226
SQLServerException.getErrString("R_crClientNoRecoveryAckFromLogin"));
72157227
}
7216-
if (connectRetryCount > 0 && !sessionRecovery.isReconnectRunning()) {
7217-
sessionRecovery.getSessionStateTable().setOriginalCatalog(sCatalog);
7218-
sessionRecovery.getSessionStateTable().setOriginalCollation(databaseCollation);
7219-
sessionRecovery.getSessionStateTable().setOriginalLanguage(sLanguage);
7220-
}
72217228
}
7222-
72237229
/* --------------- JDBC 3.0 ------------- */
72247230

72257231
/**

src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,21 @@
99
import static org.junit.jupiter.api.Assertions.assertThrows;
1010
import static org.junit.jupiter.api.Assertions.assertTrue;
1111
import static org.junit.jupiter.api.Assertions.fail;
12+
import static org.mockito.ArgumentMatchers.anyInt;
13+
import static org.mockito.ArgumentMatchers.anyString;
14+
import static org.mockito.ArgumentMatchers.eq;
15+
import static org.mockito.Mockito.doNothing;
16+
import static org.mockito.Mockito.doReturn;
17+
import static org.mockito.Mockito.mock;
18+
import static org.mockito.Mockito.never;
19+
import static org.mockito.Mockito.spy;
20+
import static org.mockito.Mockito.times;
21+
import static org.mockito.Mockito.verify;
1222

1323
import java.io.IOException;
1424
import java.io.Reader;
1525
import java.lang.management.ManagementFactory;
26+
import java.lang.reflect.Method;
1627
import java.sql.Clob;
1728
import java.sql.Connection;
1829
import java.sql.DriverManager;
@@ -29,6 +40,7 @@
2940
import java.util.concurrent.ExecutorService;
3041
import java.util.concurrent.Executors;
3142
import java.util.concurrent.Future;
43+
import java.util.logging.Level;
3244
import java.util.logging.Logger;
3345

3446
import javax.sql.ConnectionEvent;
@@ -60,6 +72,9 @@ public class SQLServerConnectionTest extends AbstractTest {
6072

6173
String randomServer = RandomUtil.getIdentifier("Server");
6274

75+
SQLServerConnection mockConnection;
76+
Logger mockLogger;
77+
6378
@BeforeAll
6479
public static void setupTests() throws Exception {
6580
setConnection();
@@ -1400,6 +1415,49 @@ public void testManagedIdentityWithEncryptStrict() {
14001415
} catch (SQLException e) {
14011416
fail("Connection failed: " + e.getMessage());
14021417
}
1403-
}
1418+
}
1419+
1420+
public Method mockedConnectionRecoveryCheck() throws Exception {
1421+
mockConnection = spy(new SQLServerConnection("test"));
1422+
mockLogger = mock(Logger.class);
1423+
doReturn(true).when(mockLogger).isLoggable(Level.WARNING);
1424+
doNothing().when(mockConnection).terminate(anyInt(), anyString());
1425+
1426+
Method method = SQLServerConnection.class.getDeclaredMethod("connectionReconveryCheck", boolean.class,
1427+
boolean.class, ServerPortPlaceHolder.class);
1428+
method.setAccessible(true);
1429+
return method;
1430+
}
1431+
1432+
@Test
1433+
void testConnectionRecoveryCheckThrowsWhenAllConditionsMet() throws Exception {
1434+
Method method = mockedConnectionRecoveryCheck();
1435+
method.invoke(mockConnection, true, false, null);
1436+
verify(mockConnection, times(1)).terminate(eq(SQLServerException.DRIVER_ERROR_INVALID_TDS),
1437+
eq(SQLServerException.getErrString("R_crClientNoRecoveryAckFromLogin")));
1438+
}
1439+
1440+
@Test
1441+
void testConnectionRecoveryCheckDoesNotThrowWhenNotReconnectRunning() throws Exception {
1442+
Method method = mockedConnectionRecoveryCheck();
1443+
method.invoke(mockConnection, false, false, null);
1444+
verify(mockConnection, never()).terminate(anyInt(), anyString());
1445+
}
1446+
1447+
@Test
1448+
void testConnectionRecoveryCheckDoesNotThrowWhenRecoveryPossible() throws Exception {
1449+
Method method = mockedConnectionRecoveryCheck();
1450+
method.invoke(mockConnection, true, true, null);
1451+
verify(mockConnection, never()).terminate(anyInt(), anyString());
1452+
}
1453+
1454+
@Test
1455+
void testConnectionRecoveryCheckDoesNotThrowWhenRoutingDetailsNotNull() throws Exception {
1456+
Method method = mockedConnectionRecoveryCheck();
1457+
ServerPortPlaceHolder routingDetails = mock(ServerPortPlaceHolder.class);
1458+
method.setAccessible(true);
1459+
method.invoke(mockConnection, true, false, routingDetails);
1460+
verify(mockConnection, never()).terminate(anyInt(), anyString());
1461+
}
14041462

14051463
}

0 commit comments

Comments
 (0)