Skip to content

Commit

Permalink
feat(gui): add wildcard menu to web UI
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 16, 2023
1 parent 90bc28d commit cdbdd9b
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 9 deletions.
5 changes: 5 additions & 0 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def list_schedulers(server: ServerContext):
return jsonify(get_pipeline_schedulers())


def list_wildcards(server: ServerContext):
return jsonify(list(get_wildcard_data().keys()))


def img2img(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get("source")
if source_file is None:
Expand Down Expand Up @@ -597,6 +601,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
app.route("/api/settings/platforms")(wrap_route(list_platforms, server)),
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)),
app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, server, pool=pool)
),
Expand Down
5 changes: 5 additions & 0 deletions gui/src/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
translation: Record<string, string>;
}>;
},
async wildcards(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'wildcards');
const res = await f(path);
return await res.json() as Array<string>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);
Expand Down
2 changes: 2 additions & 0 deletions gui/src/client/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ export interface ApiClient {
translation: Record<string, string>;
}>>;

wildcards(): Promise<Array<string>>;

/**
* Start a txt2img pipeline.
*/
Expand Down
3 changes: 3 additions & 0 deletions gui/src/client/local.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ export const LOCAL_CLIENT = {
async strings() {
return {};
},
async wildcards() {
throw new NoServerError();
},
async restart() {
throw new NoServerError();
},
Expand Down
31 changes: 22 additions & 9 deletions gui/src/components/input/PromptInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@ export interface PromptInputProps {

export const PROMPT_GROUP = 75;

function splitPrompt(prompt: string): Array<string> {
return prompt
.split(',')
.flatMap((phrase) => phrase.split(' '))
.map((word) => word.trim())
.filter((word) => word.length > 0);
}

export function PromptInput(props: PromptInputProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { selector, onChange } = props;
Expand All @@ -48,12 +40,15 @@ export function PromptInput(props: PromptInputProps) {
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
const wildcards = useQuery(['wildcards'], async () => client.wildcards(), {
staleTime: STALE_TIME,
});

const { t } = useTranslation();

function addNetwork(type: string, name: string, weight = 1.0) {
onChange({
prompt: `<${type}:${name}:1.0> ${prompt}`,
prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`,
negativePrompt,
});
}
Expand All @@ -64,6 +59,12 @@ export function PromptInput(props: PromptInputProps) {
});
}

function addWildcard(name: string) {
onChange({
prompt: `${prompt}, __${name}__`,
});
}

const tokens = useMemo(() => {
const networks = extractNetworks(prompt);
return getNetworkTokens(models.data, networks);
Expand Down Expand Up @@ -124,6 +125,18 @@ export function PromptInput(props: PromptInputProps) {
addNetwork('lora', name);
}}
/>
<QueryMenu
id='wildcard'
labelKey='wildcard'
name={t('wildcard')}
query={{
result: wildcards,
selector: (result) => result,
}}
onSelect={(name) => {
addWildcard(name);
}}
/>
</Stack>
</Stack>;
}
Expand Down
1 change: 1 addition & 0 deletions gui/src/strings/de.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ export const I18N_STRINGS_DE = {
'correction-first': 'Korrektur zuerst',
'correction-last': 'Korrektur zuletzt',
},
wildcard: '',
},
},
};
1 change: 1 addition & 0 deletions gui/src/strings/en.ts
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ export const I18N_STRINGS_EN = {
'correction-first': 'Correction First',
'correction-last': 'Correction Last',
},
wildcard: 'Wildcard',
}
},
};
1 change: 1 addition & 0 deletions gui/src/strings/es.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ export const I18N_STRINGS_ES = {
'correction-first': 'corrección primero',
'correction-last': 'última corrección',
},
wildcard: '',
},
},
};
1 change: 1 addition & 0 deletions gui/src/strings/fr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ export const I18N_STRINGS_FR = {
'correction-first': '',
'correction-last': '',
},
wildcard: '',
},
},
};

0 comments on commit cdbdd9b

Please sign in to comment.