Skip to content

Commit

Permalink
fix(gui): make prompt input perform better with large LoRA/wildcard l…
Browse files Browse the repository at this point in the history
…ists
  • Loading branch information
ssube committed Dec 17, 2023
1 parent a65e0fd commit e0929ba
Showing 1 changed file with 58 additions and 32 deletions.
90 changes: 58 additions & 32 deletions gui/src/components/input/PromptInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js';
import { ModelResponse, NetworkModel } from '../../types/api.js';

const { useContext, useMemo } = React;

Expand All @@ -27,56 +27,38 @@ export interface PromptInputProps {
onChange(value: PromptValue): void;
}

export interface PromptTextBlockProps extends PromptInputProps {
models: Maybe<ModelResponse>;
}

export const PROMPT_GROUP = 75;

export function PromptInput(props: PromptInputProps) {
export function PromptTextBlock(props: PromptTextBlockProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { selector, onChange } = props;
const { models, selector, onChange } = props;

const { t } = useTranslation();
const store = mustExist(useContext(StateContext));
const { prompt, negativePrompt } = useStore(store, selector, shallow);

const client = mustExist(useContext(ClientContext));
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
const wildcards = useQuery(['wildcards'], async () => client.wildcards(), {
staleTime: STALE_TIME,
});

const { t } = useTranslation();

function addNetwork(type: string, name: string, weight = 1.0) {
onChange({
prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`,
negativePrompt,
});
}

function addToken(name: string) {
onChange({
prompt: `${prompt}, ${name}`,
});
}

function addWildcard(name: string) {
onChange({
prompt: `${prompt}, __${name}__`,
});
}

const tokens = useMemo(() => {
const networks = extractNetworks(prompt);
return getNetworkTokens(models.data, networks);
}, [prompt, models.data]);
return getNetworkTokens(models, networks);
}, [models, prompt]);

return <Stack spacing={2}>
<TextField
label={t('parameter.prompt')}
variant='outlined'
value={prompt}
onChange={(event) => {
props.onChange({
onChange({
prompt: event.target.value,
negativePrompt,
});
Expand All @@ -100,14 +82,54 @@ export function PromptInput(props: PromptInputProps) {
});
}}
/>
</Stack>;
}

export function PromptInput(props: PromptInputProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { selector, onChange } = props;

const store = mustExist(useContext(StateContext));
const client = mustExist(useContext(ClientContext));
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
const wildcards = useQuery(['wildcards'], async () => client.wildcards(), {
staleTime: STALE_TIME,
});

const { t } = useTranslation();

function addNetwork(type: string, name: string, weight = 1.0) {
const { prompt, negativePrompt } = selector(store.getState());
onChange({
negativePrompt,
prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`,
});
}

function addWildcard(name: string) {
const { prompt, negativePrompt } = selector(store.getState());
onChange({
negativePrompt,
prompt: `${prompt}, __${name}__`,
});
}

return <Stack spacing={2}>
<PromptTextBlock
models={models.data}
onChange={onChange}
selector={selector}
/>
<Stack direction='row' spacing={2}>
<QueryMenu
id='inversion'
labelKey='model.inversion'
name={t('modelType.inversion')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
selector: (result) => filterNetworks(result.networks, 'inversion'),
}}
onSelect={(name) => {
addNetwork('inversion', name);
Expand All @@ -119,7 +141,7 @@ export function PromptInput(props: PromptInputProps) {
name={t('modelType.lora')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
selector: (result) => filterNetworks(result.networks, 'lora'),
}}
onSelect={(name) => {
addNetwork('lora', name);
Expand All @@ -141,6 +163,10 @@ export function PromptInput(props: PromptInputProps) {
</Stack>;
}

export function filterNetworks(networks: Array<NetworkModel>, type: string): Array<string> {
return networks.filter((network) => network.type === type).map((network) => network.name);
}

export const ANY_TOKEN = /<([^>]+)>/g;

export type TokenList = Array<[string, number]>;
Expand All @@ -166,7 +192,7 @@ export function extractNetworks(prompt: string): PromptNetworks {
lora.push([name, parseFloat(weight)]);
break;
default:
// ignore others
// ignore others
}
}

Expand Down

0 comments on commit e0929ba

Please sign in to comment.