Skip to content

Commit

Permalink
fix(gui): dedupe and sort available prompt tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Nov 13, 2023
1 parent 95e2d6d commit 35171e6
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions gui/src/components/input/PromptInput.tsx
Expand Up @@ -12,7 +12,7 @@ import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js';

const { useContext } = React;
const { useContext, useMemo } = React;

/**
* @todo replace with a selector
Expand Down Expand Up @@ -64,8 +64,10 @@ export function PromptInput(props: PromptInputProps) {
});
}

const networks = extractNetworks(prompt);
const tokens = getNetworkTokens(models.data, networks);
const tokens = useMemo(() => {
const networks = extractNetworks(prompt);
return getNetworkTokens(models.data, networks);
}, [prompt, models.data]);

return <Stack spacing={2}>
<TextField
Expand All @@ -80,7 +82,7 @@ export function PromptInput(props: PromptInputProps) {
}}
/>
<Stack direction='row' spacing={2}>
{tokens.map(([token, _weight]) => <Chip
{tokens.map((token) => <Chip
color={prompt.includes(token) ? 'primary' : 'default'}
label={token}
onClick={() => addToken(token)}
Expand Down Expand Up @@ -162,26 +164,26 @@ export function extractNetworks(prompt: string): PromptNetworks {
}

// eslint-disable-next-line sonarjs/cognitive-complexity
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
const tokens: TokenList = [];
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): Array<string> {
const tokens: Set<string> = new Set();

if (doesExist(models)) {
for (const [name, weight] of networks.inversion) {
for (const [name, _weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
if (doesExist(model) && model.type === 'inversion') {
tokens.push([model.token, weight]);
tokens.add(model.token);
}
}

for (const [name, weight] of networks.lora) {
for (const [name, _weight] of networks.lora) {
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
if (doesExist(model) && model.type === 'lora') {
for (const token of mustDefault(model.tokens, [])) {
tokens.push([token, weight]);
tokens.add(token);
}
}
}
}

return tokens;
return Array.from(tokens).sort();
}

0 comments on commit 35171e6

Please sign in to comment.