Skip to content

Commit

Permalink
feat: show tokens for networks in prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Nov 12, 2023
1 parent 3ffbc00 commit 44e4833
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 56 deletions.
3 changes: 3 additions & 0 deletions api/onnx_web/diffusers/utils.py
Expand Up @@ -379,6 +379,9 @@ def encode_prompt(
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> List[np.ndarray]:
"""
TODO: does not work with SDXL, fix or turn into a pipeline patch
"""
return [
pipe._encode_prompt(
remove_tokens(prompt),
Expand Down
73 changes: 34 additions & 39 deletions api/schemas/extras.yaml
Expand Up @@ -10,34 +10,53 @@ $defs:
- type: number
- type: string

lora_network:
tensor_format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]

embedding_network:
type: object
required: [name, source]
properties:
format:
$ref: "#/defs/tensor_format"
label:
type: string
model:
type: string
enum: [concept, embeddings]
name:
type: string
source:
type: string
label:
token:
type: string
type:
type: string
const: inversion # TODO: add embedding
weight:
type: number

textual_inversion_network:
lora_network:
type: object
required: [name, source]
required: [name, source, type]
properties:
name:
label:
type: string
source:
model:
type: string
format:
enum: [cloneofsimo, sd-scripts]
name:
type: string
enum: [concept, embeddings]
label:
source:
type: string
token:
tokens:
type: array
items:
type: string
type:
type: string
const: lora
weight:
type: number

Expand All @@ -46,8 +65,7 @@ $defs:
required: [name, source]
properties:
format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
$ref: "#/defs/tensor_format"
half:
type: boolean
label:
Expand Down Expand Up @@ -85,7 +103,7 @@ $defs:
inversions:
type: array
items:
$ref: "#/$defs/textual_inversion_network"
$ref: "#/$defs/embedding_network"
loras:
type: array
items:
Expand Down Expand Up @@ -142,31 +160,6 @@ $defs:
source:
type: string

source_network:
type: object
required: [name, source, type]
properties:
format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
model:
type: string
enum: [
# inversion
concept,
embeddings,
# lora
cloneofsimo,
sd-scripts
]
name:
type: string
source:
type: string
type:
type: string
enum: [inversion, lora]

translation:
type: object
additionalProperties: False
Expand Down Expand Up @@ -194,7 +187,9 @@ properties:
networks:
type: array
items:
$ref: "#/$defs/source_network"
oneOf:
- $ref: "#/$defs/lora_network"
- $ref: "#/$defs/embedding_network"
sources:
type: array
items:
Expand Down
89 changes: 76 additions & 13 deletions gui/src/components/input/PromptInput.tsx
@@ -1,5 +1,5 @@
import { mustExist } from '@apextoaster/js-utils';
import { TextField } from '@mui/material';
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
import { Chip, TextField } from '@mui/material';
import { Stack } from '@mui/system';
import { useQuery } from '@tanstack/react-query';
import * as React from 'react';
Expand All @@ -10,6 +10,7 @@ import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js';

const { useContext } = React;

Expand Down Expand Up @@ -48,26 +49,27 @@ export function PromptInput(props: PromptInputProps) {
staleTime: STALE_TIME,
});

const tokens = splitPrompt(prompt);
const groups = Math.ceil(tokens.length / PROMPT_GROUP);

const { t } = useTranslation();
const helper = t('input.prompt.tokens', {
groups,
tokens: tokens.length,
});

function addToken(type: string, name: string, weight = 1.0) {
function addNetwork(type: string, name: string, weight = 1.0) {
onChange({
prompt: `<${type}:${name}:1.0> ${prompt}`,
negativePrompt,
});
}

function addToken(name: string) {
onChange({
prompt: `${prompt}, ${name}`,
});
}

const networks = extractNetworks(prompt);
const tokens = getNetworkTokens(models.data, networks);

return <Stack spacing={2}>
<TextField
label={t('parameter.prompt')}
helperText={helper}
variant='outlined'
value={prompt}
onChange={(event) => {
Expand All @@ -77,6 +79,7 @@ export function PromptInput(props: PromptInputProps) {
});
}}
/>
{tokens.map(([token, _weight]) => <Chip label={token} onClick={() => addToken(token)} />)}
<TextField
label={t('parameter.negativePrompt')}
variant='outlined'
Expand All @@ -98,7 +101,7 @@ export function PromptInput(props: PromptInputProps) {
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}}
onSelect={(name) => {
addToken('inversion', name);
addNetwork('inversion', name);
}}
/>
<QueryMenu
Expand All @@ -110,9 +113,69 @@ export function PromptInput(props: PromptInputProps) {
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}}
onSelect={(name) => {
addToken('lora', name);
addNetwork('lora', name);
}}
/>
</Stack>
</Stack>;
}

export const ANY_TOKEN = /<([^>])+>/g;

export type TokenList = Array<[string, number]>;

export interface PromptNetworks {
inversion: TokenList;
lora: TokenList;
}

export function extractNetworks(prompt: string): PromptNetworks {
const inversion: TokenList = [];
const lora: TokenList = [];

for (const token of prompt.matchAll(ANY_TOKEN)) {
const [_whole, match] = Array.from(token);
const [type, name, weight, ..._rest] = match.split(':');

switch (type) {
case 'inversion':
inversion.push([name, parseFloat(weight)]);
break;
case 'lora':
lora.push([name, parseFloat(weight)]);
break;
default:
// ignore others
}
}

return {
inversion,
lora,
};
}

// eslint-disable-next-line sonarjs/cognitive-complexity
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
const tokens: TokenList = [];

if (doesExist(models)) {
for (const [name, weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
if (doesExist(model) && model.type === 'inversion') {
tokens.push([model.token, weight]);
}
}

for (const [name, weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
if (doesExist(model) && model.type === 'lora') {
for (const token of model.tokens) {
tokens.push([token, weight]);
}
}
}
}

return tokens;
}
23 changes: 19 additions & 4 deletions gui/src/types/api.ts
Expand Up @@ -39,13 +39,28 @@ export interface ReadyResponse {
ready: boolean;
}

export interface NetworkModel {
export interface ControlNetwork {
name: string;
type: 'control' | 'inversion' | 'lora';
// TODO: add token
// TODO: add layer/token count
type: 'control';
}

export interface EmbeddingNetwork {
label: string;
name: string;
token: string;
type: 'inversion';
// TODO: add layer count
}

export interface LoraNetwork {
name: string;
label: string;
tokens: Array<string>;
type: 'lora';
}

export type NetworkModel = EmbeddingNetwork | LoraNetwork | ControlNetwork;

export interface FilterResponse {
mask: Array<string>;
source: Array<string>;
Expand Down

0 comments on commit 44e4833

Please sign in to comment.