diff --git a/java/jdbc/src/main/java/io/vitess/jdbc/VitessStatement.java b/java/jdbc/src/main/java/io/vitess/jdbc/VitessStatement.java index 278bea6d809..f4a0904824f 100644 --- a/java/jdbc/src/main/java/io/vitess/jdbc/VitessStatement.java +++ b/java/jdbc/src/main/java/io/vitess/jdbc/VitessStatement.java @@ -591,7 +591,7 @@ protected void checkSQLNullOrEmpty(String sql) throws SQLException { protected int[] generateBatchUpdateResult(List cursorWithErrorList, List batchedArgs) throws BatchUpdateException { int[] updateCounts = new int[cursorWithErrorList.size()]; - long[][] generatedKeys = new long[cursorWithErrorList.size()][2]; + ArrayList generatedKeys = new ArrayList(); Vtrpc.RPCError rpcError = null; String batchCommand = null; @@ -603,6 +603,7 @@ protected int[] generateBatchUpdateResult(List cursorWithErrorL try { long rowsAffected = cursorWithError.getCursor().getRowsAffected(); int truncatedUpdateCount; + boolean queryBatchUpsert = false; if (rowsAffected > Integer.MAX_VALUE) { truncatedUpdateCount = Integer.MAX_VALUE; } else { @@ -610,13 +611,15 @@ protected int[] generateBatchUpdateResult(List cursorWithErrorL // mimicking mysql-connector-j here. // but it would fail for: insert into t1 values ('a'), ('b') on duplicate key update ts = now(); truncatedUpdateCount = 1; + queryBatchUpsert = true; } else { truncatedUpdateCount = (int) rowsAffected; } } updateCounts[i] = truncatedUpdateCount; - if (this.retrieveGeneratedKeys) { - generatedKeys[i] = new long[]{cursorWithError.getCursor().getInsertId(), truncatedUpdateCount}; + long insertId = cursorWithError.getCursor().getInsertId(); + if (this.retrieveGeneratedKeys && (!queryBatchUpsert || insertId > 0)) { + generatedKeys.add(new long[]{insertId, truncatedUpdateCount}); } } catch (SQLException ex) { /* This case should not happen as API has returned cursor and not error. @@ -624,14 +627,14 @@ protected int[] generateBatchUpdateResult(List cursorWithErrorL */ updateCounts[i] = Statement.SUCCESS_NO_INFO; if (this.retrieveGeneratedKeys) { - generatedKeys[i] = new long[]{Statement.SUCCESS_NO_INFO, Statement.SUCCESS_NO_INFO}; + generatedKeys.add(new long[]{Statement.SUCCESS_NO_INFO, Statement.SUCCESS_NO_INFO}); } } } else { rpcError = cursorWithError.getError(); updateCounts[i] = Statement.EXECUTE_FAILED; if (this.retrieveGeneratedKeys) { - generatedKeys[i] = new long[]{Statement.EXECUTE_FAILED, Statement.EXECUTE_FAILED}; + generatedKeys.add(new long[]{Statement.EXECUTE_FAILED, Statement.EXECUTE_FAILED}); } } } @@ -642,7 +645,7 @@ protected int[] generateBatchUpdateResult(List cursorWithErrorL throw new BatchUpdateException(rpcError.toString(), sqlState, errno, updateCounts); } if (this.retrieveGeneratedKeys) { - this.batchGeneratedKeys = generatedKeys; + this.batchGeneratedKeys = generatedKeys.toArray(new long[generatedKeys.size()][2]); } return updateCounts; } diff --git a/java/jdbc/src/test/java/io/vitess/jdbc/VitessStatementTest.java b/java/jdbc/src/test/java/io/vitess/jdbc/VitessStatementTest.java index 6e0e4a77799..39482e9f14c 100644 --- a/java/jdbc/src/test/java/io/vitess/jdbc/VitessStatementTest.java +++ b/java/jdbc/src/test/java/io/vitess/jdbc/VitessStatementTest.java @@ -742,5 +742,15 @@ private void testExecute(int fetchSize, boolean simpleExecute, boolean shouldRun Assert.assertEquals(i, 0); // we should only have one i++; } + + VitessStatement noUpdate = new VitessStatement(mockConn); + PowerMockito.when(mockCursor.getInsertId()).thenReturn(0L); + PowerMockito.when(mockCursor.getRowsAffected()).thenReturn(1L); + + noUpdate.addBatch(sqlUpsert); + noUpdate.executeBatch(); + + ResultSet empty = noUpdate.getGeneratedKeys(); + Assert.assertFalse(empty.next()); } }