Skip to content

Commit

Permalink
Add (experimental) ai/prompts construction helpers for StarChat and O…
Browse files Browse the repository at this point in the history
…penAssistant (#343)
  • Loading branch information
MaxLeiter committed Jul 18, 2023
1 parent 75a790f commit 9320e95
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 20 deletions.
5 changes: 5 additions & 0 deletions .changeset/lemon-beans-grab.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

Add (experimental) prompt construction helpers for StarChat and OpenAssistant
5 changes: 5 additions & 0 deletions docs/pages/docs/api-reference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ title: API Reference
- [`StreamingTextResponse`](./api-reference/streaming-text-response)
- [`AIStream`](./api-reference/ai-stream)
- [`streamToResponse`](./api-reference/stream-to-response)

## Prompt Construction Helpers

- [`buildOpenAssistantPrompt`](./api-reference/prompts#experimental_buildopenassistantprompt)
- [`buildStarChatBetaPrompt`](./api-reference/prompts#experimental_buildstarchatbetaprompt)
43 changes: 43 additions & 0 deletions docs/pages/docs/api-reference/prompts.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
---
title: ai/prompts
---

import { Callout } from 'nextra-theme-docs'

# `ai/prompts`

The `ai/prompts` module contains functions to assist with converting `Message`'s into prompts that can be used with the [`useChat`](./use-chat) and [`useCompletion`](./use-completion) hooks.

<Callout>
The `experimental_` prefix on the functions in this module indicates that the
API is not yet stable and may change in the future without a major version
bump.
</Callout>

## `experimental_buildOpenAssistantPrompt`

Uses `<|prompter|>`, `<|endoftext|>`, and `<|assistant>` tokens. If a `Message` with an unsupported `role` is passed, an error will be thrown.

```ts filename="route.ts" {6}
import { experimental_buildOpenAssistantPrompt } from 'ai/prompts'

const { messages } = await req.json()
const response = Hf.textGenerationStream({
model: 'OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5',
inputs: experimental_buildOpenAssistantPrompt(messages)
})
```

## `experimental_buildStarChatBetaPrompt`

Uses `<|user|>`, `<|end|>`, `<|system|>`, and `<|assistant>` tokens. If a `Message` with an unsupported `role` is passed, an error will be thrown.

```ts filename="route.ts" {6}
import { buildStarChatBetaPrompt } from 'ai/prompts'

const { messages } = await req.json()
const response = Hf.textGenerationStream({
model: 'HuggingFaceH4/starchat-beta',
inputs: experimental_buildStarChatBetaPrompt(messages)
})
```
22 changes: 3 additions & 19 deletions examples/next-huggingface/app/api/chat/route.ts
Original file line number Diff line number Diff line change
@@ -1,36 +1,20 @@
import { HfInference } from '@huggingface/inference'
import { HuggingFaceStream, StreamingTextResponse } from 'ai'
import { buildOpenAssistantPrompt } from 'ai/prompts'

// Create a new HuggingFace Inference instance
const Hf = new HfInference(process.env.HUGGINGFACE_API_KEY)

// IMPORTANT! Set the runtime to edge
export const runtime = 'edge'

// Build a prompt from the messages
function buildPrompt(
messages: { content: string; role: 'system' | 'user' | 'assistant' }[]
) {
return (
messages
.map(({ content, role }) => {
if (role === 'user') {
return `<|prompter|>${content}<|endoftext|>`
} else {
return `<|assistant|>${content}<|endoftext|>`
}
})
.join('') + '<|assistant|>'
)
}

export async function POST(req: Request) {
// Extract the `messages` from the body of the request
const { messages } = await req.json()

const response = await Hf.textGenerationStream({
const response = Hf.textGenerationStream({
model: 'OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5',
inputs: buildPrompt(messages),
inputs: buildOpenAssistantPrompt(messages),
parameters: {
max_new_tokens: 200,
// @ts-ignore (this is a valid parameter specifically in OpenAssistant models)
Expand Down
12 changes: 11 additions & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
"module": "./dist/index.mjs",
"require": "./dist/index.js"
},
"./prompts": {
"types": "./prompts/dist/index.d.ts",
"import": "./prompts/dist/index.mjs",
"module": "./prompts/dist/index.mjs",
"require": "./prompts/dist/index.js"
},
"./react": {
"types": "./react/dist/index.d.ts",
"react-server": "./react/dist/index.server.mjs",
Expand Down Expand Up @@ -106,6 +112,10 @@
"url": "https://github.com/vercel-labs/ai/issues"
},
"keywords": [
"ai", "nextjs", "svelte", "react", "vue"
"ai",
"nextjs",
"svelte",
"react",
"vue"
]
}
47 changes: 47 additions & 0 deletions packages/core/prompts/huggingface.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import {
buildOpenAssistantPrompt,
buildStarChatBetaPrompt
} from './huggingface'
import type { Message } from '../shared/types'

describe('buildStarChatBetaPrompt', () => {
it('should return a string with user, assistant, and system messages', () => {
const messages: Pick<Message, 'content' | 'role'>[] = [
{ content: 'You are a chat bot.', role: 'system' },
{ content: 'Hello!', role: 'user' },
{ content: 'Hi there!', role: 'assistant' }
]

const expectedPrompt = `<|system|>\nYou are a chat bot.<|end|>\n<|user|>\nHello!<|end|>\n<|assistant|>\nHi there!<|end|>\n<|assistant|>`
const prompt = buildStarChatBetaPrompt(messages)
expect(prompt).toEqual(expectedPrompt)
})

it('should throw an error if a function message is included', () => {
const messages: Pick<Message, 'content' | 'role'>[] = [
{ content: 'someFunction()', role: 'function' }
]
expect(() => buildStarChatBetaPrompt(messages)).toThrow()
})
})

describe('buildOpenAssistantPrompt', () => {
it('should return a string with user and assistant messages', () => {
const messages: Pick<Message, 'content' | 'role'>[] = [
{ content: 'Hello!', role: 'user' },
{ content: 'Hi there!', role: 'assistant' }
]

const expectedPrompt =
'<|prompter|>Hello!<|endoftext|><|assistant|>Hi there!<|endoftext|><|assistant|>'
const prompt = buildOpenAssistantPrompt(messages)
expect(prompt).toEqual(expectedPrompt)
})

it('should throw an error if a function message is included', () => {
const messages: Pick<Message, 'content' | 'role'>[] = [
{ content: 'someFunction()', role: 'function' }
]
expect(() => buildOpenAssistantPrompt(messages)).toThrow()
})
})
51 changes: 51 additions & 0 deletions packages/core/prompts/huggingface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { Message } from '../shared/types'

/**
* A prompt constructor for the HuggingFace StarChat Beta model.
* Does not support `function` messages.
* @see https://huggingface.co/HuggingFaceH4/starchat-beta
*/
export function buildStarChatBetaPrompt(
messages: Pick<Message, 'content' | 'role'>[]
) {
return (
messages
.map(({ content, role }) => {
if (role === 'user') {
return `<|user|>\n${content}<|end|>\n`
} else if (role === 'assistant') {
return `<|assistant|>\n${content}<|end|>\n`
} else if (role === 'system') {
return `<|system|>\n${content}<|end|>\n`
} else if (role === 'function') {
throw new Error('StarChat Beta does not support function calls.')
}
})
.join('') + '<|assistant|>'
)
}

/**
* A prompt constructor for HuggingFace OpenAssistant models.
* Does not support `function` or `system` messages.
* @see https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5
*/
export function buildOpenAssistantPrompt(
messages: Pick<Message, 'content' | 'role'>[]
) {
return (
messages
.map(({ content, role }) => {
if (role === 'user') {
return `<|prompter|>${content}<|endoftext|>`
} else if (role === 'function') {
throw new Error('OpenAssistant does not support function calls.')
} else if (role === 'system') {
throw new Error('OpenAssistant does not support system messages.')
} else {
return `<|assistant|>${content}<|endoftext|>`
}
})
.join('') + '<|assistant|>'
)
}
1 change: 1 addition & 0 deletions packages/core/prompts/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export * from './huggingface'
7 changes: 7 additions & 0 deletions packages/core/tsup.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ export default defineConfig([
external: ['react', 'svelte', 'vue'],
dts: true
},
{
entry: ['prompts/index.ts'],
format: ['cjs', 'esm'],
external: ['react', 'svelte', 'vue'],
outDir: 'prompts/dist',
dts: true
},
// React APIs
{
entry: ['react/index.ts'],
Expand Down

0 comments on commit 9320e95

Please sign in to comment.