diff --git a/src/main/java/io/r2dbc/mssql/MssqlConnection.java b/src/main/java/io/r2dbc/mssql/MssqlConnection.java index 6a40f8e..574890b 100644 --- a/src/main/java/io/r2dbc/mssql/MssqlConnection.java +++ b/src/main/java/io/r2dbc/mssql/MssqlConnection.java @@ -43,6 +43,7 @@ * * @author Mark Paluch * @author Hebert Coelho + * @author Nayan Hajratwala * @see MssqlConnection * @see DefaultMssqlResult * @see ErrorDetails @@ -117,7 +118,7 @@ public Mono beginTransaction(TransactionDefinition transactionDefinition) if (mark != null) { String markToUse = sanitize(mark, 128); Assert.isTrue(IDENTIFIER128_PATTERN.matcher(markToUse.substring(0, Math.min(128, markToUse.length()))).matches(), "Transaction names must contain only characters and numbers and" + - " must not exceed 128 characters"); + " must not exceed 128 characters"); builder.append(' ').append("WITH MARK '").append(markToUse).append("'"); } } @@ -412,29 +413,25 @@ private static String renderSetIsolationLevel(IsolationLevel isolationLevel) { return "SET TRANSACTION ISOLATION LEVEL " + isolationLevel.asSql(); } - static String sanitize(final String identifier, final int maxLength) { - String sanitized = identifier - .replace('-', '_') - .replace('.', '_') - .substring(Math.max(0, identifier.length() - maxLength)); + static String sanitize(String identifier, int maxLength) { - if (!Character.isLetterOrDigit(sanitized.charAt(0))) { - sanitized = sanitized.substring(1); - } - return sanitized; + return identifier + .replace('-', '_') + .replace('.', '_') + .substring(0, Math.min(identifier.length(), maxLength)); } private Mono exchange(String sql) { ExceptionFactory factory = ExceptionFactory.withSql(sql); return QueryMessageFlow.exchange(this.client, sql) - .handle(factory::handleErrorResponse) - .then(); + .handle(factory::handleErrorResponse) + .then(); } private Mono useTransactionStatus(Function> function) { return Flux.defer(() -> function.apply(this.client.getTransactionStatus())) - .then(); + .then(); } enum EmptyTransactionDefinition implements TransactionDefinition { diff --git a/src/test/java/io/r2dbc/mssql/MssqlConnectionUnitTests.java b/src/test/java/io/r2dbc/mssql/MssqlConnectionUnitTests.java index d02466b..b6f8b2c 100644 --- a/src/test/java/io/r2dbc/mssql/MssqlConnectionUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/MssqlConnectionUnitTests.java @@ -26,7 +26,6 @@ import io.r2dbc.mssql.message.token.SqlBatch; import io.r2dbc.spi.IsolationLevel; import io.r2dbc.spi.ValidationDepth; -import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; @@ -38,18 +37,14 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.atLeast; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; /** * Unit tests for {@link MssqlConnection}. * * @author Mark Paluch * @author Hebert Coelho + * @author Nayan Hajratwala */ class MssqlConnectionUnitTests { @@ -61,26 +56,26 @@ class MssqlConnectionUnitTests { void shouldBeginTransactionFromInitialState() { TestClient client = - TestClient.builder().expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "BEGIN TRANSACTION;")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "BEGIN TRANSACTION;")).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.beginTransaction() - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); } @Test void shouldBeginTransactionFromExplicitState() { TestClient client = - TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "BEGIN TRANSACTION;")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "BEGIN TRANSACTION;")).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.beginTransaction() - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); } @Test @@ -93,8 +88,8 @@ void shouldNotBeginTransactionFromStartedState() { MssqlConnection connection = new MssqlConnection(clientMock, metadata, conectionOptions); connection.beginTransaction() - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); verify(clientMock, times(2)).getTransactionStatus(); verify(clientMock, atLeast(1)).getContext(); @@ -105,13 +100,13 @@ void shouldNotBeginTransactionFromStartedState() { void shouldCommitFromExplicitTransaction() { TestClient client = - TestClient.builder().withTransactionStatus(TransactionStatus.STARTED).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "IF @@TRANCOUNT > 0 COMMIT TRANSACTION;")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.STARTED).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "IF @@TRANCOUNT > 0 COMMIT TRANSACTION;")).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.commitTransaction() - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); } @Test @@ -124,8 +119,8 @@ void shouldNotCommitInAutoCommitState() { MssqlConnection connection = new MssqlConnection(clientMock, metadata, conectionOptions); connection.commitTransaction() - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); verify(clientMock, times(2)).getTransactionStatus(); verify(clientMock, atLeast(1)).getContext(); @@ -136,13 +131,13 @@ void shouldNotCommitInAutoCommitState() { void shouldRollbackFromExplicitTransaction() { TestClient client = - TestClient.builder().withTransactionStatus(TransactionStatus.STARTED).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "IF @@TRANCOUNT > 0 ROLLBACK TRANSACTION;")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.STARTED).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "IF @@TRANCOUNT > 0 ROLLBACK TRANSACTION;")).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.rollbackTransaction() - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); } @Test @@ -155,8 +150,8 @@ void shouldNotRollbackInAutoCommitState() { MssqlConnection connection = new MssqlConnection(clientMock, metadata, conectionOptions); connection.rollbackTransaction() - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); verify(clientMock, times(2)).getTransactionStatus(); verify(clientMock, atLeast(1)).getContext(); @@ -185,7 +180,7 @@ void shouldAllowSavepointNames(String name) { } @ParameterizedTest - @ValueSource(strings = {"", "@", "a'", "a\"", "a[", "a]", "123456789012345678901234567890123"}) + @ValueSource(strings = {"", "@", "a'", "a\"", "a[", "a]"}) void shouldRejectSavepointNames(String name) { Client clientMock = mock(Client.class); @@ -199,13 +194,13 @@ void shouldRejectSavepointNames(String name) { void shouldRollbackTransactionToSavepointFromExplicitTransaction() { TestClient client = - TestClient.builder().withTransactionStatus(TransactionStatus.STARTED).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "ROLLBACK TRANSACTION foo")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.STARTED).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "ROLLBACK TRANSACTION foo")).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.rollbackTransactionToSavepoint("foo") - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); } @Test @@ -218,8 +213,8 @@ void shouldNotRollbackTransactionToSavepointInAutoCommitState() { MssqlConnection connection = new MssqlConnection(clientMock, metadata, conectionOptions); connection.rollbackTransactionToSavepoint("foo") - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); verify(clientMock, times(2)).getTransactionStatus(); verify(clientMock, atLeast(1)).getContext(); @@ -230,28 +225,28 @@ void shouldNotRollbackTransactionToSavepointInAutoCommitState() { void shouldCreateSavepointFromExplicitTransaction() { TestClient client = - TestClient.builder().withTransactionStatus(TransactionStatus.STARTED).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "SET IMPLICIT_TRANSACTIONS ON; IF @@TRANCOUNT = 0 " + - "BEGIN BEGIN TRAN IF @@TRANCOUNT = 2 COMMIT TRAN END SAVE TRAN foo;")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.STARTED).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "SET IMPLICIT_TRANSACTIONS ON; IF @@TRANCOUNT = 0 " + + "BEGIN BEGIN TRAN IF @@TRANCOUNT = 2 COMMIT TRAN END SAVE TRAN foo;")).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.createSavepoint("foo") - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); } @Test void createSavepointShouldBeginTransaction() { TestClient client = - TestClient.builder().withTransactionStatus(TransactionStatus.AUTO_COMMIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "SET IMPLICIT_TRANSACTIONS ON; IF @@TRANCOUNT =" + - " 0 BEGIN BEGIN TRAN IF @@TRANCOUNT = 2 COMMIT TRAN END SAVE TRAN foo;")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.AUTO_COMMIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), "SET IMPLICIT_TRANSACTIONS ON; IF @@TRANCOUNT =" + + " 0 BEGIN BEGIN TRAN IF @@TRANCOUNT = 2 COMMIT TRAN END SAVE TRAN foo;")).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.createSavepoint("foo") - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); } @ParameterizedTest @@ -259,132 +254,117 @@ void createSavepointShouldBeginTransaction() { void shouldSetIsolationLevel(IsolationLevel isolationLevel) { TestClient client = - TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), - "SET TRANSACTION ISOLATION LEVEL " + isolationLevel.asSql().toUpperCase())).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), + "SET TRANSACTION ISOLATION LEVEL " + isolationLevel.asSql().toUpperCase())).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.setTransactionIsolationLevel(isolationLevel) - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); } @Test void shouldSetLockWaitTimeout() { TestClient client = - TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), - "SET LOCK_TIMEOUT 10000")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), + "SET LOCK_TIMEOUT 10000")).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.setLockWaitTimeout(Duration.ofSeconds(10)) - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); client = - TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), - "SET LOCK_TIMEOUT -1")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), + "SET LOCK_TIMEOUT -1")).thenRespond(DoneToken.create(0)).build(); connection = new MssqlConnection(client, metadata, conectionOptions); connection.setLockWaitTimeout(Duration.ofSeconds(-10)) - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); client = - TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), - "SET LOCK_TIMEOUT 0")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withTransactionStatus(TransactionStatus.EXPLICIT).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), + "SET LOCK_TIMEOUT 0")).thenRespond(DoneToken.create(0)).build(); connection = new MssqlConnection(client, metadata, conectionOptions); connection.setLockWaitTimeout(Duration.ZERO) - .as(StepVerifier::create) - .verifyComplete(); + .as(StepVerifier::create) + .verifyComplete(); } @Test void localValidationShouldValidateAgainstConnectionState() { TestClient connected = - TestClient.builder().withConnected(true).build(); + TestClient.builder().withConnected(true).build(); MssqlConnection connection = new MssqlConnection(connected, metadata, conectionOptions); connection.validate(ValidationDepth.LOCAL) - .as(StepVerifier::create) - .expectNext(true) - .verifyComplete(); + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); TestClient disconnected = - TestClient.builder().withConnected(false).build(); + TestClient.builder().withConnected(false).build(); connection = new MssqlConnection(disconnected, metadata, conectionOptions); connection.validate(ValidationDepth.LOCAL) - .as(StepVerifier::create) - .expectNext(false) - .verifyComplete(); + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); } @Test void remoteValidationShouldIssueQuery() { TestClient client = - TestClient.builder().withConnected(true).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), - "SELECT 1")).thenRespond(DoneToken.create(0)).build(); + TestClient.builder().withConnected(true).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), + "SELECT 1")).thenRespond(DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.validate(ValidationDepth.REMOTE) - .as(StepVerifier::create) - .expectNext(true) - .verifyComplete(); + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); } @Test void remoteValidationShouldFail() { TestClient client = - TestClient.builder().withConnected(true).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), - "SELECT 1")).thenRespond(new ErrorToken(1, 1, (byte) 1, (byte) 1, "failed", "", "", 0), DoneToken.create(0)).build(); + TestClient.builder().withConnected(true).expectRequest(SqlBatch.create(1, TransactionDescriptor.empty(), + "SELECT 1")).thenRespond(new ErrorToken(1, 1, (byte) 1, (byte) 1, "failed", "", "", 0), DoneToken.create(0)).build(); MssqlConnection connection = new MssqlConnection(client, metadata, conectionOptions); connection.validate(ValidationDepth.REMOTE) - .as(StepVerifier::create) - .expectNext(false) - .verifyComplete(); + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); } - @Nested - class SanitizeTests { - @Test - void shorterThanMax() { - assertThat(MssqlConnection.sanitize("12345", 10)).isEqualTo("12345"); - } - - @Test - void exactlyMax() { - assertThat(MssqlConnection.sanitize("1234567", 7)).isEqualTo("1234567"); - } - - @Test - void greaterThanMax() { - assertThat(MssqlConnection.sanitize("1234567", 3)).isEqualTo("567"); - } + @Test + void shouldSanitizeProperly() { - @Test - void dropStartingPunctuation() { - assertThat(MssqlConnection.sanitize("1_23_4", 5)).isEqualTo("23_4"); - } + assertThat(MssqlConnection.sanitize("12345", 10)).isEqualTo("12345"); + assertThat(MssqlConnection.sanitize("1234567", 7)).isEqualTo("1234567"); + assertThat(MssqlConnection.sanitize("1234567", 3)).isEqualTo("123"); } private static Stream isolationLevels() { return Stream.of(MssqlIsolationLevel.SERIALIZABLE, MssqlIsolationLevel.READ_COMMITTED, - MssqlIsolationLevel.READ_UNCOMMITTED, MssqlIsolationLevel.REPEATABLE_READ, - MssqlIsolationLevel.SNAPSHOT); + MssqlIsolationLevel.READ_UNCOMMITTED, MssqlIsolationLevel.REPEATABLE_READ, + MssqlIsolationLevel.SNAPSHOT); } }