-
Notifications
You must be signed in to change notification settings - Fork 295
/
PGVectorStore.ts
320 lines (275 loc) · 10.2 KB
/
PGVectorStore.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import type pg from "pg";
import {
VectorStoreBase,
type IEmbedModel,
type VectorStoreNoEmbedModel,
type VectorStoreQuery,
type VectorStoreQueryResult,
} from "./types.js";
import type { BaseNode, Metadata } from "../../Node.js";
import { Document, MetadataMode } from "../../Node.js";
export const PGVECTOR_SCHEMA = "public";
export const PGVECTOR_TABLE = "llamaindex_embedding";
/**
* Provides support for writing and querying vector data in Postgres.
* Note: Can't be used with data created using the Python version of the vector store (https://docs.llamaindex.ai/en/stable/examples/vector_stores/postgres.html)
*/
export class PGVectorStore
extends VectorStoreBase
implements VectorStoreNoEmbedModel
{
storesText: boolean = true;
private collection: string = "";
private schemaName: string = PGVECTOR_SCHEMA;
private tableName: string = PGVECTOR_TABLE;
private connectionString: string | undefined = undefined;
private dimensions: number = 1536;
private db?: pg.Client;
/**
* Constructs a new instance of the PGVectorStore
*
* If the `connectionString` is not provided the following env variables are
* used to connect to the DB:
* PGHOST=your database host
* PGUSER=your database user
* PGPASSWORD=your database password
* PGDATABASE=your database name
* PGPORT=your database port
*
* @param {object} config - The configuration settings for the instance.
* @param {string} config.schemaName - The name of the schema (optional). Defaults to PGVECTOR_SCHEMA.
* @param {string} config.tableName - The name of the table (optional). Defaults to PGVECTOR_TABLE.
* @param {string} config.connectionString - The connection string (optional).
* @param {number} config.dimensions - The dimensions of the embedding model.
*/
constructor(
config?: {
schemaName?: string;
tableName?: string;
connectionString?: string;
dimensions?: number;
} & Partial<IEmbedModel>,
) {
super(config?.embedModel);
this.schemaName = config?.schemaName ?? PGVECTOR_SCHEMA;
this.tableName = config?.tableName ?? PGVECTOR_TABLE;
this.connectionString = config?.connectionString;
this.dimensions = config?.dimensions ?? 1536;
}
/**
* Setter for the collection property.
* Using a collection allows for simple segregation of vector data,
* e.g. by user, source, or access-level.
* Leave/set blank to ignore the collection value when querying.
* @param coll Name for the collection.
*/
setCollection(coll: string) {
this.collection = coll;
}
/**
* Getter for the collection property.
* Using a collection allows for simple segregation of vector data,
* e.g. by user, source, or access-level.
* Leave/set blank to ignore the collection value when querying.
* @returns The currently-set collection value. Default is empty string.
*/
getCollection(): string {
return this.collection;
}
private async getDb(): Promise<pg.Client> {
if (!this.db) {
try {
const pg = await import("pg");
const { Client } = pg.default ? pg.default : pg;
const { registerType } = await import("pgvector/pg");
// Create DB connection
// Read connection params from env - see comment block above
const db = new Client({
connectionString: this.connectionString,
});
await db.connect();
// Check vector extension
await db.query("CREATE EXTENSION IF NOT EXISTS vector");
await registerType(db);
// Check schema, table(s), index(es)
await this.checkSchema(db);
// All good? Keep the connection reference
this.db = db;
} catch (err: any) {
console.error(err);
return Promise.reject(err);
}
}
return Promise.resolve(this.db);
}
private async checkSchema(db: pg.Client) {
await db.query(`CREATE SCHEMA IF NOT EXISTS ${this.schemaName}`);
const tbl = `CREATE TABLE IF NOT EXISTS ${this.schemaName}.${this.tableName}(
id uuid DEFAULT gen_random_uuid() PRIMARY KEY,
external_id VARCHAR,
collection VARCHAR,
document TEXT,
metadata JSONB DEFAULT '{}',
embeddings VECTOR(${this.dimensions})
)`;
await db.query(tbl);
const idxs = `CREATE INDEX IF NOT EXISTS idx_${this.tableName}_external_id ON ${this.schemaName}.${this.tableName} (external_id);
CREATE INDEX IF NOT EXISTS idx_${this.tableName}_collection ON ${this.schemaName}.${this.tableName} (collection);`;
await db.query(idxs);
// TODO add IVFFlat or HNSW indexing?
return db;
}
/**
* Connects to the database specified in environment vars.
* This method also checks and creates the vector extension,
* the destination table and indexes if not found.
* @returns A connection to the database, or the error encountered while connecting/setting up.
*/
client() {
return this.getDb();
}
/**
* Delete all vector records for the specified collection.
* NOTE: Uses the collection property controlled by setCollection/getCollection.
* @returns The result of the delete query.
*/
async clearCollection() {
const sql: string = `DELETE FROM ${this.schemaName}.${this.tableName}
WHERE collection = $1`;
const db = await this.getDb();
const ret = await db.query(sql, [this.collection]);
return ret;
}
private getDataToInsert(embeddingResults: BaseNode<Metadata>[]) {
const result = [];
for (let index = 0; index < embeddingResults.length; index++) {
const row = embeddingResults[index];
const id: any = row.id_.length ? row.id_ : null;
const meta = row.metadata || {};
meta.create_date = new Date();
const params = [
id,
"",
this.collection,
row.getContent(MetadataMode.EMBED),
meta,
"[" + row.getEmbedding().join(",") + "]",
];
result.push(params);
}
return result;
}
/**
* Adds vector record(s) to the table.
* NOTE: Uses the collection property controlled by setCollection/getCollection.
* @param embeddingResults The Nodes to be inserted, optionally including metadata tuples.
* @returns A list of zero or more id values for the created records.
*/
async add(embeddingResults: BaseNode<Metadata>[]): Promise<string[]> {
if (embeddingResults.length == 0) {
console.debug("Empty list sent to PGVectorStore::add");
return Promise.resolve([]);
}
const sql: string = `INSERT INTO ${this.schemaName}.${this.tableName}
(id, external_id, collection, document, metadata, embeddings)
VALUES ($1, $2, $3, $4, $5, $6)`;
const db = await this.getDb();
const data = this.getDataToInsert(embeddingResults);
const ret: string[] = [];
for (let index = 0; index < data.length; index++) {
const params = data[index];
try {
const result = await db.query(sql, params);
if (result.rows.length) {
const id = result.rows[0].id as string;
ret.push(id);
}
} catch (err) {
const msg = `${err}`;
console.log(msg, err);
}
}
return Promise.resolve(ret);
}
/**
* Deletes a single record from the database by id.
* NOTE: Uses the collection property controlled by setCollection/getCollection.
* @param refDocId Unique identifier for the record to delete.
* @param deleteKwargs Required by VectorStore interface. Currently ignored.
* @returns Promise that resolves if the delete query did not throw an error.
*/
async delete(refDocId: string, deleteKwargs?: any): Promise<void> {
const collectionCriteria = this.collection.length
? "AND collection = $2"
: "";
const sql: string = `DELETE FROM ${this.schemaName}.${this.tableName}
WHERE id = $1 ${collectionCriteria}`;
const db = await this.getDb();
const params = this.collection.length
? [refDocId, this.collection]
: [refDocId];
await db.query(sql, params);
return Promise.resolve();
}
/**
* Query the vector store for the closest matching data to the query embeddings
* @param query The VectorStoreQuery to be used
* @param options Required by VectorStore interface. Currently ignored.
* @returns Zero or more Document instances with data from the vector store.
*/
async query(
query: VectorStoreQuery,
options?: any,
): Promise<VectorStoreQueryResult> {
// TODO QUERY TYPES:
// Distance: SELECT embedding <=> $1 AS distance FROM items;
// Inner Product: SELECT (embedding <#> $1) * -1 AS inner_product FROM items;
// Cosine Sim: SELECT 1 - (embedding <=> $1) AS cosine_similarity FROM items;
const embedding = "[" + query.queryEmbedding?.join(",") + "]";
const max = query.similarityTopK ?? 2;
const whereClauses = this.collection.length ? ["collection = $2"] : [];
const params: Array<string | number> = this.collection.length
? [embedding, this.collection]
: [embedding];
query.filters?.filters.forEach((filter, index) => {
const paramIndex = params.length + 1;
whereClauses.push(`metadata->>'${filter.key}' = $${paramIndex}`);
params.push(filter.value);
});
const where =
whereClauses.length > 0 ? `WHERE ${whereClauses.join(" AND ")}` : "";
const sql = `SELECT
v.*,
embeddings <=> $1 s
FROM ${this.schemaName}.${this.tableName} v
${where}
ORDER BY s
LIMIT ${max}
`;
const db = await this.getDb();
const results = await db.query(sql, params);
const nodes = results.rows.map((row) => {
return new Document({
id_: row.id,
text: row.document,
metadata: row.metadata,
embedding: row.embeddings,
});
});
const ret = {
nodes: nodes,
similarities: results.rows.map((row) => 1 - row.s),
ids: results.rows.map((row) => row.id),
};
return Promise.resolve(ret);
}
/**
* Required by VectorStore interface. Currently ignored.
* @param persistPath
* @param fs
* @returns Resolved Promise.
*/
persist(persistPath: string): Promise<void> {
return Promise.resolve();
}
}