Skip to content

Commit 11d7487

Browse files
authored
feat(db-postgres): add vector raw column type (#10422)
Example how you can add a vector column, enable the `vector` extension and query your embeddings in the included test - https://github.com/payloadcms/payload/compare/feat/more-types?expand=1#diff-7d876370487cb625eb42ff1ad7cffa78e8327367af3de2930837ed123f5e3ae6R1-R117
1 parent 82840aa commit 11d7487

File tree

4 files changed

+141
-0
lines changed

4 files changed

+141
-0
lines changed

packages/drizzle/src/postgres/columnToCodeConverter.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ export const columnToCodeConverter: ColumnToCodeConverter = ({
3535
}
3636
}
3737

38+
if (column.type === 'vector') {
39+
if (column.dimensions) {
40+
columnBuilderArgsArray.push(`dimensions: ${column.dimensions}`)
41+
}
42+
}
43+
3844
let columnBuilderArgs = ''
3945

4046
if (columnBuilderArgsArray.length) {

packages/drizzle/src/postgres/schema/buildDrizzleTable.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import {
1313
uniqueIndex,
1414
uuid,
1515
varchar,
16+
vector,
1617
} from 'drizzle-orm/pg-core'
1718

1819
import type { RawColumn, RawTable } from '../../types.js'
@@ -81,6 +82,13 @@ export const buildDrizzleTable = ({
8182
break
8283
}
8384

85+
case 'vector': {
86+
const builder = vector(column.name, { dimensions: column.dimensions })
87+
columns[key] = builder
88+
89+
break
90+
}
91+
8492
default:
8593
columns[key] = rawColumnBuilderMap[column.type](column.name)
8694
break

packages/drizzle/src/types.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,11 @@ export type IntegerRawColumn = {
279279
type: 'integer'
280280
} & BaseRawColumn
281281

282+
export type VectorRawColumn = {
283+
dimensions?: number
284+
type: 'vector'
285+
} & BaseRawColumn
286+
282287
export type RawColumn =
283288
| ({
284289
type: 'boolean' | 'geometry' | 'jsonb' | 'numeric' | 'serial' | 'text' | 'varchar'
@@ -287,6 +292,7 @@ export type RawColumn =
287292
| IntegerRawColumn
288293
| TimestampRawColumn
289294
| UUIDRawColumn
295+
| VectorRawColumn
290296

291297
export type IDType = 'integer' | 'numeric' | 'text' | 'uuid' | 'varchar'
292298

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/* eslint-disable jest/require-top-level-describe */
2+
import { PostgresAdapter } from '@payloadcms/db-postgres/types'
3+
import { cosineDistance, desc, gt, sql } from 'drizzle-orm'
4+
import path from 'path'
5+
import { buildConfig, getPayload } from 'payload'
6+
import { fileURLToPath } from 'url'
7+
8+
const filename = fileURLToPath(import.meta.url)
9+
const dirname = path.dirname(filename)
10+
11+
// skip on ci as there db does not have the vector extension
12+
const describeToUse =
13+
process.env.PAYLOAD_DATABASE.startsWith('postgres') && process.env.CI !== 'true'
14+
? describe
15+
: describe.skip
16+
17+
describeToUse('postgres vector custom column', () => {
18+
it('should add a vector column and query it', async () => {
19+
const { databaseAdapter } = await import(path.resolve(dirname, '../databaseAdapter.js'))
20+
21+
const init = databaseAdapter.init
22+
23+
// set options
24+
databaseAdapter.init = ({ payload }) => {
25+
const adapter = init({ payload })
26+
27+
adapter.extensions = {
28+
vector: true,
29+
}
30+
adapter.beforeSchemaInit = [
31+
({ schema, adapter }) => {
32+
;(adapter as PostgresAdapter).rawTables.posts.columns.embedding = {
33+
type: 'vector',
34+
dimensions: 5,
35+
name: 'embedding',
36+
}
37+
return schema
38+
},
39+
]
40+
return adapter
41+
}
42+
43+
const config = await buildConfig({
44+
db: databaseAdapter,
45+
secret: 'secret',
46+
collections: [
47+
{
48+
slug: 'users',
49+
auth: true,
50+
fields: [],
51+
},
52+
{
53+
slug: 'posts',
54+
fields: [
55+
{
56+
type: 'json',
57+
name: 'embedding',
58+
},
59+
{
60+
name: 'title',
61+
type: 'text',
62+
},
63+
],
64+
},
65+
],
66+
})
67+
68+
const payload = await getPayload({ config })
69+
70+
const catEmbedding = [1.5, -0.4, 7.2, 19.6, 20.2]
71+
72+
await payload.create({
73+
collection: 'posts',
74+
data: {
75+
embedding: [-5.2, 3.1, 0.2, 8.1, 3.5],
76+
title: 'apple',
77+
},
78+
})
79+
80+
await payload.create({
81+
collection: 'posts',
82+
data: {
83+
embedding: catEmbedding,
84+
title: 'cat',
85+
},
86+
})
87+
88+
await payload.create({
89+
collection: 'posts',
90+
data: {
91+
embedding: [-5.1, 2.9, 0.8, 7.9, 3.1],
92+
title: 'fruit',
93+
},
94+
})
95+
96+
await payload.create({
97+
collection: 'posts',
98+
data: {
99+
embedding: [1.7, -0.3, 6.9, 19.1, 21.1],
100+
title: 'dog',
101+
},
102+
})
103+
104+
const similarity = sql<number>`1 - (${cosineDistance(payload.db.tables.posts.embedding, catEmbedding)})`
105+
106+
const res = await payload.db.drizzle
107+
.select()
108+
.from(payload.db.tables.posts)
109+
.where(gt(similarity, 0.9))
110+
.orderBy(desc(similarity))
111+
112+
// Only cat and dog
113+
expect(res).toHaveLength(2)
114+
115+
// similarity sort
116+
expect(res[0].title).toBe('cat')
117+
expect(res[1].title).toBe('dog')
118+
119+
payload.logger.info(res)
120+
})
121+
})

0 commit comments

Comments
 (0)