Skip to content

Commit

Permalink
Vector: support storing multiple vectors in the db
Browse files Browse the repository at this point in the history
  • Loading branch information
fwang committed Jan 18, 2024
1 parent b1f8565 commit 0744e2e
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 110 deletions.
7 changes: 6 additions & 1 deletion examples/playground/functions/vector-example/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ export const seeder = async () => {
for (const tag of tags) {
console.log("ingesting tag", tag.id);
await client.ingest({
//model: "amazon.titan-embed-image-v1",
model: "text-embedding-ada-002",
text: tag.text,
metadata: { type: "tag", id: tag.id },
});
Expand All @@ -61,6 +63,7 @@ export const seeder = async () => {
const image = imageBuffer.toString("base64");

await client.ingest({
model: "text-embedding-ada-002",
text: movie.summary,
image,
metadata: { type: "movie", id: movie.id },
Expand All @@ -75,8 +78,10 @@ export const seeder = async () => {

export const app = async (event) => {
const ret = await client.retrieve({
model: "text-embedding-ada-002",
text: event.queryStringParameters?.text,
metadata: { type: "movie" },
include: { type: "movie" },
exclude: { id: "movie1" },
});

return {
Expand Down
5 changes: 3 additions & 2 deletions examples/playground/sst.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ export default $config({
},
async run() {
const vector = new sst.Vector("MyVectorDB", {
//model: "amazon.titan-embed-image-v1",
model: "text-embedding-ada-002",
openAiApiKey: new sst.Secret("OpenAiApiKey").value,
});

const seeder = new sst.Function("Seeder", {
Expand All @@ -43,6 +42,8 @@ export default $config({
url: true,
});

const bucket = new sst.Bucket("MyBucket");

const app = new sst.Function("MyApp", {
handler: "functions/vector-example/index.app",
link: [vector],
Expand Down
128 changes: 96 additions & 32 deletions internal/components/src/components/handlers/vector-handler/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,70 @@ import {
import { OpenAI } from "openai";
import { useClient } from "../../helpers/aws/client";

const ModelInfo = {
"amazon.titan-embed-text-v1": {
provider: "bedrock" as const,
shortName: "brtt1",
},
"amazon.titan-embed-image-v1": {
provider: "bedrock" as const,
shortName: "brti1",
},
"text-embedding-ada-002": {
provider: "openai" as const,
shortName: "oata2",
},
};

type Model = keyof typeof ModelInfo;

export type IngestEvent = {
model?: Model;
text?: string;
image?: string;
metadata: any;
metadata: Record<string, any>;
};

export type RetrieveEvent = {
model?: Model;
text?: string;
image?: string;
metadata: any;
include: Record<string, any>;
exclude?: Record<string, any>;
threshold?: number;
count?: number;
};

export type RemoveEvent = {
metadata: string;
include: Record<string, any>;
};

const {
CLUSTER_ARN,
SECRET_ARN,
DATABASE_NAME,
TABLE_NAME,
MODEL,
MODEL_PROVIDER,
// modal provider dependent (optional)
OPENAI_API_KEY,
} = process.env;

export async function ingest(event: IngestEvent) {
const embedding = await generateEmbedding(event.text, event.image);
const model = normalizeModel(event.model);
const embedding = await generateEmbedding(model, event.text, event.image);
const metadata = JSON.stringify(event.metadata);
await storeEmbedding(metadata, embedding);
await storeEmbedding(model, metadata, embedding);
}
export async function retrieve(event: RetrieveEvent) {
const embedding = await generateEmbedding(event.text, event.image);
const metadata = JSON.stringify(event.metadata);
const model = normalizeModel(event.model);
const embedding = await generateEmbedding(model, event.text, event.image);
const include = JSON.stringify(event.include);
// The return type of JSON.stringify() is always "string".
// This is wrong when "event.exclude" is undefined.
const exclude = JSON.stringify(event.exclude) as string | undefined;
const result = await queryEmbeddings(
metadata,
model,
include,
exclude,
embedding,
event.threshold ?? 0,
event.count ?? 10
Expand All @@ -57,35 +82,49 @@ export async function retrieve(event: RetrieveEvent) {
};
}
export async function remove(event: RemoveEvent) {
const metadata = JSON.stringify(event.metadata);
await removeEmbedding(metadata);
const include = JSON.stringify(event.include);
await removeEmbedding(include);
}

async function generateEmbedding(text?: string, image?: string) {
if (MODEL_PROVIDER === "openai") {
return await generateEmbeddingOpenAI(text!);
function normalizeModel(model?: Model) {
model = model ?? "amazon.titan-embed-image-v1";
if (ModelInfo[model].provider === "openai" && !OPENAI_API_KEY) {
throw new Error(
`To use the model "${model}", an OpenAI API key is necessary. Please ensure that "openAiApiKey" has been configured in the Vector component.`
);
}
return await generateEmbeddingBedrock(text, image);
return model;
}

async function generateEmbeddingOpenAI(text: string) {
async function generateEmbedding(model: Model, text?: string, image?: string) {
if (ModelInfo[model].provider === "openai") {
return await generateEmbeddingOpenAI(model, text!);
}
return await generateEmbeddingBedrock(model, text, image);
}

async function generateEmbeddingOpenAI(model: Model, text: string) {
const openAi = new OpenAI({ apiKey: OPENAI_API_KEY });
const embeddingResponse = await openAi.embeddings.create({
model: "text-embedding-ada-002",
model,
input: text,
encoding_format: "float",
});
return embeddingResponse.data[0].embedding;
}

async function generateEmbeddingBedrock(text?: string, image?: string) {
async function generateEmbeddingBedrock(
model: Model,
text?: string,
image?: string
) {
const ret = await useClient(BedrockRuntimeClient).send(
new InvokeModelCommand({
body: JSON.stringify({
inputText: text,
inputImage: image,
}),
modelId: MODEL,
modelId: model,
contentType: "application/json",
accept: "*/*",
})
Expand All @@ -94,15 +133,23 @@ async function generateEmbeddingBedrock(text?: string, image?: string) {
return payload.embedding;
}

async function storeEmbedding(metadata: string, embedding: number[]) {
async function storeEmbedding(
model: Model,
metadata: string,
embedding: number[]
) {
await useClient(RDSDataClient).send(
new ExecuteStatementCommand({
resourceArn: CLUSTER_ARN,
secretArn: SECRET_ARN,
database: DATABASE_NAME,
sql: `INSERT INTO ${TABLE_NAME} (embedding, metadata)
VALUES (ARRAY[${embedding.join(",")}], :metadata)`,
sql: `INSERT INTO ${TABLE_NAME} (model, embedding, metadata)
VALUES (:model, ARRAY[${embedding.join(",")}], :metadata)`,
parameters: [
{
name: "model",
value: { stringValue: ModelInfo[model].shortName },
},
{
name: "metadata",
value: { stringValue: metadata },
Expand All @@ -114,7 +161,9 @@ async function storeEmbedding(metadata: string, embedding: number[]) {
}

async function queryEmbeddings(
metadata: string,
model: Model,
include: string,
exclude: string | undefined,
embedding: number[],
threshold: number,
count: number
Expand All @@ -126,16 +175,31 @@ async function queryEmbeddings(
secretArn: SECRET_ARN,
database: DATABASE_NAME,
sql: `SELECT id, metadata, ${score} AS score FROM ${TABLE_NAME}
WHERE ${score} < ${1 - threshold}
AND metadata @> :metadata
WHERE model = :model
AND ${score} < ${1 - threshold}
AND metadata @> :include
${exclude ? "AND NOT metadata @> :exclude" : ""}
ORDER BY ${score}
LIMIT ${count}`,
parameters: [
{
name: "metadata",
value: { stringValue: metadata },
name: "model",
value: { stringValue: ModelInfo[model].shortName },
},
{
name: "include",
value: { stringValue: include },
typeHint: "JSON",
},
...(exclude
? [
{
name: "exclude",
value: { stringValue: exclude },
typeHint: "JSON" as const,
},
]
: []),
],
})
);
Expand All @@ -146,17 +210,17 @@ async function queryEmbeddings(
}));
}

async function removeEmbedding(metadata: string) {
async function removeEmbedding(include: string) {
await useClient(RDSDataClient).send(
new ExecuteStatementCommand({
resourceArn: CLUSTER_ARN,
secretArn: SECRET_ARN,
database: DATABASE_NAME,
sql: `DELETE FROM ${TABLE_NAME} WHERE metadata @> :metadata`,
sql: `DELETE FROM ${TABLE_NAME} WHERE metadata @> :include`,
parameters: [
{
name: "metadata",
value: { stringValue: metadata },
name: "include",
value: { stringValue: include },
typeHint: "JSON",
},
],
Expand Down
19 changes: 2 additions & 17 deletions internal/components/src/components/providers/embeddings-table.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ export interface PostgresTableInputs {
secretArn: Input<string>;
databaseName: Input<string>;
tableName: Input<string>;
vectorSize: Input<number>;
}

interface Inputs {
clusterArn: string;
secretArn: string;
databaseName: string;
tableName: string;
vectorSize: number;
}

class Provider implements dynamic.ResourceProvider {
Expand All @@ -43,9 +41,6 @@ class Provider implements dynamic.ResourceProvider {
await this.createDatabase(news);
await this.enablePgvectorExtension(news);
await this.enablePgtrgmExtension(news);
if (olds.vectorSize !== news.vectorSize) {
await this.removeTable(news);
}
await this.createTable(news);
await this.createEmbeddingIndex(news);
await this.createMetadataIndex(news);
Expand Down Expand Up @@ -117,7 +112,8 @@ class Provider implements dynamic.ResourceProvider {
database: inputs.databaseName,
sql: `create table ${inputs.tableName} (
id bigserial primary key,
embedding vector(${inputs.vectorSize}),
model char(5),
embedding vector(1536),
metadata jsonb
);`,
})
Expand All @@ -129,17 +125,6 @@ class Provider implements dynamic.ResourceProvider {
}
}

async removeTable(inputs: Inputs) {
await useClient(RDSDataClient).send(
new ExecuteStatementCommand({
resourceArn: inputs.clusterArn,
secretArn: inputs.secretArn,
database: inputs.databaseName,
sql: `drop table if exists ${inputs.tableName};`,
})
);
}

async createEmbeddingIndex(inputs: Inputs) {
try {
await useClient(RDSDataClient).send(
Expand Down
6 changes: 3 additions & 3 deletions internal/components/src/components/secret.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import { Component } from "./component";
export class SecretMissingError extends VisibleError {
constructor(public readonly secretName: string) {
super(
`Set a value for ${secretName} with \`sst secrets set ${secretName} <value>\``,
`Set a value for ${secretName} with \`sst secrets set ${secretName} <value>\``
);
}
}

export class Secret extends Component implements Linkable {
private _value?: string;
private _value: string;
private _name: string;
private _placeholder?: string;

Expand All @@ -23,7 +23,7 @@ export class Secret extends Component implements Linkable {
{
placeholder,
},
{},
{}
);
this._name = name;
this._placeholder = placeholder;
Expand Down
Loading

0 comments on commit 0744e2e

Please sign in to comment.