Skip to content

Commit

Permalink
add model name/select to provider header
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmacarthy committed Apr 21, 2024
1 parent 59faa50 commit 90f5c44
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/extension/cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export class LRUCache<T = string> {
}

normalize(src: string): string {
return src.split('\n').join('').replace(/\s+/gm, '').replace(' ', '')
return src.split('\n').join('').replace(/\s+/g, '').replace(/\s/g, '')
}

getKey(prefixSuffix: PrefixSuffix): string {
Expand Down
23 changes: 9 additions & 14 deletions src/webview/model-select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@ export const ModelSelect = ({ model, models, setModel }: Props) => {
}

return (
<div>
<div>
<label htmlFor="modelName">Model name*</label>
</div>
<VSCodeDropdown onChange={handleOnChange} value={model}>
{models?.map((model, index) => {
return (
<option value={model.name} key={`${index}`}>
{getModelShortName(model.name)}
</option>
)
})}
</VSCodeDropdown>
</div>
<VSCodeDropdown onChange={handleOnChange} value={model}>
{models?.map((model, index) => {
return (
<option value={model.name} key={`${index}`}>
{getModelShortName(model.name)}
</option>
)
})}
</VSCodeDropdown>
)
}
4 changes: 2 additions & 2 deletions src/webview/provider-select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export const ProviderSelect = () => {
.sort((a, b) => a.modelName.localeCompare(b.modelName))
.map((provider, index) => (
<VSCodeOption key={index} value={provider.id}>
{`${provider.label} (${provider.provider})`}
{`${provider.label} (${provider.modelName})`}
</VSCodeOption>
))}
</VSCodeDropdown>
Expand All @@ -59,7 +59,7 @@ export const ProviderSelect = () => {
.sort((a, b) => a.modelName.localeCompare(b.modelName))
.map((provider, index) => (
<VSCodeOption key={index} value={provider.id}>
{`${provider.label} (${provider.provider})`}
{`${provider.label} (${provider.modelName})`}
</VSCodeOption>
))}
</VSCodeDropdown>
Expand Down
11 changes: 11 additions & 0 deletions src/webview/providers.module.css
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
flex-wrap: wrap;
}

.providerHeader, .providerForm vscode-dropdown {
flex-grow: 1;
}

.providerHeader h4 {
margin: 0;
}
Expand All @@ -52,6 +56,8 @@

.providerForm {
width: 100%;
display: flex;
flex-direction: column;
}

.providerForm div {
Expand All @@ -68,6 +74,11 @@
gap: 5px;
}

.providerHeader vscode-dropdown {
flex-grow: 1;
margin-right: 10px;
}

.providerHeader svg {
width: 20px;
}
65 changes: 53 additions & 12 deletions src/webview/providers.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ import styles from './providers.module.css'
import { TwinnyProvider } from '../extension/provider-manager'
import {
DEFAULT_PROVIDER_FORM_VALUES,
FIM_TEMPLATE_FORMAT,
FIM_TEMPLATE_FORMAT
} from '../common/constants'
import { ModelSelect } from './model-select'

export const Providers = () => {
const [showForm, setShowForm] = React.useState(false)
const [provider, setProvider] = React.useState<TwinnyProvider | undefined>()
const { models } = useOllamaModels()
const hasOllamaModels = !!models?.length
const { updateProvider } = useProviders()
const { providers, removeProvider, copyProvider, resetProviders } =
useProviders()

Expand Down Expand Up @@ -50,6 +53,19 @@ export const Providers = () => {
resetProviders()
}

const handleSetModel = (provider: TwinnyProvider, model: string) => {
updateProvider({
...provider,
modelName: model
})
}

const handleChange = (provider: TwinnyProvider, e: unknown) => {
const event = e as unknown as React.ChangeEvent<HTMLInputElement>
const { value } = event.target
handleSetModel(provider, value)
}

return (
<div>
<h3>Providers</h3>
Expand All @@ -69,7 +85,27 @@ export const Providers = () => {
{Object.values(providers).map((provider, index) => (
<div className={styles.provider} key={index}>
<div className={styles.providerHeader}>
<h4>{provider.label}</h4>
{provider.provider === ApiProviders.Ollama &&
hasOllamaModels && (
<ModelSelect
models={models}
model={provider.modelName}
setModel={(model: string) =>
handleSetModel(provider, model)
}
/>
)}
{provider.provider !== ApiProviders.Ollama && (
<VSCodeTextField
required
name="modelName"
onChange={(e) => {
handleChange(provider, e)
}}
value={provider.modelName}
placeholder='Applicable for some providers like "Ollama"'
></VSCodeTextField>
)}
<div className={styles.providerActions}>
<VSCodeButton
appearance="icon"
Expand Down Expand Up @@ -99,6 +135,9 @@ export const Providers = () => {
</div>
<VSCodeDivider />
<div className={styles.providerDetails}>
<div>
<b>Label:</b> {provider.label}
</div>
<div>
<b>Provider:</b> {provider.provider}
</div>
Expand All @@ -110,9 +149,6 @@ export const Providers = () => {
<b>Fim Template:</b> {provider.fimTemplate}
</div>
)}
<div>
<b>Model:</b> {provider.modelName}
</div>
<div>
<b>Hostname:</b> {provider.apiHostname}
</div>
Expand Down Expand Up @@ -271,13 +307,18 @@ function ProviderForm({ onClose, provider }: ProviderFormProps) {
</div>

{formState.provider === ApiProviders.Ollama && hasOllamaModels && (
<ModelSelect
models={models}
model={formState.modelName}
setModel={(model: string) => {
setFormState({ ...formState, modelName: model })
}}
/>
<div>
<div>
<label htmlFor="apiHostname">Model name*</label>
</div>
<ModelSelect
models={models}
model={formState.modelName}
setModel={(model: string) => {
setFormState({ ...formState, modelName: model })
}}
/>
</div>
)}

{formState.provider !== ApiProviders.Ollama && (
Expand Down
8 changes: 3 additions & 5 deletions src/webview/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import {
EMPTY_MESAGE,
} from '../common/constants'
import { EMPTY_MESAGE } from '../common/constants'
import { CodeLanguage, supportedLanguages } from '../common/languages'
import { LanguageType, ServerMessage } from '../common/types'

Expand Down Expand Up @@ -57,8 +55,8 @@ export const kebabToSentence = (kebabStr: string) => {
export const getLineBreakCount = (str: string) => str.split('\n').length

export const getModelShortName = (name: string) => {
if (name.length > 32) {
return `${name.substring(0, 15)}...${name.substring(name.length - 16)}`
if (name.length > 25) {
return `${name.substring(0, 10)}...${name.substring(name.length - 10)}`
}
return name
}

0 comments on commit 90f5c44

Please sign in to comment.