-
Notifications
You must be signed in to change notification settings - Fork 324
/
vertex.ts
81 lines (73 loc) · 2.42 KB
/
vertex.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import {
VertexAI,
GenerativeModel as VertexGenerativeModel,
GenerativeModelPreview as VertexGenerativeModelPreview,
type GenerateContentResponse,
type ModelParams as VertexModelParams,
type StreamGenerateContentResult as VertexStreamGenerateContentResult,
} from "@google-cloud/vertexai";
import type {
GeminiChatStreamResponse,
IGeminiSession,
VertexGeminiSessionOptions,
} from "./types.js";
import { getEnv } from "@llamaindex/env";
import type { CompletionResponse } from "../types.js";
import { streamConverter } from "../utils.js";
import { getText } from "./utils.js";
/* To use Google's Vertex AI backend, it doesn't use api key authentication.
*
* To authenticate for local development:
*
* ```
* npm install @google-cloud/vertexai
* gcloud auth application-default login
* ```
* For production the prefered method is via a service account, more
* details: https://cloud.google.com/docs/authentication/
*
* */
export class GeminiVertexSession implements IGeminiSession {
private vertex: VertexAI;
private preview: boolean = false;
constructor(options?: Partial<VertexGeminiSessionOptions>) {
const project = options?.project ?? getEnv("GOOGLE_VERTEX_PROJECT");
const location = options?.location ?? getEnv("GOOGLE_VERTEX_LOCATION");
if (!project || !location) {
throw new Error(
"Set Google Vertex project and location in GOOGLE_VERTEX_PROJECT and GOOGLE_VERTEX_LOCATION env variables",
);
}
this.vertex = new VertexAI({
...options,
project,
location,
});
this.preview = options?.preview ?? false;
}
getGenerativeModel(
metadata: VertexModelParams,
): VertexGenerativeModelPreview | VertexGenerativeModel {
if (this.preview) return this.vertex.preview.getGenerativeModel(metadata);
return this.vertex.getGenerativeModel(metadata);
}
getResponseText(response: GenerateContentResponse): string {
return getText(response);
}
async *getChatStream(
result: VertexStreamGenerateContentResult,
): GeminiChatStreamResponse {
yield* streamConverter(result.stream, (response) => ({
delta: this.getResponseText(response),
raw: response,
}));
}
getCompletionStream(
result: VertexStreamGenerateContentResult,
): AsyncIterable<CompletionResponse> {
return streamConverter(result.stream, (response) => ({
text: this.getResponseText(response),
raw: response,
}));
}
}