Skip to content

Commit

Permalink
Vector: make retrieve client function typesafe
Browse files Browse the repository at this point in the history
  • Loading branch information
fwang committed Jan 19, 2024
1 parent d11ee5f commit edbd538
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async function queryEmbeddings(
resourceArn: CLUSTER_ARN,
secretArn: SECRET_ARN,
database: DATABASE_NAME,
sql: `SELECT id, metadata, ${score} AS score FROM ${TABLE_NAME}
sql: `SELECT metadata, ${score} AS score FROM ${TABLE_NAME}
WHERE ${score} < ${1 - threshold}
AND metadata @> :include
${exclude ? "AND NOT metadata @> :exclude" : ""}
Expand All @@ -156,9 +156,8 @@ async function queryEmbeddings(
})
);
return ret.records?.map((record) => ({
id: record[0].stringValue,
metadata: JSON.parse(record[1].stringValue!),
score: 1 - record[2].doubleValue!,
metadata: JSON.parse(record[0].stringValue!),
score: 1 - record[1].doubleValue!,
}));
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"$schema": "https://json.schemastore.org/package.json",
"name": "sst",
"type": "module",
"version": "3.0.1-12",
"version": "3.0.1-13",
"main": "./dist/index.js",
"exports": {
".": "./dist/index.js",
Expand Down
30 changes: 22 additions & 8 deletions sdk/js/src/vector-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export type IngestEvent = {
* }
* ```
*/
metadata: any;
metadata: Record<string, any>;
};

export type RetrieveEvent = {
Expand Down Expand Up @@ -94,7 +94,7 @@ export type RetrieveEvent = {
* release: "2001",
* }
*/
include: any;
include: Record<string, any>;
/**
* Exclude embeddings with metadata that match the provided fields.
* @example
Expand Down Expand Up @@ -123,7 +123,7 @@ export type RetrieveEvent = {
* release: "2001",
* }
*/
exclude?: any;
exclude?: Record<string, any>;
/**
* The threshold of similarity between the prompt and the retrieved embeddings.
* Only embeddings with a similarity score higher than the threshold will be returned.
Expand Down Expand Up @@ -172,7 +172,18 @@ export type RemoveEvent = {
* }
* }
*/
include: any;
include: Record<string, any>;
};

type RetriveResponse = {
/**
* Metadata for the event in JSON format that was provided when ingesting the embedding.
*/
metadata: Record<string, any>;
/**
* The similarity score between the prompt and the retrieved embedding.
*/
score: number;
};

const lambda = new LambdaClient();
Expand All @@ -187,7 +198,7 @@ export const VectorClient = (name: string) => {
})
);

return parsePayload(ret, "Failed to ingest into the vector db");
parsePayload(ret, "Failed to ingest into the vector db");
},

retrieve: async (event: RetrieveEvent) => {
Expand All @@ -197,7 +208,10 @@ export const VectorClient = (name: string) => {
Payload: JSON.stringify(event),
})
);
return parsePayload(ret, "Failed to retrieve from the vector db");
return parsePayload<RetriveResponse>(
ret,
"Failed to retrieve from the vector db"
);
},

remove: async (event: RemoveEvent) => {
Expand All @@ -207,12 +221,12 @@ export const VectorClient = (name: string) => {
Payload: JSON.stringify(event),
})
);
return parsePayload(ret, "Failed to remove from the vector db");
parsePayload(ret, "Failed to remove from the vector db");
},
};
};

function parsePayload(output: InvokeCommandOutput, message: string) {
function parsePayload<T>(output: InvokeCommandOutput, message: string): T {
const payload = JSON.parse(Buffer.from(output.Payload!).toString());

// Set cause to the payload so that it can be logged in CloudWatch
Expand Down

0 comments on commit edbd538

Please sign in to comment.