diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java index cd9387f5e8b..c127da5673c 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java @@ -319,12 +319,21 @@ private Object convertIdToPgType(String id) { @Override public void doDelete(List idList) { - int updateCount = 0; - for (String id : idList) { - int count = this.jdbcTemplate.update("DELETE FROM " + getFullyQualifiedTableName() + " WHERE id = ?", - UUID.fromString(id)); - updateCount = updateCount + count; - } + String sql = "DELETE FROM " + getFullyQualifiedTableName() + " WHERE id = ?"; + + this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() { + + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + var id = idList.get(i); + StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, convertIdToPgType(id)); + } + + @Override + public int getBatchSize() { + return idList.size(); + } + }); } @Override diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java index e0b2bc3578a..8b06c628743 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java @@ -75,6 +75,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Jihoon Kim + * @author CChuYong */ @Testcontainers @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -232,6 +233,47 @@ public void testToPgTypeWithNonUuidIdType() { }); } + @Test + public void testBulkOperationWithUuidIdType() { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE") + .run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + List documents = List.of( + new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>()), + new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>()), + new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>())); + vectorStore.add(documents); + + List idList = documents.stream().map(Document::getId).toList(); + vectorStore.delete(idList); + + dropTable(context); + }); + } + + @Test + public void testBulkOperationWithNonUuidIdType() { + this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE") + .withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false) + .withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT") + .run(context -> { + VectorStore vectorStore = context.getBean(VectorStore.class); + initSchema(context); + + List documents = List.of(new Document("NON_UUID_1", "TEXT", new HashMap<>()), + new Document("NON_UUID_2", "TEXT", new HashMap<>()), + new Document("NON_UUID_3", "TEXT", new HashMap<>())); + vectorStore.add(documents); + + List idList = documents.stream().map(Document::getId).toList(); + vectorStore.delete(idList); + + dropTable(context); + }); + } + @ParameterizedTest(name = "Filter expression {0} should return {1} records ") @MethodSource("provideFilters") public void searchWithInFilter(String expression, Integer expectedRecords) { @@ -436,6 +478,8 @@ void getNativeClientTest() { PgVectorStore vectorStore = context.getBean(PgVectorStore.class); Optional nativeClient = vectorStore.getNativeClient(); assertThat(nativeClient).isPresent(); + + dropTable(context); }); }