-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Chat: sync token limit at model import time (#3486)
Co-authored-by: Dominic Cooney <dominic.cooney@sourcegraph.com>
- Loading branch information
1 parent
9647ecb
commit da1fe66
Showing
17 changed files
with
249 additions
and
186 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import { beforeAll, describe, expect, it } from 'vitest' | ||
import { ModelProvider } from '../models/index' | ||
import { DEFAULT_FAST_MODEL_TOKEN_LIMIT, tokensToChars } from '../prompt/constants' | ||
import { DOTCOM_URL } from '../sourcegraph-api/environments' | ||
import { DEFAULT_DOT_COM_MODELS } from './dotcom' | ||
import { ModelUsage } from './types' | ||
|
||
describe('getMaxCharsByModel', () => { | ||
beforeAll(() => { | ||
ModelProvider.getProviders(ModelUsage.Chat, DOTCOM_URL.toString()) | ||
}) | ||
|
||
it('returns default token limit for unknown model', () => { | ||
const maxChars = ModelProvider.getMaxCharsByModel('unknown-model') | ||
expect(maxChars).toEqual(tokensToChars(DEFAULT_FAST_MODEL_TOKEN_LIMIT)) | ||
}) | ||
|
||
it('returns max token limit for known chat model', () => { | ||
const maxChars = ModelProvider.getMaxCharsByModel(DEFAULT_DOT_COM_MODELS[0].model) | ||
expect(maxChars).toEqual(tokensToChars(DEFAULT_DOT_COM_MODELS[0].maxToken)) | ||
}) | ||
|
||
it('returns default token limit for unknown model - Enterprise user', () => { | ||
ModelProvider.getProviders(ModelUsage.Chat, 'https://example.com') | ||
const maxChars = ModelProvider.getMaxCharsByModel('unknown-model') | ||
expect(maxChars).toEqual(tokensToChars(DEFAULT_FAST_MODEL_TOKEN_LIMIT)) | ||
}) | ||
|
||
it('returns max token limit for known model - Enterprise user', () => { | ||
ModelProvider.getProviders(ModelUsage.Chat, 'https://example.com') | ||
ModelProvider.setProviders([new ModelProvider('model-with-limit', [ModelUsage.Chat], 200)]) | ||
const maxChars = ModelProvider.getMaxCharsByModel('model-with-limit') | ||
expect(maxChars).toEqual(tokensToChars(200)) | ||
}) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,123 +1,87 @@ | ||
import { logError } from '../logger' | ||
import { OLLAMA_DEFAULT_URL } from '../ollama' | ||
import { isDotCom } from '../sourcegraph-api/environments' | ||
import { DEFAULT_FAST_MODEL_TOKEN_LIMIT, tokensToChars } from '../prompt/constants' | ||
import { DEFAULT_DOT_COM_MODELS } from './dotcom' | ||
import { ModelUsage } from './types' | ||
import { getProviderName } from './utils' | ||
import type { ModelUsage } from './types' | ||
import { fetchLocalOllamaModels, getModelInfo } from './utils' | ||
|
||
/** | ||
* ModelProvider manages available chat and edit models. | ||
* It stores a set of available providers and methods to add, | ||
* retrieve and select between them. | ||
*/ | ||
export class ModelProvider { | ||
// Whether the model is the default model | ||
public default = false | ||
// Whether the model is only available to Pro users | ||
public codyProOnly = false | ||
// The name of the provider of the model, e.g. "Anthropic" | ||
public provider: string | ||
// The title of the model, e.g. "Claude 2.0" | ||
public readonly title: string | ||
|
||
constructor( | ||
// The model id that includes the provider name & the model name, | ||
// e.g. "anthropic/claude-2.0" | ||
public readonly model: string, | ||
// The usage of the model, e.g. chat or edit. | ||
public readonly usage: ModelUsage[], | ||
isDefaultModel = true | ||
// The maximum number of tokens that can be processed by the model in a single request. | ||
// NOTE: A token is equivalent to 4 characters/bytes. | ||
public readonly maxToken: number = DEFAULT_FAST_MODEL_TOKEN_LIMIT | ||
) { | ||
const splittedModel = model.split('/') | ||
this.provider = getProviderName(splittedModel[0]) | ||
this.title = splittedModel[1]?.replaceAll('-', ' ') | ||
this.default = isDefaultModel | ||
const { provider, title } = getModelInfo(model) | ||
this.provider = provider | ||
this.title = title | ||
} | ||
|
||
// Providers available for non-dotcom instances | ||
private static privateProviders: Map<string, ModelProvider> = new Map() | ||
// Providers available for dotcom instances | ||
private static dotComProviders: ModelProvider[] = DEFAULT_DOT_COM_MODELS | ||
// Providers available from local ollama instances | ||
private static ollamaProvidersEnabled = false | ||
private static ollamaProviders: ModelProvider[] = [] | ||
|
||
public static onConfigChange(enableOllamaModels: boolean): void { | ||
ModelProvider.ollamaProvidersEnabled = enableOllamaModels | ||
ModelProvider.ollamaProviders = [] | ||
if (enableOllamaModels) { | ||
ModelProvider.getLocalOllamaModels() | ||
} | ||
/** | ||
* Get all the providers currently available to the user | ||
*/ | ||
private static get providers(): ModelProvider[] { | ||
return ModelProvider.primaryProviders.concat(ModelProvider.localProviders) | ||
} | ||
|
||
/** | ||
* Fetches available Ollama models from the local Ollama server | ||
* and adds them to the list of ollama providers. | ||
* Providers available on the user's Sourcegraph instance | ||
*/ | ||
public static getLocalOllamaModels(): void { | ||
const isAgentTesting = process.env.CODY_SHIM_TESTING === 'true' | ||
private static primaryProviders: ModelProvider[] = DEFAULT_DOT_COM_MODELS | ||
/** | ||
* Providers available from user's local instances, e.g. Ollama | ||
*/ | ||
private static localProviders: ModelProvider[] = [] | ||
|
||
public static async onConfigChange(enableOllamaModels: boolean): Promise<void> { | ||
// Only fetch local models if user has enabled the config | ||
if (isAgentTesting || !ModelProvider.ollamaProvidersEnabled) { | ||
return | ||
} | ||
// TODO (bee) watch file change to determine if a new model is added | ||
// to eliminate the needs of restarting the extension to get the new models | ||
fetch(new URL('/api/tags', OLLAMA_DEFAULT_URL).href) | ||
.then(response => response.json()) | ||
.then( | ||
data => { | ||
const models = new Set<ModelProvider>() | ||
for (const model of data.models) { | ||
const name = `ollama/${model.model}` | ||
const newModel = new ModelProvider(name, [ModelUsage.Chat, ModelUsage.Edit]) | ||
models.add(newModel) | ||
} | ||
ModelProvider.ollamaProviders = Array.from(models) | ||
}, | ||
error => { | ||
const fetchFailedErrors = ['Failed to fetch', 'fetch failed'] | ||
const isFetchFailed = fetchFailedErrors.some(err => error.toString().includes(err)) | ||
const serverErrorMsg = 'Please make sure the Ollama server is up & running.' | ||
logError('getLocalOllamaModels: failed ', isFetchFailed ? serverErrorMsg : error) | ||
} | ||
) | ||
ModelProvider.localProviders = enableOllamaModels ? await fetchLocalOllamaModels() : [] | ||
} | ||
|
||
/** | ||
* Adds a new model provider, instantiated from the given model string, | ||
* to the internal providers set. This allows new models to be added and | ||
* made available for use. | ||
* Sets the primary model providers. | ||
* NOTE: private instances can only support 1 provider atm | ||
*/ | ||
public static add(provider: ModelProvider): void { | ||
// private instances can only support 1 provider atm | ||
if (ModelProvider.privateProviders.size) { | ||
ModelProvider.privateProviders.clear() | ||
} | ||
ModelProvider.privateProviders.set(provider.model.trim(), provider) | ||
public static setProviders(providers: ModelProvider[]): void { | ||
ModelProvider.primaryProviders = providers | ||
} | ||
|
||
/** | ||
* Gets the model providers based on the endpoint and current model. | ||
* If endpoint is a dotcom endpoint, returns dotComProviders with ollama providers. | ||
* Get the list of the primary models providers with local models. | ||
* If currentModel is provided, sets it as the default model. | ||
*/ | ||
public static get( | ||
type: ModelUsage, | ||
endpoint?: string | null, | ||
currentModel?: string | ||
): ModelProvider[] { | ||
const isDotComUser = !endpoint || (endpoint && isDotCom(endpoint)) | ||
const models = ( | ||
isDotComUser | ||
? ModelProvider.dotComProviders | ||
: Array.from(ModelProvider.privateProviders.values()) | ||
) | ||
.concat(ModelProvider.ollamaProviders) | ||
.filter(model => model.usage.includes(type)) | ||
|
||
if (!isDotComUser) { | ||
return models | ||
} | ||
|
||
// Set the current model as default | ||
return models.map(model => { | ||
return { | ||
public static getProviders(type: ModelUsage, currentModel?: string): ModelProvider[] { | ||
return ModelProvider.providers | ||
.filter(m => m.usage.includes(type)) | ||
?.map(model => ({ | ||
...model, | ||
// Set the current model as default | ||
default: model.model === currentModel, | ||
} | ||
}) | ||
})) | ||
} | ||
|
||
/** | ||
* Finds the model provider with the given model ID and returns its characters limit. | ||
* The limit is calculated based on the max number of tokens the model can process. | ||
* E.g. 7000 tokens * 4 characters/token = 28000 characters | ||
*/ | ||
public static getMaxCharsByModel(modelID: string): number { | ||
const model = ModelProvider.providers.find(m => m.model === modelID) | ||
return tokensToChars(model?.maxToken || DEFAULT_FAST_MODEL_TOKEN_LIMIT) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import { describe, expect, it } from 'vitest' | ||
import { getModelInfo } from './utils' | ||
|
||
describe('getModelInfo', () => { | ||
it('splits model ID and returns provider and title', () => { | ||
const result = getModelInfo('Anthropic/Claude 2.0') | ||
expect(result).toEqual({ | ||
provider: 'Anthropic', | ||
title: 'Claude 2.0', | ||
}) | ||
}) | ||
|
||
it('handles model ID without title', () => { | ||
const result = getModelInfo('Anthropic/') | ||
expect(result).toEqual({ | ||
provider: 'Anthropic', | ||
title: '', | ||
}) | ||
}) | ||
|
||
it('replaces dashes in title with spaces', () => { | ||
const result = getModelInfo('example/model-with-dashes') | ||
expect(result).toEqual({ | ||
provider: 'example', | ||
title: 'model with dashes', | ||
}) | ||
}) | ||
|
||
it('handles model ID with multiple dashes', () => { | ||
const result = getModelInfo('fireworks/accounts/fireworks/models/mixtral-8x7b-instruct') | ||
expect(result).toEqual({ | ||
provider: 'fireworks', | ||
title: 'mixtral 8x7b instruct', | ||
}) | ||
}) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.