Skip to content

Commit

Permalink
Working on AI demo
Browse files Browse the repository at this point in the history
  • Loading branch information
fwang committed Jan 15, 2024
1 parent c9d275a commit 71a5a84
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 75 deletions.
2 changes: 2 additions & 0 deletions examples/playground/functions/vector-example/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export const seeder = async () => {
for (const tag of tags) {
console.log("ingesting tag", tag.id);
await client.ingest({
externalId: tag.id,
text: tag.text,
metadata: { type: "tag", id: tag.id },
});
Expand All @@ -61,6 +62,7 @@ export const seeder = async () => {
const image = imageBuffer.toString("base64");

await client.ingest({
externalId: movie.id,
text: movie.summary,
image,
metadata: { type: "movie", id: movie.id },
Expand Down
27 changes: 22 additions & 5 deletions internal/components/src/auto/run.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import { PulumiFn } from "@pulumi/pulumi/automation";
import { runtime } from "@pulumi/pulumi";
import { initializeLinkRegistry } from "../components/link";
import { interpolate, runtime } from "@pulumi/pulumi";
import {
initializeLinkRegistry,
makeLinkable,
makeAWSLinkable,
} from "../components/link";

export async function run(program: PulumiFn) {
process.chdir($cli.paths.root);
Expand All @@ -22,7 +26,7 @@ export async function run(program: PulumiFn) {
};
}
return undefined;
},
}
);

runtime.registerStackTransformation(
Expand All @@ -40,14 +44,27 @@ export async function run(program: PulumiFn) {

if (!normalizedName.match(/^[A-Z][a-zA-Z0-9]*$/)) {
throw new Error(
`Invalid component name "${normalizedName}". Component names must start with an uppercase letter and contain only alphanumeric characters.`,
`Invalid component name "${normalizedName}". Component names must start with an uppercase letter and contain only alphanumeric characters.`
);
}

return undefined;
},
}
);

await initializeLinkRegistry();
makeLinkable(aws.dynamodb.Table, function () {
return {
type: `{ tableName: string }`,
value: { tableName: this.name },
};
});
makeAWSLinkable(aws.dynamodb.Table, function () {
return {
actions: ["dynamodb:*"],
resources: [this.arn, interpolate`${this.arn}/*`],
};
});

return await program();
}
110 changes: 73 additions & 37 deletions internal/components/src/components/bucket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@ import { prefixName, hashNumberToString } from "./helpers/naming";
import { Component } from "./component";
import { AWSLinkable, Link, Linkable } from "./link";
import { FunctionPermissionArgs } from ".";
import { create } from "domain";

/**
* Properties to create a DNS validated certificate managed by AWS Certificate Manager.
*/
export interface BucketArgs {
/**
* Whether the bucket should block public access
* @default - true
* Enable public access to the files in the bucket
* @default false
* @example
* ```js
* {
* public: true
* }
* ```
*/
blockPublicAccess?: Input<boolean>;
public?: Input<boolean>;
nodes?: {
bucket?: aws.s3.BucketV2Args;
};
Expand All @@ -38,47 +45,76 @@ export class Bucket extends Component implements Linkable, AWSLinkable {
super("sst:sst:Bucket", name, args, opts);

const parent = this;
const blockPublicAccess = normalizeBlockPublicAccess();

const randomId = new RandomId(`${name}Id`, { byteLength: 6 }, { parent });

const bucket = new aws.s3.BucketV2(
`${name}Bucket`,
{
bucket: randomId.dec.apply((dec) =>
prefixName(
name.toLowerCase(),
`-${hashNumberToString(parseInt(dec), 8)}`
)
),
forceDestroy: true,
...args?.nodes?.bucket,
},
{
parent,
}
);
const publicAccess = normalizePublicAccess();

const bucket = createBucket();
createPublicAccess();

output(blockPublicAccess).apply((blockPublicAccess) => {
if (!blockPublicAccess) return;
this.bucket = bucket;

new aws.s3.BucketPublicAccessBlock(
`${name}PublicAccessBlock`,
function createBucket() {
const randomId = new RandomId(`${name}Id`, { byteLength: 6 }, { parent });

return new aws.s3.BucketV2(
`${name}Bucket`,
{
bucket: bucket.bucket,
blockPublicAcls: true,
blockPublicPolicy: true,
ignorePublicAcls: true,
restrictPublicBuckets: true,
bucket: randomId.dec.apply((dec) =>
prefixName(
name.toLowerCase(),
`-${hashNumberToString(parseInt(dec), 8)}`
)
),
forceDestroy: true,
...args?.nodes?.bucket,
},
{ parent }
{
parent,
}
);
});
}

this.bucket = bucket;
function createPublicAccess() {
publicAccess.apply((publicAccess) => {
const publicAccessBlock = new aws.s3.BucketPublicAccessBlock(
`${name}PublicAccessBlock`,
{
bucket: bucket.bucket,
blockPublicAcls: true,
blockPublicPolicy: !publicAccess,
ignorePublicAcls: true,
restrictPublicBuckets: !publicAccess,
},
{ parent }
);

if (!publicAccess) return;

new aws.s3.BucketPolicy(
`${name}Policy`,
{
bucket: bucket.bucket,
policy: aws.iam.getPolicyDocumentOutput({
statements: [
{
principals: [
{
type: "*",
identifiers: ["*"],
},
],
actions: ["s3:GetObject"],
resources: [$util.interpolate`${bucket.arn}/*`],
},
],
}).json,
},
{ parent, dependsOn: publicAccessBlock }
);
});
}

function normalizeBlockPublicAccess() {
return output(args?.blockPublicAccess).apply((v) => v ?? true);
function normalizePublicAccess() {
return output(args?.public).apply((v) => v ?? false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,8 @@ import {
} from "@aws-sdk/client-bedrock-runtime";
import { useClient } from "../../helpers/aws/client";

const {
CLUSTER_ARN,
SECRET_ARN,
EMBEDDING_MODEL_ID,
DATABASE_NAME,
TABLE_NAME,
} = process.env;

export type IngestEvent = {
externalId: string;
text: string;
image?: string;
metadata: any;
Expand All @@ -29,14 +22,24 @@ export type RetrieveEvent = {
count?: number;
};

export type RemoveEvent = {
externalId: string;
};

const {
CLUSTER_ARN,
SECRET_ARN,
EMBEDDING_MODEL_ID,
DATABASE_NAME,
TABLE_NAME,
} = process.env;

export async function ingest(event: IngestEvent) {
console.log(event);
const embedding = await generateEmbedding(event.text, event.image);
const metadata = JSON.stringify(event.metadata);
await storeEmbedding(metadata, embedding);
await storeEmbedding(metadata, embedding, event.externalId);
}
export async function retrieve(event: RetrieveEvent) {
console.log(event);
const embedding = await generateEmbedding(event.prompt);
const metadata = JSON.stringify(event.metadata);
const result = await queryEmbeddings(
Expand All @@ -49,6 +52,9 @@ export async function retrieve(event: RetrieveEvent) {
results: result,
};
}
export async function remove(event: RemoveEvent) {
await removeEmbedding(event.externalId);
}

async function generateEmbedding(text: string, image?: string) {
const ret = await useClient(BedrockRuntimeClient).send(
Expand All @@ -66,20 +72,32 @@ async function generateEmbedding(text: string, image?: string) {
return payload.embedding;
}

async function storeEmbedding(metadata: string, embedding: number[]) {
async function storeEmbedding(
metadata: string,
embedding: number[],
externalId: string
) {
await useClient(RDSDataClient).send(
new ExecuteStatementCommand({
resourceArn: CLUSTER_ARN,
secretArn: SECRET_ARN,
database: DATABASE_NAME,
sql: `INSERT INTO ${TABLE_NAME} (embedding, metadata)
VALUES (ARRAY[${embedding.join(",")}], :metadata)`,
sql: `INSERT INTO ${TABLE_NAME} (embedding, metadata, external_id)
VALUES (ARRAY[${embedding.join(",")}], :metadata, :external_id)
ON CONFLICT (external_id) DO UPDATE
SET embedding = ARRAY[${embedding.join(
","
)}], metadata = :metadata`,
parameters: [
{
name: "metadata",
value: { stringValue: metadata },
typeHint: "JSON",
},
{
name: "external_id",
value: { stringValue: externalId },
},
],
})
);
Expand Down Expand Up @@ -117,3 +135,20 @@ async function queryEmbeddings(
score: 1 - record[2].doubleValue!,
}));
}

async function removeEmbedding(externalId: string) {
await useClient(RDSDataClient).send(
new ExecuteStatementCommand({
resourceArn: CLUSTER_ARN,
secretArn: SECRET_ARN,
database: DATABASE_NAME,
sql: `DELETE FROM ${TABLE_NAME} WHERE external_id = :external_id`,
parameters: [
{
name: "external_id",
value: { stringValue: externalId },
},
],
})
);
}
7 changes: 7 additions & 0 deletions internal/components/src/components/link.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@ export function makeLinkable<T>(
) {
obj.prototype.getSSTLink = cb;
}

export function makeAWSLinkable<T>(
obj: { new (...args: any[]): T },
cb: (this: T) => FunctionPermissionArgs
) {
obj.prototype.getSSTAWSPermissions = cb;
}
23 changes: 22 additions & 1 deletion internal/components/src/components/providers/embeddings-table.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Provider implements dynamic.ResourceProvider {
await this.createTable(inputs);
await this.createEmbeddingIndex(inputs);
await this.createMetadataIndex(inputs);
await this.createExternalIdIndex(inputs);
return {
id: inputs.tableName,
outs: {},
Expand All @@ -46,6 +47,7 @@ class Provider implements dynamic.ResourceProvider {
await this.createTable(news);
await this.createEmbeddingIndex(news);
await this.createMetadataIndex(news);
await this.createExternalIdIndex(news);
return {
outs: {},
};
Expand Down Expand Up @@ -115,7 +117,8 @@ class Provider implements dynamic.ResourceProvider {
sql: `create table ${inputs.tableName} (
id bigserial primary key,
embedding vector(${inputs.vectorSize}),
metadata jsonb
metadata jsonb,
external_id varchar(255) unique
);`,
})
);
Expand Down Expand Up @@ -162,6 +165,24 @@ class Provider implements dynamic.ResourceProvider {
throw error;
}
}

async createExternalIdIndex(inputs: Inputs) {
const client = useClient(RDSDataClient);
try {
await client.send(
new ExecuteStatementCommand({
resourceArn: inputs.clusterArn,
secretArn: inputs.secretArn,
database: inputs.databaseName,
sql: `create unique index on ${inputs.tableName} (external_id);`,
})
);
} catch (error: any) {
// ERROR: relation "embeddings" already exists; SQLState: 42P07
if (error.message.endsWith("SQLState: 42P07")) return;
throw error;
}
}
}

export class EmbeddingsTable extends dynamic.Resource {
Expand Down
Loading

0 comments on commit 71a5a84

Please sign in to comment.