Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat: sync token limit at model import time #3486

Merged
merged 17 commits into from
Mar 22, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ data class ModelProvider(
val codyProOnly: Boolean? = null,
val provider: String? = null,
val title: String? = null,
val privateProviders: Map<String, ModelProvider>? = null,
val dotComProviders: List<ModelProvider>? = null,
val ollamaProvidersEnabled: Boolean? = null,
val primaryProviders: List<ModelProvider>? = null,
val ollamaProviders: List<ModelProvider>? = null,
)

2 changes: 1 addition & 1 deletion agent/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ export class Agent extends MessageHandler implements ExtensionClient {
})

this.registerAuthenticatedRequest('chat/restore', async ({ modelID, messages, chatID }) => {
const theModel = modelID ? modelID : ModelProvider.get(ModelUsage.Chat).at(0)?.model
const theModel = modelID ? modelID : ModelProvider.getProviders(ModelUsage.Chat).at(0)?.model
if (!theModel) {
throw new Error('No default chat model found')
}
Expand Down
10 changes: 10 additions & 0 deletions lib/shared/src/models/dotcom.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ModelProvider } from '.'
import { DEFAULT_CHAT_MODEL_TOKEN_LIMIT, DEFAULT_FAST_MODEL_TOKEN_LIMIT } from '../prompt/constants'
import { ModelUsage } from './types'

// The models must first be added to the custom chat models list in https://sourcegraph.com/github.com/sourcegraph/sourcegraph/-/blob/internal/completions/httpapi/chat.go?L48-51
Expand All @@ -10,6 +11,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: true,
codyProOnly: false,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude 2.1',
Expand All @@ -18,6 +20,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude Instant',
Expand All @@ -26,6 +29,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_FAST_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude 3 Haiku',
Expand All @@ -34,6 +38,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_FAST_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude 3 Sonnet',
Expand All @@ -42,6 +47,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude 3 Opus',
Expand All @@ -50,6 +56,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'GPT-3.5 Turbo',
Expand All @@ -58,6 +65,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_FAST_MODEL_TOKEN_LIMIT,
},
{
title: 'GPT-4 Turbo Preview',
Expand All @@ -66,6 +74,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'Mixtral 8x7B',
Expand All @@ -75,5 +84,6 @@ export const DEFAULT_DOT_COM_MODELS = [
codyProOnly: true,
// TODO: Improve prompt for Mixtral + Edit to see if we can use it there too.
usage: [ModelUsage.Chat],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
] as const satisfies ModelProvider[]
35 changes: 35 additions & 0 deletions lib/shared/src/models/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { beforeAll, describe, expect, it } from 'vitest'
import { ModelProvider } from '../models/index'
import { DEFAULT_FAST_MODEL_TOKEN_LIMIT, tokensToChars } from '../prompt/constants'
import { DOTCOM_URL } from '../sourcegraph-api/environments'
import { DEFAULT_DOT_COM_MODELS } from './dotcom'
import { ModelUsage } from './types'

describe('getMaxCharsByModel', () => {
beforeAll(() => {
ModelProvider.getProviders(ModelUsage.Chat, DOTCOM_URL.toString())
})

it('returns default token limit for unknown model', () => {
const maxChars = ModelProvider.getMaxCharsByModel('unknown-model')
expect(maxChars).toEqual(tokensToChars(DEFAULT_FAST_MODEL_TOKEN_LIMIT))
})

it('returns max token limit for known chat model', () => {
const maxChars = ModelProvider.getMaxCharsByModel(DEFAULT_DOT_COM_MODELS[0].model)
expect(maxChars).toEqual(tokensToChars(DEFAULT_DOT_COM_MODELS[0].maxToken))
})

it('returns default token limit for unknown model - Enterprise user', () => {
ModelProvider.getProviders(ModelUsage.Chat, 'https://example.com')
const maxChars = ModelProvider.getMaxCharsByModel('unknown-model')
expect(maxChars).toEqual(tokensToChars(DEFAULT_FAST_MODEL_TOKEN_LIMIT))
})

it('returns max token limit for known model - Enterprise user', () => {
ModelProvider.getProviders(ModelUsage.Chat, 'https://example.com')
ModelProvider.setProviders([new ModelProvider('model-with-limit', [ModelUsage.Chat], 200)])
const maxChars = ModelProvider.getMaxCharsByModel('model-with-limit')
expect(maxChars).toEqual(tokensToChars(200))
})
})
155 changes: 77 additions & 78 deletions lib/shared/src/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import { logError } from '../logger'
import { OLLAMA_DEFAULT_URL } from '../ollama'
import { isDotCom } from '../sourcegraph-api/environments'
import {
DEFAULT_FAST_MODEL_CHARS_LIMIT,
DEFAULT_FAST_MODEL_TOKEN_LIMIT,
tokensToChars,
} from '../prompt/constants'
import { DEFAULT_DOT_COM_MODELS } from './dotcom'
import { ModelUsage } from './types'
import { getProviderName } from './utils'
import { getModelInfo } from './utils'

/**
* ModelProvider manages available chat and edit models.
Expand All @@ -13,111 +17,106 @@ import { getProviderName } from './utils'
export class ModelProvider {
public default = false
public codyProOnly = false
// The name of the provider of the model, e.g. "Anthropic"
public provider: string
// The title of the model, e.g. "Claude 2.0"
public readonly title: string

constructor(
// The model id that includes the provider name & the model name,
// e.g. "anthropic/claude-2.0"
public readonly model: string,
// The usage of the model, e.g. chat or edit.
public readonly usage: ModelUsage[],
isDefaultModel = true
// The maximum number of tokens that can be processed by the model in a single request.
// NOTE: A token is equivalent to 4 characters/bytes.
public readonly maxToken: number = DEFAULT_FAST_MODEL_TOKEN_LIMIT
) {
const splittedModel = model.split('/')
this.provider = getProviderName(splittedModel[0])
this.title = splittedModel[1]?.replaceAll('-', ' ')
this.default = isDefaultModel
}

// Providers available for non-dotcom instances
private static privateProviders: Map<string, ModelProvider> = new Map()
// Providers available for dotcom instances
private static dotComProviders: ModelProvider[] = DEFAULT_DOT_COM_MODELS
// Providers available from local ollama instances
private static ollamaProvidersEnabled = false
private static ollamaProviders: ModelProvider[] = []

public static onConfigChange(enableOllamaModels: boolean): void {
ModelProvider.ollamaProvidersEnabled = enableOllamaModels
ModelProvider.ollamaProviders = []
if (enableOllamaModels) {
ModelProvider.getLocalOllamaModels()
}
const { provider, title } = getModelInfo(model)
this.provider = provider
this.title = title
}

/**
* Fetches available Ollama models from the local Ollama server
* and adds them to the list of ollama providers.
* Providers available on the user's instance
*/
private static primaryProviders: ModelProvider[] = DEFAULT_DOT_COM_MODELS
/**
* Providers available from local ollama instances
*/
public static getLocalOllamaModels(): void {
const isAgentTesting = process.env.CODY_SHIM_TESTING === 'true'
private static ollamaProviders: ModelProvider[] = []
public static async onConfigChange(enableOllamaModels: boolean): Promise<void> {
// Only fetch local models if user has enabled the config
if (isAgentTesting || !ModelProvider.ollamaProvidersEnabled) {
return
}
// TODO (bee) watch file change to determine if a new model is added
// to eliminate the needs of restarting the extension to get the new models
fetch(new URL('/api/tags', OLLAMA_DEFAULT_URL).href)
.then(response => response.json())
.then(
data => {
const models = new Set<ModelProvider>()
for (const model of data.models) {
const name = `ollama/${model.model}`
const newModel = new ModelProvider(name, [ModelUsage.Chat, ModelUsage.Edit])
models.add(newModel)
}
ModelProvider.ollamaProviders = Array.from(models)
},
error => {
const fetchFailedErrors = ['Failed to fetch', 'fetch failed']
const isFetchFailed = fetchFailedErrors.some(err => error.toString().includes(err))
const serverErrorMsg = 'Please make sure the Ollama server is up & running.'
logError('getLocalOllamaModels: failed ', isFetchFailed ? serverErrorMsg : error)
}
)
ModelProvider.ollamaProviders = enableOllamaModels ? await fetchLocalOllamaModels() : []
}

/**
* Adds a new model provider, instantiated from the given model string,
* to the internal providers set. This allows new models to be added and
* made available for use.
* Sets the primary model providers.
* NOTE: private instances can only support 1 provider atm
*/
public static add(provider: ModelProvider): void {
// private instances can only support 1 provider atm
if (ModelProvider.privateProviders.size) {
ModelProvider.privateProviders.clear()
}
ModelProvider.privateProviders.set(provider.model.trim(), provider)
public static setProviders(providers: ModelProvider[]): void {
ModelProvider.primaryProviders = providers
}

/**
* Gets the model providers based on the endpoint and current model.
* If endpoint is a dotcom endpoint, returns dotComProviders with ollama providers.
* Get the list of the primary models providers with local models.
* If currentModel is provided, sets it as the default model.
*/
public static get(
type: ModelUsage,
endpoint?: string | null,
currentModel?: string
): ModelProvider[] {
const isDotComUser = !endpoint || (endpoint && isDotCom(endpoint))
const models = (
isDotComUser
? ModelProvider.dotComProviders
: Array.from(ModelProvider.privateProviders.values())
)
public static getProviders(type: ModelUsage, currentModel?: string): ModelProvider[] {
const models = ModelProvider.primaryProviders
abeatrix marked this conversation as resolved.
Show resolved Hide resolved
.concat(ModelProvider.ollamaProviders)
.filter(model => model.usage.includes(type))

if (!isDotComUser) {
return models
}

// Set the current model as default
return models.map(model => {
return models?.map(model => {
return {
...model,
default: model.model === currentModel,
}
})
}

/**
* Finds the model provider with the given model ID and returns its characters limit.
* The limit is calculated based on the max number of tokens the model can process.
* E.g. 7000 tokens * 4 characters/token = 28000 characters
*/
public static getMaxCharsByModel(modelID: string): number {
const model = ModelProvider.primaryProviders
.concat(ModelProvider.ollamaProviders)
.find(m => m.model === modelID)
return tokensToChars(model?.maxToken || DEFAULT_FAST_MODEL_TOKEN_LIMIT)
}
}

/**
* Fetches available Ollama models from the local Ollama server
* and adds them to the list of ollama providers.
*/
export async function fetchLocalOllamaModels(): Promise<ModelProvider[]> {
if (process.env.CODY_SHIM_TESTING === 'true') {
abeatrix marked this conversation as resolved.
Show resolved Hide resolved
return []
}
// TODO (bee) watch file change to determine if a new model is added
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no file to watch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we can watch for changes in ~/. ollama/models/manifests/registry.ollama.ai/library where the ollama models are stored.

// to eliminate the needs of restarting the extension to get the new models
return await fetch(new URL('/api/tags', OLLAMA_DEFAULT_URL).href)
.then(response => response.json())
.then(
data =>
data?.models?.map(
(m: { model: string }) =>
new ModelProvider(
`ollama/${m.model}`,
[ModelUsage.Chat, ModelUsage.Edit],
DEFAULT_FAST_MODEL_CHARS_LIMIT
)
),
error => {
const fetchFailedErrors = ['Failed to fetch', 'fetch failed']
const isFetchFailed = fetchFailedErrors.some(err => error.toString().includes(err))
const serverErrorMsg = 'Please make sure the Ollama server is up & running.'
logError('getLocalOllamaModels: failed ', isFetchFailed ? serverErrorMsg : error)
return []
}
)
}
36 changes: 36 additions & 0 deletions lib/shared/src/models/utils.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { describe, expect, it } from 'vitest'
import { getModelInfo } from './utils'

describe('getModelInfo', () => {
it('splits model ID and returns provider and title', () => {
const result = getModelInfo('Anthropic/Claude 2.0')
expect(result).toEqual({
provider: 'Anthropic',
title: 'Claude 2.0',
})
})

it('handles model ID without title', () => {
const result = getModelInfo('Anthropic/')
expect(result).toEqual({
provider: 'Anthropic',
title: '',
})
})

it('replaces dashes in title with spaces', () => {
const result = getModelInfo('example/model-with-dashes')
expect(result).toEqual({
provider: 'example',
title: 'model with dashes',
})
})

it('handles model ID with multiple dashes', () => {
const result = getModelInfo('fireworks/accounts/fireworks/models/mixtral-8x7b-instruct')
expect(result).toEqual({
provider: 'fireworks',
title: 'mixtral 8x7b instruct',
})
})
})
13 changes: 13 additions & 0 deletions lib/shared/src/models/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,16 @@ export function getProviderName(name: string): string {
export function supportsFastPath(model: string): boolean {
return model?.startsWith('anthropic/claude-3')
}

/**
* Gets the provider and title from a model ID string.
*/
export function getModelInfo(modelID: string): {
provider: string
title: string
} {
const [providerID, ...rest] = modelID.split('/')
const provider = getProviderName(providerID)
const title = (rest.at(-1) || '').replace(/-/g, ' ')
return { provider, title }
}