Skip to content

Commit

Permalink
Chat: sync token limit at model import time (#3486)
Browse files Browse the repository at this point in the history
Co-authored-by: Dominic Cooney <dominic.cooney@sourcegraph.com>
  • Loading branch information
abeatrix and dominiccooney committed Mar 22, 2024
1 parent 9647ecb commit da1fe66
Show file tree
Hide file tree
Showing 17 changed files with 249 additions and 186 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ data class ModelProvider(
val codyProOnly: Boolean? = null,
val provider: String? = null,
val title: String? = null,
val privateProviders: Map<String, ModelProvider>? = null,
val dotComProviders: List<ModelProvider>? = null,
val ollamaProvidersEnabled: Boolean? = null,
val ollamaProviders: List<ModelProvider>? = null,
val primaryProviders: List<ModelProvider>? = null,
val localProviders: List<ModelProvider>? = null,
)

2 changes: 1 addition & 1 deletion agent/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ export class Agent extends MessageHandler implements ExtensionClient {
})

this.registerAuthenticatedRequest('chat/restore', async ({ modelID, messages, chatID }) => {
const theModel = modelID ? modelID : ModelProvider.get(ModelUsage.Chat).at(0)?.model
const theModel = modelID ? modelID : ModelProvider.getProviders(ModelUsage.Chat).at(0)?.model
if (!theModel) {
throw new Error('No default chat model found')
}
Expand Down
10 changes: 10 additions & 0 deletions lib/shared/src/models/dotcom.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ModelProvider } from '.'
import { DEFAULT_CHAT_MODEL_TOKEN_LIMIT, DEFAULT_FAST_MODEL_TOKEN_LIMIT } from '../prompt/constants'
import { ModelUsage } from './types'

// The models must first be added to the custom chat models list in https://sourcegraph.com/github.com/sourcegraph/sourcegraph/-/blob/internal/completions/httpapi/chat.go?L48-51
Expand All @@ -10,6 +11,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: true,
codyProOnly: false,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude 2.1',
Expand All @@ -18,6 +20,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude Instant',
Expand All @@ -26,6 +29,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_FAST_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude 3 Haiku',
Expand All @@ -34,6 +38,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_FAST_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude 3 Sonnet',
Expand All @@ -42,6 +47,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'Claude 3 Opus',
Expand All @@ -50,6 +56,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'GPT-3.5 Turbo',
Expand All @@ -58,6 +65,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_FAST_MODEL_TOKEN_LIMIT,
},
{
title: 'GPT-4 Turbo Preview',
Expand All @@ -66,6 +74,7 @@ export const DEFAULT_DOT_COM_MODELS = [
default: false,
codyProOnly: true,
usage: [ModelUsage.Chat, ModelUsage.Edit],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
{
title: 'Mixtral 8x7B',
Expand All @@ -75,5 +84,6 @@ export const DEFAULT_DOT_COM_MODELS = [
codyProOnly: true,
// TODO: Improve prompt for Mixtral + Edit to see if we can use it there too.
usage: [ModelUsage.Chat],
maxToken: DEFAULT_CHAT_MODEL_TOKEN_LIMIT,
},
] as const satisfies ModelProvider[]
35 changes: 35 additions & 0 deletions lib/shared/src/models/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { beforeAll, describe, expect, it } from 'vitest'
import { ModelProvider } from '../models/index'
import { DEFAULT_FAST_MODEL_TOKEN_LIMIT, tokensToChars } from '../prompt/constants'
import { DOTCOM_URL } from '../sourcegraph-api/environments'
import { DEFAULT_DOT_COM_MODELS } from './dotcom'
import { ModelUsage } from './types'

describe('getMaxCharsByModel', () => {
beforeAll(() => {
ModelProvider.getProviders(ModelUsage.Chat, DOTCOM_URL.toString())
})

it('returns default token limit for unknown model', () => {
const maxChars = ModelProvider.getMaxCharsByModel('unknown-model')
expect(maxChars).toEqual(tokensToChars(DEFAULT_FAST_MODEL_TOKEN_LIMIT))
})

it('returns max token limit for known chat model', () => {
const maxChars = ModelProvider.getMaxCharsByModel(DEFAULT_DOT_COM_MODELS[0].model)
expect(maxChars).toEqual(tokensToChars(DEFAULT_DOT_COM_MODELS[0].maxToken))
})

it('returns default token limit for unknown model - Enterprise user', () => {
ModelProvider.getProviders(ModelUsage.Chat, 'https://example.com')
const maxChars = ModelProvider.getMaxCharsByModel('unknown-model')
expect(maxChars).toEqual(tokensToChars(DEFAULT_FAST_MODEL_TOKEN_LIMIT))
})

it('returns max token limit for known model - Enterprise user', () => {
ModelProvider.getProviders(ModelUsage.Chat, 'https://example.com')
ModelProvider.setProviders([new ModelProvider('model-with-limit', [ModelUsage.Chat], 200)])
const maxChars = ModelProvider.getMaxCharsByModel('model-with-limit')
expect(maxChars).toEqual(tokensToChars(200))
})
})
138 changes: 51 additions & 87 deletions lib/shared/src/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,123 +1,87 @@
import { logError } from '../logger'
import { OLLAMA_DEFAULT_URL } from '../ollama'
import { isDotCom } from '../sourcegraph-api/environments'
import { DEFAULT_FAST_MODEL_TOKEN_LIMIT, tokensToChars } from '../prompt/constants'
import { DEFAULT_DOT_COM_MODELS } from './dotcom'
import { ModelUsage } from './types'
import { getProviderName } from './utils'
import type { ModelUsage } from './types'
import { fetchLocalOllamaModels, getModelInfo } from './utils'

/**
* ModelProvider manages available chat and edit models.
* It stores a set of available providers and methods to add,
* retrieve and select between them.
*/
export class ModelProvider {
// Whether the model is the default model
public default = false
// Whether the model is only available to Pro users
public codyProOnly = false
// The name of the provider of the model, e.g. "Anthropic"
public provider: string
// The title of the model, e.g. "Claude 2.0"
public readonly title: string

constructor(
// The model id that includes the provider name & the model name,
// e.g. "anthropic/claude-2.0"
public readonly model: string,
// The usage of the model, e.g. chat or edit.
public readonly usage: ModelUsage[],
isDefaultModel = true
// The maximum number of tokens that can be processed by the model in a single request.
// NOTE: A token is equivalent to 4 characters/bytes.
public readonly maxToken: number = DEFAULT_FAST_MODEL_TOKEN_LIMIT
) {
const splittedModel = model.split('/')
this.provider = getProviderName(splittedModel[0])
this.title = splittedModel[1]?.replaceAll('-', ' ')
this.default = isDefaultModel
const { provider, title } = getModelInfo(model)
this.provider = provider
this.title = title
}

// Providers available for non-dotcom instances
private static privateProviders: Map<string, ModelProvider> = new Map()
// Providers available for dotcom instances
private static dotComProviders: ModelProvider[] = DEFAULT_DOT_COM_MODELS
// Providers available from local ollama instances
private static ollamaProvidersEnabled = false
private static ollamaProviders: ModelProvider[] = []

public static onConfigChange(enableOllamaModels: boolean): void {
ModelProvider.ollamaProvidersEnabled = enableOllamaModels
ModelProvider.ollamaProviders = []
if (enableOllamaModels) {
ModelProvider.getLocalOllamaModels()
}
/**
* Get all the providers currently available to the user
*/
private static get providers(): ModelProvider[] {
return ModelProvider.primaryProviders.concat(ModelProvider.localProviders)
}

/**
* Fetches available Ollama models from the local Ollama server
* and adds them to the list of ollama providers.
* Providers available on the user's Sourcegraph instance
*/
public static getLocalOllamaModels(): void {
const isAgentTesting = process.env.CODY_SHIM_TESTING === 'true'
private static primaryProviders: ModelProvider[] = DEFAULT_DOT_COM_MODELS
/**
* Providers available from user's local instances, e.g. Ollama
*/
private static localProviders: ModelProvider[] = []

public static async onConfigChange(enableOllamaModels: boolean): Promise<void> {
// Only fetch local models if user has enabled the config
if (isAgentTesting || !ModelProvider.ollamaProvidersEnabled) {
return
}
// TODO (bee) watch file change to determine if a new model is added
// to eliminate the needs of restarting the extension to get the new models
fetch(new URL('/api/tags', OLLAMA_DEFAULT_URL).href)
.then(response => response.json())
.then(
data => {
const models = new Set<ModelProvider>()
for (const model of data.models) {
const name = `ollama/${model.model}`
const newModel = new ModelProvider(name, [ModelUsage.Chat, ModelUsage.Edit])
models.add(newModel)
}
ModelProvider.ollamaProviders = Array.from(models)
},
error => {
const fetchFailedErrors = ['Failed to fetch', 'fetch failed']
const isFetchFailed = fetchFailedErrors.some(err => error.toString().includes(err))
const serverErrorMsg = 'Please make sure the Ollama server is up & running.'
logError('getLocalOllamaModels: failed ', isFetchFailed ? serverErrorMsg : error)
}
)
ModelProvider.localProviders = enableOllamaModels ? await fetchLocalOllamaModels() : []
}

/**
* Adds a new model provider, instantiated from the given model string,
* to the internal providers set. This allows new models to be added and
* made available for use.
* Sets the primary model providers.
* NOTE: private instances can only support 1 provider atm
*/
public static add(provider: ModelProvider): void {
// private instances can only support 1 provider atm
if (ModelProvider.privateProviders.size) {
ModelProvider.privateProviders.clear()
}
ModelProvider.privateProviders.set(provider.model.trim(), provider)
public static setProviders(providers: ModelProvider[]): void {
ModelProvider.primaryProviders = providers
}

/**
* Gets the model providers based on the endpoint and current model.
* If endpoint is a dotcom endpoint, returns dotComProviders with ollama providers.
* Get the list of the primary models providers with local models.
* If currentModel is provided, sets it as the default model.
*/
public static get(
type: ModelUsage,
endpoint?: string | null,
currentModel?: string
): ModelProvider[] {
const isDotComUser = !endpoint || (endpoint && isDotCom(endpoint))
const models = (
isDotComUser
? ModelProvider.dotComProviders
: Array.from(ModelProvider.privateProviders.values())
)
.concat(ModelProvider.ollamaProviders)
.filter(model => model.usage.includes(type))

if (!isDotComUser) {
return models
}

// Set the current model as default
return models.map(model => {
return {
public static getProviders(type: ModelUsage, currentModel?: string): ModelProvider[] {
return ModelProvider.providers
.filter(m => m.usage.includes(type))
?.map(model => ({
...model,
// Set the current model as default
default: model.model === currentModel,
}
})
}))
}

/**
* Finds the model provider with the given model ID and returns its characters limit.
* The limit is calculated based on the max number of tokens the model can process.
* E.g. 7000 tokens * 4 characters/token = 28000 characters
*/
public static getMaxCharsByModel(modelID: string): number {
const model = ModelProvider.providers.find(m => m.model === modelID)
return tokensToChars(model?.maxToken || DEFAULT_FAST_MODEL_TOKEN_LIMIT)
}
}
36 changes: 36 additions & 0 deletions lib/shared/src/models/utils.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { describe, expect, it } from 'vitest'
import { getModelInfo } from './utils'

describe('getModelInfo', () => {
it('splits model ID and returns provider and title', () => {
const result = getModelInfo('Anthropic/Claude 2.0')
expect(result).toEqual({
provider: 'Anthropic',
title: 'Claude 2.0',
})
})

it('handles model ID without title', () => {
const result = getModelInfo('Anthropic/')
expect(result).toEqual({
provider: 'Anthropic',
title: '',
})
})

it('replaces dashes in title with spaces', () => {
const result = getModelInfo('example/model-with-dashes')
expect(result).toEqual({
provider: 'example',
title: 'model with dashes',
})
})

it('handles model ID with multiple dashes', () => {
const result = getModelInfo('fireworks/accounts/fireworks/models/mixtral-8x7b-instruct')
expect(result).toEqual({
provider: 'fireworks',
title: 'mixtral 8x7b instruct',
})
})
})
46 changes: 46 additions & 0 deletions lib/shared/src/models/utils.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import { ModelProvider } from '.'
import { logError } from '../logger'
import { OLLAMA_DEFAULT_URL } from '../ollama'
import { DEFAULT_FAST_MODEL_CHARS_LIMIT } from '../prompt/constants'
import { ModelUsage } from './types'
export function getProviderName(name: string): string {
const providerName = name.toLowerCase()
switch (providerName) {
Expand All @@ -15,3 +20,44 @@ export function getProviderName(name: string): string {
export function supportsFastPath(model: string): boolean {
return model?.startsWith('anthropic/claude-3')
}

/**
* Gets the provider and title from a model ID string.
*/
export function getModelInfo(modelID: string): {
provider: string
title: string
} {
const [providerID, ...rest] = modelID.split('/')
const provider = getProviderName(providerID)
const title = (rest.at(-1) || '').replace(/-/g, ' ')
return { provider, title }
}

/**
* Fetches available Ollama models from the Ollama server.
*/
export async function fetchLocalOllamaModels(): Promise<ModelProvider[]> {
// TODO (bee) watch file change to determine if a new model is added
// to eliminate the needs of restarting the extension to get the new models
return await fetch(new URL('/api/tags', OLLAMA_DEFAULT_URL).href)
.then(response => response.json())
.then(
data =>
data?.models?.map(
(m: { model: string }) =>
new ModelProvider(
`ollama/${m.model}`,
[ModelUsage.Chat, ModelUsage.Edit],
DEFAULT_FAST_MODEL_CHARS_LIMIT
)
),
error => {
const fetchFailedErrors = ['Failed to fetch', 'fetch failed']
const isFetchFailed = fetchFailedErrors.some(err => error.toString().includes(err))
const serverErrorMsg = 'Please make sure the Ollama server is up & running.'
logError('getLocalOllamaModels: failed ', isFetchFailed ? serverErrorMsg : error)
return []
}
)
}
Loading

0 comments on commit da1fe66

Please sign in to comment.