Skip to content

Commit

Permalink
feat(gui): add prompt to upscale tab (fixes #187)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 19, 2023
1 parent 7ef63e1 commit 34832f0
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 38 deletions.
2 changes: 1 addition & 1 deletion docs/converting-models.md
Expand Up @@ -22,7 +22,7 @@ You can start from a diffusers directory, HuggingFace Hub repository, or an SD c

1. LoRA weights from `kohya-ss/sd-scripts` to...
2. SD or Dreambooth checkpoint to...
3. diffusers or LoRA weights from `cloneofsimo/lora` to...
3. diffusers directory or LoRA weights from `cloneofsimo/lora` to...
4. ONNX models

One disadvantage of using ONNX is that LoRA weights must be merged with the base model before being converted,
Expand Down
8 changes: 8 additions & 0 deletions gui/src/client.ts
Expand Up @@ -129,6 +129,8 @@ export interface UpscaleParams {
* Parameters for upscale requests.
*/
export interface UpscaleReqParams {
prompt: string;
negativePrompt?: string;
source: Blob;
}

Expand Down Expand Up @@ -477,6 +479,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}

url.searchParams.append('prompt', params.prompt);

if (doesExist(params.negativePrompt)) {
url.searchParams.append('negativePrompt', params.negativePrompt);
}

const body = new FormData();
body.append('source', params.source, 'source');

Expand Down
42 changes: 7 additions & 35 deletions gui/src/components/control/ImageControl.tsx
@@ -1,6 +1,6 @@
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Casino } from '@mui/icons-material';
import { Button, Stack, TextField } from '@mui/material';
import { Button, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useQuery } from 'react-query';
Expand All @@ -11,10 +11,9 @@ import { STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
import { SCHEDULER_LABELS } from '../../strings.js';
import { NumericField } from '../input/NumericField.js';
import { PromptInput } from '../input/PromptInput.js';
import { QueryList } from '../input/QueryList.js';

export const PROMPT_LIMIT = 70;

export interface ImageControlProps {
selector: (state: OnnxState) => BaseImgParams;

Expand All @@ -34,17 +33,6 @@ export function ImageControl(props: ImageControlProps) {
staleTime: STALE_TIME,
});

const promptLength = controlState.prompt.split(' ').length;
const error = promptLength > PROMPT_LIMIT;

function promptHelper() {
if (error) {
return `Too many tokens: ${promptLength}/${PROMPT_LIMIT}`;
} else {
return `Tokens: ${promptLength}/${PROMPT_LIMIT}`;
}
}

return <Stack spacing={2}>
<QueryList
id='schedulers'
Expand Down Expand Up @@ -126,30 +114,14 @@ export function ImageControl(props: ImageControlProps) {
New Seed
</Button>
</Stack>
<TextField
error={error}
label='Prompt'
helperText={promptHelper()}
variant='outlined'
value={controlState.prompt}
onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
prompt: event.target.value,
});
}
}}
/>
<TextField
label='Negative Prompt'
variant='outlined'
value={controlState.negativePrompt}
onChange={(event) => {
<PromptInput
prompt={controlState.prompt}
negativePrompt={controlState.negativePrompt}
onChange={(value) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
negativePrompt: event.target.value,
...value,
});
}
}}
Expand Down
60 changes: 60 additions & 0 deletions gui/src/components/input/PromptInput.tsx
@@ -0,0 +1,60 @@
import { doesExist, Maybe } from '@apextoaster/js-utils';
import { TextField } from '@mui/material';
import { Stack } from '@mui/system';
import * as React from 'react';

export interface PromptValue {
prompt: string;
negativePrompt?: string;
}

export interface PromptInputProps extends PromptValue {
onChange?: Maybe<(value: PromptValue) => void>;
}

export const PROMPT_LIMIT = 77;

export function PromptInput(props: PromptInputProps) {
const { prompt = '', negativePrompt = '' } = props;
const promptLength = prompt.split(' ').length;
const error = promptLength > PROMPT_LIMIT;

function promptHelper() {
if (error) {
return `Too many tokens: ${promptLength}/${PROMPT_LIMIT}`;
} else {
return `Tokens: ${promptLength}/${PROMPT_LIMIT}`;
}
}

return <Stack>
<TextField
error={error}
label='Prompt'
helperText={promptHelper()}
variant='outlined'
value={prompt}
onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
prompt: event.target.value,
negativePrompt,
});
}
}}
/>
<TextField
label='Negative Prompt'
variant='outlined'
value={negativePrompt}
onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
prompt,
negativePrompt: event.target.value,
});
}
}}
/>
</Stack>;
}
8 changes: 8 additions & 0 deletions gui/src/components/tab/Upscale.tsx
Expand Up @@ -9,6 +9,7 @@ import { IMAGE_FILTER } from '../../config.js';
import { ClientContext, StateContext } from '../../state.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { PromptInput } from '../input/PromptInput.js';

export function Upscale() {
async function uploadSource() {
Expand Down Expand Up @@ -47,6 +48,13 @@ export function Upscale() {
});
}}
/>
<PromptInput
prompt={params.prompt}
negativePrompt={params.negativePrompt}
onChange={(value) => {
setSource(value);
}}
/>
<UpscaleControl />
<Button
disabled={doesExist(params.source) === false}
Expand Down
8 changes: 6 additions & 2 deletions gui/src/state.ts
Expand Up @@ -411,6 +411,8 @@ export function createStateSlices(server: ServerParams) {
upscaleOrder: server.upscaleOrder.default,
},
upscaleTab: {
negativePrompt: server.negativePrompt.default,
prompt: server.prompt.default,
source: null,
},
setUpscale(upscale) {
Expand All @@ -432,6 +434,8 @@ export function createStateSlices(server: ServerParams) {
resetUpscaleTab() {
set({
upscaleTab: {
negativePrompt: server.negativePrompt.default,
prompt: server.prompt.default,
source: null,
},
});
Expand All @@ -452,12 +456,12 @@ export function createStateSlices(server: ServerParams) {
}));
},
resetBlend() {
set((prev) => ({
set({
blend: {
mask: null,
sources: [],
},
}));
});
},
});

Expand Down

0 comments on commit 34832f0

Please sign in to comment.