1515 */
1616package org .springframework .ai .vectorstore ;
1717
18- import java .sql .PreparedStatement ;
19- import java .sql .ResultSet ;
20- import java .sql .SQLException ;
21- import java .util .List ;
22- import java .util .Map ;
23- import java .util .Optional ;
24- import java .util .UUID ;
25-
18+ import com .fasterxml .jackson .core .JsonProcessingException ;
19+ import com .fasterxml .jackson .databind .ObjectMapper ;
20+ import com .pgvector .PGvector ;
21+ import io .micrometer .observation .ObservationRegistry ;
2622import org .postgresql .util .PGobject ;
2723import org .slf4j .Logger ;
2824import org .slf4j .LoggerFactory ;
4642import org .springframework .lang .Nullable ;
4743import org .springframework .util .StringUtils ;
4844
49- import com .fasterxml .jackson .core .JsonProcessingException ;
50- import com .fasterxml .jackson .databind .ObjectMapper ;
51- import com .pgvector .PGvector ;
52-
53- import io .micrometer .observation .ObservationRegistry ;
45+ import java .sql .PreparedStatement ;
46+ import java .sql .ResultSet ;
47+ import java .sql .SQLException ;
48+ import java .util .ArrayList ;
49+ import java .util .List ;
50+ import java .util .Map ;
51+ import java .util .Optional ;
52+ import java .util .UUID ;
5453
5554/**
5655 * Uses the "vector_store" table to store the Spring AI vector data. The table and the
@@ -81,6 +80,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
8180
8281 public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter ();
8382
83+ public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000 ;
84+
8485 private final String vectorTableName ;
8586
8687 private final String vectorIndexName ;
@@ -109,6 +110,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
109110
110111 private final BatchingStrategy batchingStrategy ;
111112
113+ private final int maxDocumentBatchSize ;
114+
112115 public PgVectorStore (JdbcTemplate jdbcTemplate , EmbeddingModel embeddingModel ) {
113116 this (jdbcTemplate , embeddingModel , INVALID_EMBEDDING_DIMENSION , PgDistanceType .COSINE_DISTANCE , false ,
114117 PgIndexType .NONE , false );
@@ -132,7 +135,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin
132135
133136 this (DEFAULT_SCHEMA_NAME , vectorTableName , DEFAULT_SCHEMA_VALIDATION , jdbcTemplate , embeddingModel , dimensions ,
134137 distanceType , removeExistingVectorStoreTable , createIndexMethod , initializeSchema );
135-
136138 }
137139
138140 private PgVectorStore (String schemaName , String vectorTableName , boolean vectorTableValidationsEnabled ,
@@ -141,14 +143,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
141143
142144 this (schemaName , vectorTableName , vectorTableValidationsEnabled , jdbcTemplate , embeddingModel , dimensions ,
143145 distanceType , removeExistingVectorStoreTable , createIndexMethod , initializeSchema ,
144- ObservationRegistry .NOOP , null , new TokenCountBatchingStrategy ());
146+ ObservationRegistry .NOOP , null , new TokenCountBatchingStrategy (), MAX_DOCUMENT_BATCH_SIZE );
145147 }
146148
147149 private PgVectorStore (String schemaName , String vectorTableName , boolean vectorTableValidationsEnabled ,
148150 JdbcTemplate jdbcTemplate , EmbeddingModel embeddingModel , int dimensions , PgDistanceType distanceType ,
149151 boolean removeExistingVectorStoreTable , PgIndexType createIndexMethod , boolean initializeSchema ,
150152 ObservationRegistry observationRegistry , VectorStoreObservationConvention customObservationConvention ,
151- BatchingStrategy batchingStrategy ) {
153+ BatchingStrategy batchingStrategy , int maxDocumentBatchSize ) {
152154
153155 super (observationRegistry , customObservationConvention );
154156
@@ -172,6 +174,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
172174 this .initializeSchema = initializeSchema ;
173175 this .schemaValidator = new PgVectorSchemaValidator (jdbcTemplate );
174176 this .batchingStrategy = batchingStrategy ;
177+ this .maxDocumentBatchSize = maxDocumentBatchSize ;
175178 }
176179
177180 public PgDistanceType getDistanceType () {
@@ -180,40 +183,50 @@ public PgDistanceType getDistanceType() {
180183
181184 @ Override
182185 public void doAdd (List <Document > documents ) {
186+ this .embeddingModel .embed (documents , EmbeddingOptionsBuilder .builder ().build (), this .batchingStrategy );
183187
184- int size = documents .size ();
188+ List <List <Document >> batchedDocuments = batchDocuments (documents );
189+ batchedDocuments .forEach (this ::insertOrUpdateBatch );
190+ }
185191
186- this .embeddingModel .embed (documents , EmbeddingOptionsBuilder .builder ().build (), this .batchingStrategy );
192+ private List <List <Document >> batchDocuments (List <Document > documents ) {
193+ List <List <Document >> batches = new ArrayList <>();
194+ for (int i = 0 ; i < documents .size (); i += this .maxDocumentBatchSize ) {
195+ batches .add (documents .subList (i , Math .min (i + this .maxDocumentBatchSize , documents .size ())));
196+ }
197+ return batches ;
198+ }
187199
188- this .jdbcTemplate .batchUpdate (
189- "INSERT INTO " + getFullyQualifiedTableName ()
190- + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
191- + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? " ,
192- new BatchPreparedStatementSetter () {
193- @ Override
194- public void setValues (PreparedStatement ps , int i ) throws SQLException {
195-
196- var document = documents .get (i );
197- var content = document .getContent ();
198- var json = toJson (document .getMetadata ());
199- var embedding = document .getEmbedding ();
200- var pGvector = new PGvector (embedding );
201-
202- StatementCreatorUtils .setParameterValue (ps , 1 , SqlTypeValue .TYPE_UNKNOWN ,
203- UUID .fromString (document .getId ()));
204- StatementCreatorUtils .setParameterValue (ps , 2 , SqlTypeValue .TYPE_UNKNOWN , content );
205- StatementCreatorUtils .setParameterValue (ps , 3 , SqlTypeValue .TYPE_UNKNOWN , json );
206- StatementCreatorUtils .setParameterValue (ps , 4 , SqlTypeValue .TYPE_UNKNOWN , pGvector );
207- StatementCreatorUtils .setParameterValue (ps , 5 , SqlTypeValue .TYPE_UNKNOWN , content );
208- StatementCreatorUtils .setParameterValue (ps , 6 , SqlTypeValue .TYPE_UNKNOWN , json );
209- StatementCreatorUtils .setParameterValue (ps , 7 , SqlTypeValue .TYPE_UNKNOWN , pGvector );
210- }
200+ private void insertOrUpdateBatch (List <Document > batch ) {
201+ String sql = "INSERT INTO " + getFullyQualifiedTableName ()
202+ + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
203+ + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? " ;
204+
205+ this .jdbcTemplate .batchUpdate (sql , new BatchPreparedStatementSetter () {
206+ @ Override
207+ public void setValues (PreparedStatement ps , int i ) throws SQLException {
208+
209+ var document = batch .get (i );
210+ var content = document .getContent ();
211+ var json = toJson (document .getMetadata ());
212+ var embedding = document .getEmbedding ();
213+ var pGvector = new PGvector (embedding );
214+
215+ StatementCreatorUtils .setParameterValue (ps , 1 , SqlTypeValue .TYPE_UNKNOWN ,
216+ UUID .fromString (document .getId ()));
217+ StatementCreatorUtils .setParameterValue (ps , 2 , SqlTypeValue .TYPE_UNKNOWN , content );
218+ StatementCreatorUtils .setParameterValue (ps , 3 , SqlTypeValue .TYPE_UNKNOWN , json );
219+ StatementCreatorUtils .setParameterValue (ps , 4 , SqlTypeValue .TYPE_UNKNOWN , pGvector );
220+ StatementCreatorUtils .setParameterValue (ps , 5 , SqlTypeValue .TYPE_UNKNOWN , content );
221+ StatementCreatorUtils .setParameterValue (ps , 6 , SqlTypeValue .TYPE_UNKNOWN , json );
222+ StatementCreatorUtils .setParameterValue (ps , 7 , SqlTypeValue .TYPE_UNKNOWN , pGvector );
223+ }
211224
212- @ Override
213- public int getBatchSize () {
214- return size ;
215- }
216- });
225+ @ Override
226+ public int getBatchSize () {
227+ return batch . size () ;
228+ }
229+ });
217230 }
218231
219232 private String toJson (Map <String , Object > map ) {
@@ -285,7 +298,7 @@ private String comparisonOperator() {
285298 // Initialize
286299 // ---------------------------------------------------------------------------------
287300 @ Override
288- public void afterPropertiesSet () throws Exception {
301+ public void afterPropertiesSet () {
289302
290303 logger .info ("Initializing PGVectorStore schema for table: {} in schema: {}" , this .getVectorTableName (),
291304 this .getSchemaName ());
@@ -390,7 +403,7 @@ public enum PgIndexType {
390403 * speed-recall tradeoff). There’s no training step like IVFFlat, so the index can
391404 * be created without any data in the table.
392405 */
393- HNSW ;
406+ HNSW
394407
395408 }
396409
@@ -443,7 +456,7 @@ private static class DocumentRowMapper implements RowMapper<Document> {
443456
444457 private static final String COLUMN_DISTANCE = "distance" ;
445458
446- private ObjectMapper objectMapper ;
459+ private final ObjectMapper objectMapper ;
447460
448461 public DocumentRowMapper (ObjectMapper objectMapper ) {
449462 this .objectMapper = objectMapper ;
@@ -509,6 +522,8 @@ public static class Builder {
509522
510523 private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy ();
511524
525+ private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE ;
526+
512527 @ Nullable
513528 private VectorStoreObservationConvention searchObservationConvention ;
514529
@@ -576,11 +591,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) {
576591 return this ;
577592 }
578593
594+ public Builder withMaxDocumentBatchSize (int maxDocumentBatchSize ) {
595+ this .maxDocumentBatchSize = maxDocumentBatchSize ;
596+ return this ;
597+ }
598+
579599 public PgVectorStore build () {
580600 return new PgVectorStore (this .schemaName , this .vectorTableName , this .vectorTableValidationsEnabled ,
581601 this .jdbcTemplate , this .embeddingModel , this .dimensions , this .distanceType ,
582602 this .removeExistingVectorStoreTable , this .indexType , this .initializeSchema ,
583- this .observationRegistry , this .searchObservationConvention , this .batchingStrategy );
603+ this .observationRegistry , this .searchObservationConvention , this .batchingStrategy ,
604+ this .maxDocumentBatchSize );
584605 }
585606
586607 }
0 commit comments