3737import java .util .LinkedHashSet ;
3838import java .util .List ;
3939import java .util .Map ;
40- import java .util .Set ;
4140import java .util .concurrent .atomic .AtomicBoolean ;
4241import java .util .function .BiFunction ;
4342import java .util .function .Function ;
4948import org .reactivestreams .Publisher ;
5049
5150import org .springframework .dao .DataAccessException ;
51+ import org .springframework .dao .InvalidDataAccessApiUsageException ;
5252import org .springframework .data .domain .Pageable ;
5353import org .springframework .data .domain .Sort ;
5454import org .springframework .data .r2dbc .UncategorizedR2dbcException ;
5757import org .springframework .data .r2dbc .function .connectionfactory .ConnectionProxy ;
5858import org .springframework .data .r2dbc .function .convert .ColumnMapRowMapper ;
5959import org .springframework .data .r2dbc .support .R2dbcExceptionTranslator ;
60+ import org .springframework .data .relational .core .sql .Insert ;
6061import org .springframework .lang .Nullable ;
6162import org .springframework .util .Assert ;
6263
@@ -336,9 +337,17 @@ <T> FetchSpec<T> exchange(String sql, BiFunction<Row, RowMetadata, T> mappingFun
336337 logger .debug ("Executing SQL statement [" + sql + "]" );
337338 }
338339
340+ if (sqlSupplier instanceof PreparedOperation <?>) {
341+ return ((PreparedOperation <?>) sqlSupplier ).bind (it .createStatement (sql ));
342+ }
343+
339344 BindableOperation operation = namedParameters .expand (sql , dataAccessStrategy .getBindMarkersFactory (),
340345 new MapBindParameterSource (byName ));
341346
347+ if (logger .isTraceEnabled ()) {
348+ logger .trace ("Expanded SQL [" + operation .toQuery () + "]" );
349+ }
350+
342351 Statement statement = it .createStatement (operation .toQuery ());
343352
344353 byName .forEach ((name , o ) -> {
@@ -366,6 +375,7 @@ <T> FetchSpec<T> exchange(String sql, BiFunction<Row, RowMetadata, T> mappingFun
366375
367376 public ExecuteSpecSupport bind (int index , Object value ) {
368377
378+ assertNotPreparedOperation ();
369379 Assert .notNull (value , () -> String .format ("Value at index %d must not be null. Use bindNull(…) instead." , index ));
370380
371381 Map <Integer , SettableValue > byIndex = new LinkedHashMap <>(this .byIndex );
@@ -376,6 +386,8 @@ public ExecuteSpecSupport bind(int index, Object value) {
376386
377387 public ExecuteSpecSupport bindNull (int index , Class <?> type ) {
378388
389+ assertNotPreparedOperation ();
390+
379391 Map <Integer , SettableValue > byIndex = new LinkedHashMap <>(this .byIndex );
380392 byIndex .put (index , SettableValue .empty (type ));
381393
@@ -384,6 +396,8 @@ public ExecuteSpecSupport bindNull(int index, Class<?> type) {
384396
385397 public ExecuteSpecSupport bind (String name , Object value ) {
386398
399+ assertNotPreparedOperation ();
400+
387401 Assert .hasText (name , "Parameter name must not be null or empty!" );
388402 Assert .notNull (value ,
389403 () -> String .format ("Value for parameter %s must not be null. Use bindNull(…) instead." , name ));
@@ -396,6 +410,7 @@ public ExecuteSpecSupport bind(String name, Object value) {
396410
397411 public ExecuteSpecSupport bindNull (String name , Class <?> type ) {
398412
413+ assertNotPreparedOperation ();
399414 Assert .hasText (name , "Parameter name must not be null or empty!" );
400415
401416 Map <String , SettableValue > byName = new LinkedHashMap <>(this .byName );
@@ -404,6 +419,12 @@ public ExecuteSpecSupport bindNull(String name, Class<?> type) {
404419 return createInstance (this .byIndex , byName , this .sqlSupplier );
405420 }
406421
422+ private void assertNotPreparedOperation () {
423+ if (sqlSupplier instanceof PreparedOperation <?>) {
424+ throw new InvalidDataAccessApiUsageException ("Cannot add bindings to a PreparedOperation" );
425+ }
426+ }
427+
407428 protected ExecuteSpecSupport createInstance (Map <Integer , SettableValue > byIndex , Map <String , SettableValue > byName ,
408429 Supplier <String > sqlSupplier ) {
409430 return new ExecuteSpecSupport (byIndex , byName , sqlSupplier );
@@ -881,20 +902,19 @@ private <R> FetchSpec<R> exchange(BiFunction<Row, RowMetadata, R> mappingFunctio
881902 throw new IllegalStateException ("Insert fields is empty!" );
882903 }
883904
884- BindableOperation bindableInsert = dataAccessStrategy .insertAndReturnGeneratedKeys (table , byName .keySet ());
905+ PreparedOperation <Insert > operation = dataAccessStrategy .getStatements ().insert (table , Collections .emptyList (),
906+ it -> {
907+ byName .forEach (it ::bind );
908+ });
885909
886- String sql = bindableInsert .toQuery ();
910+ String sql = operation .toQuery ();
887911 Function <Connection , Statement > insertFunction = it -> {
888912
889913 if (logger .isDebugEnabled ()) {
890914 logger .debug ("Executing SQL statement [" + sql + "]" );
891915 }
892916
893- Statement statement = it .createStatement (sql ).returnGeneratedValues ();
894-
895- byName .forEach ((k , v ) -> bindableInsert .bind (statement , k , v ));
896-
897- return statement ;
917+ return operation .bind (it .createStatement (sql ));
898918 };
899919
900920 Function <Connection , Flux <Result >> resultFunction = it -> Flux .from (insertFunction .apply (it ).execute ());
@@ -998,34 +1018,25 @@ private <MR> FetchSpec<MR> exchange(Object toInsert, BiFunction<Row, RowMetadata
9981018
9991019 OutboundRow outboundRow = dataAccessStrategy .getOutboundRow (toInsert );
10001020
1001- Set <String > columns = new LinkedHashSet <>();
1002-
1003- outboundRow .forEach ((k , v ) -> {
1004-
1005- if (v .hasValue ()) {
1006- columns .add (k );
1007- }
1008- });
1021+ PreparedOperation <Insert > operation = dataAccessStrategy .getStatements ().insert (table , Collections .emptyList (),
1022+ it -> {
1023+ outboundRow .forEach ((k , v ) -> {
10091024
1010- BindableOperation bindableInsert = dataAccessStrategy .insertAndReturnGeneratedKeys (table , columns );
1025+ if (v .hasValue ()) {
1026+ it .bind (k , v );
1027+ }
1028+ });
1029+ });
10111030
1012- String sql = bindableInsert .toQuery ();
1031+ String sql = operation .toQuery ();
10131032
10141033 Function <Connection , Statement > insertFunction = it -> {
10151034
10161035 if (logger .isDebugEnabled ()) {
10171036 logger .debug ("Executing SQL statement [" + sql + "]" );
10181037 }
10191038
1020- Statement statement = it .createStatement (sql ).returnGeneratedValues ();
1021-
1022- outboundRow .forEach ((k , v ) -> {
1023- if (v .hasValue ()) {
1024- bindableInsert .bind (statement , k , v );
1025- }
1026- });
1027-
1028- return statement ;
1039+ return operation .bind (it .createStatement (sql ));
10291040 };
10301041
10311042 Function <Connection , Flux <Result >> resultFunction = it -> Flux .from (insertFunction .apply (it ).execute ());
0 commit comments