Skip to content

Commit

Permalink
adding FIM finetuned model hosted on fireworks (#4245)
Browse files Browse the repository at this point in the history
  • Loading branch information
hitesh-1997 committed May 23, 2024
1 parent f3abf5e commit 4b9fb4c
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 26 deletions.
10 changes: 9 additions & 1 deletion lib/shared/src/experimentation/FeatureFlagProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ export enum FeatureFlag {
// Enable StarCoder2 7b and 15b as the default model via Fireworks
CodyAutocompleteStarCoder2Hybrid = 'cody-autocomplete-starcoder2-hybrid',
// Enable the FineTuned model as the default model via Fireworks
CodyAutocompleteFineTunedModel = 'cody-autocomplete-finetuned-model',

// Enable various feature flags to experiment with FIM trained fine-tuned models via Fireworks
CodyAutocompleteFIMFineTunedModelBaseFeatureFlag = 'cody-autocomplete-fim-fine-tuned-model-experiment-flag',
CodyAutocompleteFIMFineTunedModelControl = 'cody-autocomplete-fim-fine-tuned-model-control',
CodyAutocompleteFIMFineTunedModelVariant1 = 'cody-autocomplete-fim-fine-tuned-model-variant-1',
CodyAutocompleteFIMFineTunedModelVariant2 = 'cody-autocomplete-fim-fine-tuned-model-variant-2',
CodyAutocompleteFIMFineTunedModelVariant3 = 'cody-autocomplete-fim-fine-tuned-model-variant-3',
CodyAutocompleteFIMFineTunedModelVariant4 = 'cody-autocomplete-fim-fine-tuned-model-variant-4',

// Enables Claude 3 if the user is in our holdout group
CodyAutocompleteClaude3 = 'cody-autocomplete-claude-3',
// Enables the bfg-mixed context retriever that will combine BFG with the default local editor
Expand Down
78 changes: 66 additions & 12 deletions vscode/src/completions/providers/create-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ import {
featureFlagProvider,
} from '@sourcegraph/cody-shared'

import * as vscode from 'vscode'
import { logError } from '../../log'

import {
type AnthropicOptions,
createProviderConfig as createAnthropicProviderConfig,
} from './anthropic'
import { createProviderConfig as createExperimentalOllamaProviderConfig } from './experimental-ollama'
import {
FIREWORKS_FIM_FINE_TUNED_MODEL_1,
FIREWORKS_FIM_FINE_TUNED_MODEL_2,
FIREWORKS_FIM_FINE_TUNED_MODEL_3,
FIREWORKS_FIM_FINE_TUNED_MODEL_4,
type FireworksOptions,
createProviderConfig as createFireworksProviderConfig,
} from './fireworks'
Expand Down Expand Up @@ -150,6 +154,50 @@ export async function createProviderConfig(
return createAnthropicProviderConfig({ client })
}

async function resolveFinetunedModelProviderFromFeatureFlags(): ReturnType<
typeof resolveDefaultProviderFromVSCodeConfigOrFeatureFlags
> {
/**
* The traffic allocated to the fine-tuned-base feature flag is further split between multiple feature flag in function.
*/
const [finetuneControl, finetuneVariant1, finetuneVariant2, finetuneVariant3, finetuneVariant4] =
await Promise.all([
featureFlagProvider.evaluateFeatureFlag(
FeatureFlag.CodyAutocompleteFIMFineTunedModelControl
),
featureFlagProvider.evaluateFeatureFlag(
FeatureFlag.CodyAutocompleteFIMFineTunedModelVariant1
),
featureFlagProvider.evaluateFeatureFlag(
FeatureFlag.CodyAutocompleteFIMFineTunedModelVariant2
),
featureFlagProvider.evaluateFeatureFlag(
FeatureFlag.CodyAutocompleteFIMFineTunedModelVariant3
),
featureFlagProvider.evaluateFeatureFlag(
FeatureFlag.CodyAutocompleteFIMFineTunedModelVariant4
),
])
if (finetuneVariant1) {
return { provider: 'fireworks', model: FIREWORKS_FIM_FINE_TUNED_MODEL_1 }
}
if (finetuneVariant2) {
return { provider: 'fireworks', model: FIREWORKS_FIM_FINE_TUNED_MODEL_2 }
}
if (finetuneVariant3) {
return { provider: 'fireworks', model: FIREWORKS_FIM_FINE_TUNED_MODEL_3 }
}
if (finetuneVariant4) {
return { provider: 'fireworks', model: FIREWORKS_FIM_FINE_TUNED_MODEL_4 }
}
if (finetuneControl) {
return { provider: 'fireworks', model: 'starcoder-hybrid' }
}

// Extra free traffic - redirect to the current production model which could be different than control
return { provider: 'fireworks', model: 'starcoder-hybrid' }
}

async function resolveDefaultProviderFromVSCodeConfigOrFeatureFlags(
configuredProvider: string | null
): Promise<{
Expand All @@ -160,15 +208,27 @@ async function resolveDefaultProviderFromVSCodeConfigOrFeatureFlags(
return { provider: configuredProvider }
}

const [starCoder2Hybrid, starCoderHybrid, llamaCode13B, claude3, finetunedModel] = await Promise.all(
[
const [starCoder2Hybrid, starCoderHybrid, llamaCode13B, claude3, finetunedFIMModelExperiment] =
await Promise.all([
featureFlagProvider.evaluateFeatureFlag(FeatureFlag.CodyAutocompleteStarCoder2Hybrid),
featureFlagProvider.evaluateFeatureFlag(FeatureFlag.CodyAutocompleteStarCoderHybrid),
featureFlagProvider.evaluateFeatureFlag(FeatureFlag.CodyAutocompleteLlamaCode13B),
featureFlagProvider.evaluateFeatureFlag(FeatureFlag.CodyAutocompleteClaude3),
featureFlagProvider.evaluateFeatureFlag(FeatureFlag.CodyAutocompleteFineTunedModel),
]
)
featureFlagProvider.evaluateFeatureFlag(
FeatureFlag.CodyAutocompleteFIMFineTunedModelBaseFeatureFlag
),
])

// We run fine tuning experiment for VSC client only.
// We disable for all agent clients like the JetBrains plugin.
const isFinetuningExperimentDisabled = vscode.workspace
.getConfiguration()
.get<boolean>('cody.advanced.agent.running', false)

if (!isFinetuningExperimentDisabled && finetunedFIMModelExperiment) {
// The traffic in this feature flag is interpreted as a traffic allocated to the fine-tuned experiment.
return resolveFinetunedModelProviderFromFeatureFlags()
}

if (llamaCode13B) {
return { provider: 'fireworks', model: 'llama-code-13b' }
Expand All @@ -179,12 +239,6 @@ async function resolveDefaultProviderFromVSCodeConfigOrFeatureFlags(
}

if (starCoderHybrid) {
// Adding the fine-tuned model here for the A/B test setup.
// Among all the users in starcoder-hybrid - some % of them will be redirected to the fine-tuned model.
if (finetunedModel) {
return { provider: 'fireworks', model: 'fireworks-completions-fine-tuned' }
}

return { provider: 'fireworks', model: 'starcoder-hybrid' }
}

Expand Down
54 changes: 41 additions & 13 deletions vscode/src/completions/providers/fireworks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import {
tokensToChars,
tracer,
} from '@sourcegraph/cody-shared'

import { fetch } from '@sourcegraph/cody-shared'
import { getLanguageConfig } from '../../tree-sitter/language'
import { getSuffixAfterFirstNewline } from '../text-processing'
Expand Down Expand Up @@ -69,6 +68,19 @@ const PROVIDER_IDENTIFIER = 'fireworks'
const EOT_STARCODER = '<|endoftext|>'
const EOT_LLAMA_CODE = ' <EOT>'

// Fireworks hosted model identifier strings
export const FIREWORKS_FIM_FINE_TUNED_MODEL_1 = 'fim-fine-tuned-model-variant-1'
export const FIREWORKS_FIM_FINE_TUNED_MODEL_2 = 'fim-fine-tuned-model-variant-2'
export const FIREWORKS_FIM_FINE_TUNED_MODEL_3 = 'fim-fine-tuned-model-variant-3'
export const FIREWORKS_FIM_FINE_TUNED_MODEL_4 = 'fim-fine-tuned-model-variant-4'

const FIREWORKS_FIM_FINE_TUNED_MODEL_FAMILY = [
FIREWORKS_FIM_FINE_TUNED_MODEL_1,
FIREWORKS_FIM_FINE_TUNED_MODEL_2,
FIREWORKS_FIM_FINE_TUNED_MODEL_3,
FIREWORKS_FIM_FINE_TUNED_MODEL_4,
]

// Model identifiers can be found in https://docs.fireworks.ai/explore/ and in our internal
// conversations
const MODEL_MAP = {
Expand All @@ -83,8 +95,10 @@ const MODEL_MAP = {
'llama-code-13b': 'fireworks/accounts/fireworks/models/llama-v2-13b-code',

// Fine-tuned model mapping
'fireworks-completions-fine-tuned':
'fireworks/accounts/sourcegraph/models/codecompletion-mixtral-rust-152k-005e',
[FIREWORKS_FIM_FINE_TUNED_MODEL_1]: FIREWORKS_FIM_FINE_TUNED_MODEL_1,
[FIREWORKS_FIM_FINE_TUNED_MODEL_2]: FIREWORKS_FIM_FINE_TUNED_MODEL_2,
[FIREWORKS_FIM_FINE_TUNED_MODEL_3]: FIREWORKS_FIM_FINE_TUNED_MODEL_3,
[FIREWORKS_FIM_FINE_TUNED_MODEL_4]: FIREWORKS_FIM_FINE_TUNED_MODEL_4,
}

type FireworksModel =
Expand All @@ -111,8 +125,12 @@ function getMaxContextTokens(model: FireworksModel): number {
// Llama 2 on Fireworks supports up to 4k tokens. We're constraining it here to better
// compare the results
return 2048
case 'fireworks-completions-fine-tuned':
return 2048
case FIREWORKS_FIM_FINE_TUNED_MODEL_1:
case FIREWORKS_FIM_FINE_TUNED_MODEL_2:
case FIREWORKS_FIM_FINE_TUNED_MODEL_3:
case FIREWORKS_FIM_FINE_TUNED_MODEL_4: {
return 3072
}
default:
return 1200
}
Expand Down Expand Up @@ -180,7 +198,10 @@ class FireworksProvider extends Provider {
const languageConfig = getLanguageConfig(this.options.document.languageId)

// In StarCoder we have a special token to announce the path of the file
if (!isStarCoderFamily(this.model) && this.model !== 'fireworks-completions-fine-tuned') {
if (
!isStarCoderFamily(this.model) &&
!FIREWORKS_FIM_FINE_TUNED_MODEL_FAMILY.includes(this.model)
) {
intro.push(ps`Path: ${PromptString.fromDisplayPath(this.options.document.uri)}`)
}

Expand All @@ -193,6 +214,13 @@ class FireworksProvider extends Provider {
intro.push(
ps`Additional documentation for \`${contextPrompts.symbol}\`:\n\n${contextPrompts.content}`
)
} else if (FIREWORKS_FIM_FINE_TUNED_MODEL_FAMILY.includes(this.model)) {
// Fine-tuned model have a additional <file_sep> tag.
intro.push(
ps`<file_sep>Here is a reference snippet of code from ${PromptString.fromDisplayPath(
snippet.uri
)}\n${contextPrompts.content}`
)
} else {
intro.push(
ps`Here is a reference snippet of code from ${PromptString.fromDisplayPath(
Expand Down Expand Up @@ -257,7 +285,10 @@ class FireworksProvider extends Provider {
model,
} satisfies CodeCompletionsParams

if (requestParams.model.includes('starcoder2')) {
if (
requestParams.model.includes('starcoder2') ||
FIREWORKS_FIM_FINE_TUNED_MODEL_FAMILY.includes(requestParams.model)
) {
requestParams.stopSequences = [
...(requestParams.stopSequences || []),
'<fim_prefix>',
Expand Down Expand Up @@ -322,12 +353,8 @@ class FireworksProvider extends Provider {
// c.f. https://github.com/facebookresearch/codellama/blob/main/llama/generation.py#L402
return ps`<PRE> ${intro}${prefix} <SUF>${suffix} <MID>`
}
if (this.model === 'fireworks-completions-fine-tuned') {
const fixedPrompt = ps`You are an expert in writing code in many different languages.
Your goal is to perform code completion for the following code, keeping in mind the rest of the code and the file meta data.
Metadata details: filename: ${filename}. The code of the file until where you have to start completion:
`
return ps`${intro} \n ${fixedPrompt}\n${prefix}`
if (FIREWORKS_FIM_FINE_TUNED_MODEL_FAMILY.includes(this.model)) {
return ps`${intro}<fim_suffix>${filename}\n${suffix}<fim_prefix>${prefix}<fim_middle>`
}
console.error('Could not generate infilling prompt for', this.model)
return ps`${intro}${prefix}`
Expand Down Expand Up @@ -396,6 +423,7 @@ class FireworksProvider extends Provider {
...(self.fireworksConfig?.parameters?.stop || []),
],
stream: true,
languageId: self.options.document.languageId,
}

const headers = new Headers()
Expand Down

0 comments on commit 4b9fb4c

Please sign in to comment.