Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (provider/google): support system instructions #2256

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/ninety-beers-do.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/google': patch
---

feat (provider/google): support system instructions
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,18 @@ You can use the following optional settings to customize the Google Generative A
## Language Models

You can create models that call the [Google Generative AI API](https://ai.google.dev/api/rest) using the provider instance.
The first argument is the model id, e.g. `models/gemini-pro`.
The first argument is the model id, e.g. `models/gemini-1.5-pro-latest`.
The models support tool calls and some have multi-modal capabilities.

```ts
const model = google('models/gemini-pro');
const model = google('models/gemini-1.5-pro-latest');
```

Google Generative AI models support also some model specific settings that are not part of the [standard call settings](/docs/ai-sdk-core/settings).
You can pass them as an options argument:

```ts
const model = google('models/gemini-pro', {
const model = google('models/gemini-1.5-pro-latest', {
topK: 0.2,
});
```
Expand Down
2 changes: 1 addition & 1 deletion examples/ai-core/src/generate-text/google-custom-fetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const google = createGoogleGenerativeAI({

async function main() {
const result = await generateText({
model: google('models/gemini-pro'),
model: google('models/gemini-1.5-pro-latest'),
prompt: 'Invent a new holiday and describe its traditions.',
});

Expand Down
2 changes: 1 addition & 1 deletion examples/ai-core/src/generate-text/google-tool-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dotenv.config();

async function main() {
const result = await generateText({
model: google('models/gemini-pro'),
model: google('models/gemini-1.5-pro-latest'),
maxTokens: 512,
tools: {
weather: weatherTool,
Expand Down
2 changes: 1 addition & 1 deletion examples/ai-core/src/generate-text/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dotenv.config();

async function main() {
const result = await generateText({
model: google('models/gemini-pro'),
model: google('models/gemini-1.5-pro-latest'),
prompt: 'Invent a new holiday and describe its traditions.',
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async function main() {
}

const result = await streamText({
model: google('models/gemini-pro'),
model: google('models/gemini-1.5-pro-latest'),
tools: { weatherTool },
system: `You are a helpful, respectful and honest assistant.`,
messages,
Expand Down
2 changes: 1 addition & 1 deletion examples/ai-core/src/stream-text/google-chatbot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async function main() {
messages.push({ role: 'user', content: userInput });

const result = await streamText({
model: google('models/gemini-pro'),
model: google('models/gemini-1.5-pro-latest'),
system: `You are a helpful, respectful and honest assistant.`,
messages,
});
Expand Down
2 changes: 1 addition & 1 deletion examples/ai-core/src/stream-text/google-fullstream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dotenv.config();

async function main() {
const result = await streamText({
model: google('models/gemini-pro'),
model: google('models/gemini-1.5-pro-latest'),
tools: {
weather: weatherTool,
cityAttractions: {
Expand Down
3 changes: 2 additions & 1 deletion examples/ai-core/src/stream-text/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ dotenv.config();

async function main() {
const result = await streamText({
model: google('models/gemini-pro'),
model: google('models/gemini-1.5-pro-latest'),
system: 'You are a comedian. Only give funny answers.',
prompt: 'Invent a new holiday and describe its traditions.',
});

Expand Down
2 changes: 1 addition & 1 deletion packages/google/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import { google } from '@ai-sdk/google';
import { generateText } from 'ai';

const { text } = await generateText({
model: google('models/gemini-1.5-flash-latest'),
model: google('models/gemini-1.5-pro-latest'),
prompt: 'Write a vegetarian lasagna recipe for 4 people.',
});
```
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
import { convertToGoogleGenerativeAIMessages } from './convert-to-google-generative-ai-messages';

describe('system messages', () => {
it('should store system message in system instruction', async () => {
const result = await convertToGoogleGenerativeAIMessages({
prompt: [{ role: 'system', content: 'Test' }],
});

expect(result).toEqual({
systemInstruction: { parts: [{ text: 'Test' }] },
contents: [],
});
});

it('should throw error when there was already a user message', async () => {
await expect(
convertToGoogleGenerativeAIMessages({
prompt: [
{ role: 'user', content: [{ type: 'text', text: 'Test' }] },
{ role: 'system', content: 'Test' },
],
}),
).rejects.toThrow(
'system messages are only supported at the beginning of the conversation',
);
});
});

describe('user messages', () => {
it('should download images for user image parts with URLs', async () => {
const result = await convertToGoogleGenerativeAIMessages({
Expand All @@ -24,19 +50,22 @@ describe('user messages', () => {
},
});

expect(result).toEqual([
{
role: 'user',
parts: [
{
inlineData: {
data: 'AAECAw==',
mimeType: 'image/png',
expect(result).toEqual({
systemInstruction: undefined,
contents: [
{
role: 'user',
parts: [
{
inlineData: {
data: 'AAECAw==',
mimeType: 'image/png',
},
},
},
],
},
]);
],
},
],
});
});

it('should add image parts for UInt8Array images', async () => {
Expand All @@ -54,23 +83,26 @@ describe('user messages', () => {
},
],

downloadImplementation: async ({ url }) => {
downloadImplementation: async () => {
throw new Error('Unexpected download call');
},
});

expect(result).toEqual([
{
role: 'user',
parts: [
{
inlineData: {
data: 'AAECAw==',
mimeType: 'image/png',
expect(result).toEqual({
systemInstruction: undefined,
contents: [
{
role: 'user',
parts: [
{
inlineData: {
data: 'AAECAw==',
mimeType: 'image/png',
},
},
},
],
},
]);
],
},
],
});
});
});
42 changes: 31 additions & 11 deletions packages/google/src/convert-to-google-generative-ai-messages.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import { LanguageModelV1Prompt } from '@ai-sdk/provider';
import {
LanguageModelV1Prompt,
UnsupportedFunctionalityError,
} from '@ai-sdk/provider';
import { convertUint8ArrayToBase64, download } from '@ai-sdk/provider-utils';
import {
GoogleGenerativeAIContent,
GoogleGenerativeAIContentPart,
GoogleGenerativeAIPrompt,
} from './google-generative-ai-prompt';
Expand All @@ -12,21 +16,27 @@ export async function convertToGoogleGenerativeAIMessages({
prompt: LanguageModelV1Prompt;
downloadImplementation?: typeof download;
}): Promise<GoogleGenerativeAIPrompt> {
const messages: GoogleGenerativeAIPrompt = [];
const systemInstructionParts: Array<{ text: string }> = [];
const contents: Array<GoogleGenerativeAIContent> = [];
let systemMessagesAllowed = true;

for (const { role, content } of prompt) {
switch (role) {
case 'system': {
// system message becomes user message:
messages.push({ role: 'user', parts: [{ text: content }] });

// required for to ensure turn-taking:
messages.push({ role: 'model', parts: [{ text: '' }] });
if (!systemMessagesAllowed) {
throw new UnsupportedFunctionalityError({
functionality:
'system messages are only supported at the beginning of the conversation',
});
}

systemInstructionParts.push({ text: content });
break;
}

case 'user': {
systemMessagesAllowed = false;

const parts: GoogleGenerativeAIContentPart[] = [];

for (const part of content) {
Expand Down Expand Up @@ -63,12 +73,14 @@ export async function convertToGoogleGenerativeAIMessages({
}
}

messages.push({ role: 'user', parts });
contents.push({ role: 'user', parts });
break;
}

case 'assistant': {
messages.push({
systemMessagesAllowed = false;

contents.push({
role: 'model',
parts: content
.map(part => {
Expand Down Expand Up @@ -96,7 +108,9 @@ export async function convertToGoogleGenerativeAIMessages({
}

case 'tool': {
messages.push({
systemMessagesAllowed = false;

contents.push({
role: 'user',
parts: content.map(part => ({
functionResponse: {
Expand All @@ -114,5 +128,11 @@ export async function convertToGoogleGenerativeAIMessages({
}
}

return messages;
return {
systemInstruction:
systemInstructionParts.length > 0
? { parts: systemInstructionParts }
: undefined,
contents,
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ describe('doGenerate', () => {
await model.doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
prompt: [
{ role: 'system', content: 'test system instruction' },
{ role: 'user', content: [{ type: 'text', text: 'Hello' }] },
],
});

expect(await call(0).getRequestBodyJson()).toStrictEqual({
Expand All @@ -217,6 +220,7 @@ describe('doGenerate', () => {
parts: [{ text: 'Hello' }],
},
],
systemInstruction: { parts: [{ text: 'test system instruction' }] },
generationConfig: {},
});
}),
Expand Down
5 changes: 4 additions & 1 deletion packages/google/src/google-generative-ai-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,16 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 {
topP,
};

const contents = await convertToGoogleGenerativeAIMessages({ prompt });
const { contents, systemInstruction } =
await convertToGoogleGenerativeAIMessages({ prompt });

switch (type) {
case 'regular': {
return {
args: {
generationConfig,
contents,
systemInstruction,
safetySettings: this.settings.safetySettings,
...prepareToolsAndToolConfig(mode),
},
Expand All @@ -121,6 +123,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 {
response_mime_type: 'application/json',
},
contents,
systemInstruction,
safetySettings: this.settings.safetySettings,
},
warnings,
Expand Down
9 changes: 8 additions & 1 deletion packages/google/src/google-generative-ai-prompt.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
export type GoogleGenerativeAIPrompt = Array<GoogleGenerativeAIContent>;
export type GoogleGenerativeAIPrompt = {
systemInstruction?: GoogleGenerativeAISystemInstruction;
contents: Array<GoogleGenerativeAIContent>;
};

export type GoogleGenerativeAISystemInstruction = {
parts: Array<{ text: string }>;
};

export type GoogleGenerativeAIContent = {
role: 'user' | 'model';
Expand Down
Loading