-
Notifications
You must be signed in to change notification settings - Fork 89
/
Copy pathhandler.js
296 lines (270 loc) · 9.43 KB
/
handler.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import { DynamoDBClient } from "@aws-sdk/client-dynamodb";
import {
DynamoDBDocumentClient,
QueryCommand,
UpdateCommand,
} from "@aws-sdk/lib-dynamodb";
import {
AccessDeniedException,
BedrockRuntimeClient,
ConverseStreamCommand,
} from "@aws-sdk/client-bedrock-runtime";
import { z } from "zod";
import { Auth } from "auth-sdk";
/**
* A small utility to easily capture the status codes in the error handlers.
*/
class HTTPError extends Error {
constructor(statusCode, message) {
super(message);
this.statusCode = statusCode;
}
}
/**
* Initialize the AWS SDK Dynamo Doc Client.
*/
const USAGE_TABLE_NAME = process.env.USAGE_TABLE_NAME;
const dynamoDbClient = new DynamoDBClient();
const dynamoDbDocClient = DynamoDBDocumentClient.from(dynamoDbClient);
/**
* Defines the zod schema for the input payload to pass on to AWS Bedrock's
* model. This is passed to the ConverseStreamCommand as a part of the user
* input.
*/
const inputMessageSchema = z.object({
messages: z.array(
z.object({
role: z.enum(["user", "assistant"]),
content: z.array(
z.object({
text: z.string(),
})
),
})
),
});
/**
* The awslambda.streamifyResponse is a utility provided in the lambda runtime
* to stream the response. Unfortunately this utility is not available
* externally at the moment, therefore this method can't be run locally using
* Serverless Dev mode.
*/
export const handler = awslambda.streamifyResponse(
async (event, responseStream, context) => {
/**
* responseStream is a Writeable Stream and doesn't provide method to update
* the headers, statusCode, therefore we use HttpResponseStream.from to
* create a new responseStream with the new response headers.
*/
const updateStream = ({ statusCode } = {}) => {
const httpResponseMetadata = {
statusCode: statusCode || 200,
headers: {
"Content-Type": "application/json",
},
};
responseStream = awslambda.HttpResponseStream.from(
responseStream,
httpResponseMetadata
);
};
try {
const runtimeClient = new BedrockRuntimeClient({ region: "us-east-1" });
const authenticator = new Auth({
secret: process.env.SHARED_TOKEN_SECRET,
});
/**
* Extract, parser, and validate the JWT Token from the Authorization
* header.
*/
const requestTokenHeader =
event.headers.Authorization || event.headers.authorization;
const [authSchema, authorizationParameter] = (
requestTokenHeader || ""
).split(" ");
if (
!requestTokenHeader ||
authSchema !== "Bearer" ||
!authorizationParameter
) {
throw new HTTPError(
403,
"Missing bearer token in Authorization header"
);
}
const token = authenticator.verify(authorizationParameter);
if (!token) {
throw new HTTPError(403, "Invalid token");
}
const { userId } = token;
/**
* This is a simple throttle mechanism to limit the number of requests
* per user per month. This throttles based on per-user limits as well as
* a global limit. At the end of each request, the usage is updated in the
* DynamoDB table, including the request count, input tokens, output
* tokens, and total tokens.
*
* This uses the request count to throttle. Using other metrics, like the
* inputTokens or totalTokens, is also possible by switching out the
* `invocationCount` with the desired metric.
*
* This also uses the Model ID as the secondary key, in which case, the
* throttle limits are calculated per model.
*/
const now = new Date();
const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1);
const queryParams = (pk) => ({
TableName: process.env.USAGE_TABLE_NAME,
KeyConditionExpression: "PK = :pk AND SK = :sk",
ExpressionAttributeValues: {
":pk": pk,
":sk": `MODEL#${process.env.MODEL_ID}`,
},
});
const userUsageKey = `USER#${userId}#${startOfMonth.toISOString()}`;
const globalUsageKey = `GLOBAL#${startOfMonth.toISOString()}`;
const userUsageCommand = new QueryCommand(queryParams(userUsageKey));
const userUsageRecords = await dynamoDbDocClient.send(userUsageCommand);
const userUsageMetrics = userUsageRecords.Items[0];
const globalUsageCommand = new QueryCommand(queryParams(globalUsageKey));
const globalUsageRecords = await dynamoDbDocClient.send(
globalUsageCommand
);
const globalUsageMetrics = globalUsageRecords.Items[0];
if (
userUsageMetrics?.invocationCount >=
process.env.THROTTLE_MONTHLY_LIMIT_USER ||
globalUsageMetrics?.invocationCount >=
process.env.THROTTLE_MONTHLY_LIMIT_GLOBAL
) {
throw new HTTPError(
429,
`User has exceeded the user or global monthly usage limit`
);
}
/**
* Parse and validate the input payload from the request body before
* passing it on to AWS Bedrock.
*/
let messages = {};
try {
messages = JSON.parse(event.body);
} catch {
throw new HTTPError(400, "Invalid JSON format in the request body");
}
try {
inputMessageSchema.parse({ messages });
} catch (e) {
const issuePath = e.issues[0].path.join(".");
const issueMessage = e.issues[0].message;
const errorMessage = `Invalid value at '${issuePath}': ${issueMessage}`;
throw new HTTPError(400, errorMessage);
}
/**
* Update the response stream with the defaults just before starting the
* streaming response.
*/
updateStream();
/**
* Details about this command can be found in the AWS SDK for JavaScript v3.
* https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/client/bedrock-runtime/command/ConverseStreamCommand/
*/
const input = {
modelId: process.env.MODEL_ID,
system: [
{
text: "You are a helpful bot.",
},
],
messages,
};
const command = new ConverseStreamCommand(input);
const converseResponse = await runtimeClient.send(command);
let usage = {};
for await (const event of converseResponse.stream) {
const {
metadata,
internalServerException,
modelStreamErrorException,
validationException,
throttlingException,
...contentResponse
} = event;
/**
* Metadata includes internal properties that should not be passed on to
* the client via the HTTP response, so we log them for output only.
* This also saves the metadata.usage object, which is later used to
* record the usage in the DynamoDB table.
*/
if (metadata) {
console.log(metadata);
usage = metadata.usage;
}
/**
* The AWS Bedrock ConversesStreamCommand doesn't throw errors, instead,
* it returns the error in the response. This checks each of the types
* of errors and throws an error with the message.
*
* In production consider using a more detailed error message, and
* more gracefully handling the errors.
*/
const exception =
internalServerException ||
modelStreamErrorException ||
validationException ||
throttlingException;
if (exception) {
throw new Error(exception.message);
}
responseStream.write(contentResponse);
}
/**
* Records the usage in the usage table, including the userId, timestamp,
* and the number of tokens used from the AWS Bedrock model. The token
* usage is just informative. The throttling uses the number of requests
* to limit the number of requests per user per month.
*/
const updateParams = (pk) => ({
TableName: USAGE_TABLE_NAME,
Key: {
PK: pk,
SK: `MODEL#${process.env.MODEL_ID}`,
},
UpdateExpression:
"ADD invocationCount :inc, inputTokens :in, outputTokens :out, totalTokens :tot",
ExpressionAttributeValues: {
":inc": 1,
":in": usage.inputTokens || 0,
":out": usage.outputTokens || 0,
":tot": usage.totalTokens || 0,
},
});
const userUsageUpdateCommand = new UpdateCommand(
updateParams(userUsageKey)
);
const globalUsageUpdateCommand = new UpdateCommand(
updateParams(globalUsageKey)
);
await dynamoDbDocClient.send(userUsageUpdateCommand);
await dynamoDbDocClient.send(globalUsageUpdateCommand);
} catch (error) {
console.error(event);
console.error(error);
if (error instanceof HTTPError) {
updateStream({ statusCode: error.statusCode });
responseStream.write({ error: error.message });
} else if (error instanceof AccessDeniedException) {
const message =
"Access denied to AWS Bedrock - Please ensure the model is enabled in the AWS Bedrock console.";
updateStream({ statusCode: 500 });
responseStream.write({ error: message });
} else {
updateStream({ statusCode: 500 });
responseStream.write({ error: "Internal Error" });
}
} finally {
responseStream.end();
return;
}
}
);