Skip to content

Commit

Permalink
Merge pull request #5 from sgomez/tool-choice
Browse files Browse the repository at this point in the history
Tool choice support
  • Loading branch information
sgomez committed May 18, 2024
2 parents fbff767 + 9a5ba88 commit 08ddec5
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 15 deletions.
122 changes: 122 additions & 0 deletions examples/ai-core/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
.SILENT:
.DEFAULT_GOAL := all

define RUN_EXAMPLE_TARGET
echo -- examples/$(subst _,/,$(1))
pnpm tsx src/$(subst _,/,$(1)).ts > /dev/null
endef

define RUN_EXAMPLE_CHAT
echo -- examples/$(subst _,/,$(1))
pnpm tsx src/$(subst _,/,$(1)).ts
endef

all: complex embed embed-many generate-object generate-text stream-object stream-text

# complex
.PHONY: complex complex-run complex-all semantic-router_main
complex: complex-run complex-all
complex-run:
echo - examples/complex:
complex-all: semantic-router_main
semantic-router_main:
$(call RUN_EXAMPLE_TARGET,complex/$@)


# embed
.PHONY: embed embed-run embed-all embed_ollama
embed: embed-run embed-all
embed-run:
echo - examples/embed:
embed-all: embed_ollama
embed_ollama:
$(call RUN_EXAMPLE_TARGET,$@)


# embed-many
.PHONY: embed-many embed-many-run embed-many-all embed-many_ollama
embed-many: embed-many-run embed-many-all
embed-many-run:
echo - examples/embed-many:
embed-many-all: embed-many_ollama
embed-many_ollama:
$(call RUN_EXAMPLE_TARGET,$@)


# generate-object
.PHONY: generate-object generate-object-run generate-object-all generate-object_ollama generate-object_ollama-full-json generate-object_ollama-json generate-object_ollama-multimodal generate-object_ollama-tool
generate-object: generate-object-run generate-object-all
generate-object-run:
echo - examples/generate-object:
generate-object-all: generate-object_ollama generate-object_ollama-full-json generate-object_ollama-json generate-object_ollama-multimodal generate-object_ollama-tool
generate-object_ollama:
$(call RUN_EXAMPLE_TARGET,$@)
generate-object_ollama-full-json:
$(call RUN_EXAMPLE_TARGET,$@)
generate-object_ollama-json:
$(call RUN_EXAMPLE_TARGET,$@)
generate-object_ollama-multimodal:
$(call RUN_EXAMPLE_TARGET,$@)
generate-object_ollama-tool:
$(call RUN_EXAMPLE_TARGET,$@)


# generate-text
.PHONY: generate-text generate-text-run generate-text-all generate-text_ollama generate-text_ollama-completion generate-text_ollama-completion-chat generate-text_ollama-multimodal generate-text_ollama-multimodal-base64 generate-text_ollama-multimodal-url generate-text_ollama-system-message-a generate-text_ollama-system-message-b generate-text_ollama-tool-call
generate-text: generate-text-run generate-text-all
generate-text-run:
echo - examples/generate-text:
generate-text-all: generate-text_ollama generate-text_ollama-completion generate-text_ollama-completion-chat generate-text_ollama-multimodal generate-text_ollama-multimodal-base64 generate-text_ollama-system-message-a generate-text_ollama-system-message-b generate-text_ollama-tool-call
generate-text_ollama:
$(call RUN_EXAMPLE_TARGET,$@)
generate-text_ollama-completion:
$(call RUN_EXAMPLE_TARGET,$@)
generate-text_ollama-completion-chat:
$(call RUN_EXAMPLE_TARGET,$@)
generate-text_ollama-multimodal:
$(call RUN_EXAMPLE_TARGET,$@)
generate-text_ollama-multimodal-base64:
$(call RUN_EXAMPLE_TARGET,$@)
generate-text_ollama-multimodal-url: # manual, not supported
$(call RUN_EXAMPLE_TARGET,$@)
generate-text_ollama-system-message-a:
$(call RUN_EXAMPLE_TARGET,$@)
generate-text_ollama-system-message-b:
$(call RUN_EXAMPLE_TARGET,$@)
generate-text_ollama-tool-call:
$(call RUN_EXAMPLE_TARGET,$@)


# stream-object
.PHONY: stream-object stream-object-run stream-object-all stream-object_ollama stream-object_ollama-fullstream stream-object_ollama-json
stream-object: stream-object-run stream-object-all
stream-object-run:
echo - examples/stream-object:
stream-object-all: stream-object_ollama stream-object_ollama-fullstream stream-object_ollama-json
stream-object_ollama:
$(call RUN_EXAMPLE_TARGET,$@)
stream-object_ollama-fullstream:
$(call RUN_EXAMPLE_TARGET,$@)
stream-object_ollama-json:
$(call RUN_EXAMPLE_TARGET,$@)


# stream-text
.PHONY: stream-text stream-text-run stream-text-all stream-text_ollama stream-text_ollama-abort stream-text_ollama-completion stream-text_ollama-completion-chat stream-text_ollama-reader
stream-text: stream-text-run stream-text-all
stream-text-run:
echo - examples/stream-text:
stream-text-all: stream-text_ollama stream-text_ollama-abort stream-text_ollama-completion stream-text_ollama-completion-chat stream-text_ollama-reader
stream-text_ollama:
$(call RUN_EXAMPLE_TARGET,$@)
stream-text_ollama-abort:
$(call RUN_EXAMPLE_TARGET,$@)
stream-text_ollama-chatbot: # manual
$(call RUN_EXAMPLE_CHAT,$@)
stream-text_ollama-completion:
$(call RUN_EXAMPLE_TARGET,$@)
stream-text_ollama-completion-chat:
$(call RUN_EXAMPLE_CHAT,$@)
stream-text_ollama-reader:
$(call RUN_EXAMPLE_TARGET,$@)

33 changes: 33 additions & 0 deletions examples/ai-core/src/generate-object/ollama-tool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#! /usr/bin/env -S pnpm tsx

import { generateObject } from 'ai'
import { ollama } from 'ollama-ai-provider'
import { OllamaChatModelId } from 'ollama-ai-provider/src/ollama-chat-settings'
import { z } from 'zod'

import { buildProgram } from '../tools/command'

async function main(model: OllamaChatModelId) {
const result = await generateObject({
maxTokens: 2000,
mode: 'tool',
model: ollama(model),
prompt:
'Generate 3 character descriptions for a fantasy role playing game.',
schema: z.object({
characters: z.array(
z.object({
class: z
.string()
.describe('Character class, e.g. warrior, mage, or thief.'),
description: z.string(),
name: z.string(),
}),
),
}),
})

console.log(JSON.stringify(result.object, null, 2))
}

buildProgram('mistral', main).catch(console.error)
2 changes: 1 addition & 1 deletion examples/ai-core/src/generate-text/ollama-tool-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ async function main(model: Parameters<typeof ollama>[0]) {
console.log(JSON.stringify(result, null, 2))
}

buildProgram('openhermes', main).catch(console.error)
buildProgram('mistral', main).catch(console.error)
3 changes: 3 additions & 0 deletions packages/ollama/src/convert-to-ollama-chat-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { OllamaChatPrompt } from '@/ollama-chat-prompt'
export function convertToOllamaChatMessages(
prompt: LanguageModelV1Prompt,
tools?: LanguageModelV1FunctionTool[],
toolChoice?: string,
): OllamaChatPrompt {
const messages: OllamaChatPrompt = []

Expand All @@ -22,6 +23,7 @@ export function convertToOllamaChatMessages(
messages.push({
content: injectToolsSchemaIntoSystem({
system: content,
toolChoice,
tools,
}),
role: 'system',
Expand Down Expand Up @@ -87,6 +89,7 @@ export function convertToOllamaChatMessages(
messages.unshift({
content: injectToolsSchemaIntoSystem({
system: '',
toolChoice,
tools,
}),
role: 'system',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { inferToolCallsFromResponse } from '@/generate-tool/infer-tool-calls-fro
import { OllamaChatResponseSchema } from '@/ollama-chat-language-model'

describe('inferToolCallsFromAssistantMessage', () => {
it('should infer valid tool calls', () => {
it('should infer valid selected tools response', () => {
// Arrange
const response = {
finish_reason: 'stop',
Expand Down Expand Up @@ -33,6 +33,36 @@ describe('inferToolCallsFromAssistantMessage', () => {
)
})

it('should infer valid selected tool response', () => {
// Arrange
const response = {
finish_reason: 'stop',
message: {
content: JSON.stringify({
arguments: { numbers: [2, 3] },
name: 'sum',
}),
role: 'assistant',
},
} as OllamaChatResponseSchema

// Act
const parsedResponse = inferToolCallsFromResponse(response)

// Assert
expect(parsedResponse.finish_reason).toEqual('tool-calls')
expect(parsedResponse.message.tool_calls).toContainEqual(
expect.objectContaining({
function: {
arguments: JSON.stringify({ numbers: [2, 3] }),
name: 'sum',
},
id: expect.any(String),
type: 'function',
}),
)
})

it('should ignore invalid tool calls', () => {
// Arrange
const response = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ export function inferToolCallsFromResponse(
): OllamaChatResponseSchema {
try {
const tool = JSON.parse(response.message.content)
const parsedTools = toolResponseSchema.parse(tool)

let parsedTools = toolResponseSchema.parse(tool)
if (!Array.isArray(parsedTools)) {
parsedTools = [parsedTools]
}

return {
...response,
Expand All @@ -34,9 +38,15 @@ export function inferToolCallsFromResponse(
}
}

const toolResponseSchema = z.array(
const toolResponseSchema = z.union([
z.array(
z.object({
arguments: z.record(z.unknown()),
name: z.string(),
}),
),
z.object({
arguments: z.record(z.unknown()),
name: z.string(),
}),
)
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import { LanguageModelV1FunctionTool } from '@ai-sdk/provider'
import { describe, expect, it } from 'vitest'

import { injectToolsSchemaIntoSystem } from '@/generate-tool/inject-tools-schema-into-system'

describe('injectToolsSchemaIntoSystem', () => {
it('should return system message if no tools are present', () => {
// Arrange
const system = 'You are a helpful and honest assistant.'

// Act
const systemWithTools = injectToolsSchemaIntoSystem({ system })

// Assert
expect(systemWithTools).toEqual(system)
})

it('should return system message with tools', () => {
// Arrange
const system = 'You are a helpful and honest assistant.'
const tools: LanguageModelV1FunctionTool[] = [
{
description: 'Sum numbers',
name: 'sum',
parameters: { type: 'object' },
type: 'function',
},
{
description: 'Multiply numbers',
name: 'multiply',
parameters: { type: 'object' },
type: 'function',
},
]

// Act
const systemWithTools = injectToolsSchemaIntoSystem({ system, tools })

// Assert
expect(systemWithTools).toMatch(/You are a helpful and honest assistant./)
expect(systemWithTools).toMatch(/You have access to the following tools:/)
expect(systemWithTools).toMatch(/"name":"sum"/)
expect(systemWithTools).toMatch(/"name":"multiply"/)
})

it('should return system message with choiced tool', () => {
// Arrange
const system = 'You are a helpful and honest assistant.'
const tools: LanguageModelV1FunctionTool[] = [
{
description: 'Sum numbers',
name: 'sum',
parameters: { type: 'object' },
type: 'function',
},
{
description: 'Multiply numbers',
name: 'multiply',
parameters: { type: 'object' },
type: 'function',
},
]
const toolChoice = 'sum'

// Act
const systemWithTools = injectToolsSchemaIntoSystem({
system,
toolChoice,
tools,
})

// Assert
expect(systemWithTools).toMatch(/You are a helpful and honest assistant./)
expect(systemWithTools).toMatch(/You have access to the following tools:/)
expect(systemWithTools).toMatch(/"name":"sum"/)
expect(systemWithTools).not.toMatch(/"name":"multiply"/)
})
})
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,24 @@ export function injectToolsSchemaIntoSystem({
schemaPrefix = DEFAULT_SCHEMA_PREFIX,
schemaSuffix = DEFAULT_SCHEMA_SUFFIX,
system,
toolChoice,
tools,
}: {
schemaPrefix?: string
schemaSuffix?: string
system: string
toolChoice?: string
tools?: LanguageModelV1FunctionTool[]
// tools: JSONSchema7
}): string {
if (!tools) {
const selectedTools = tools?.filter(
(tool) => !toolChoice || tool.name === toolChoice,
)

if (!selectedTools) {
return system
}

return [
system,
system === null ? null : '', // add a newline if system is not null
schemaPrefix,
JSON.stringify(tools),
schemaSuffix,
]
return [system, schemaPrefix, JSON.stringify(selectedTools), schemaSuffix]
.filter((line) => line !== null)
.join('\n')
}
6 changes: 5 additions & 1 deletion packages/ollama/src/ollama-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ export class OllamaChatLanguageModel implements LanguageModelV1 {
args: {
...baseArguments,
format: 'json',
messages: convertToOllamaChatMessages(prompt, [mode.tool]),
messages: convertToOllamaChatMessages(
prompt,
[mode.tool],
mode.tool.name,
),
tool_choice: {
function: { name: mode.tool.name },
type: 'function',
Expand Down

0 comments on commit 08ddec5

Please sign in to comment.