Skip to content

Commit

Permalink
feat(gui): add menus for upscaling and correction models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 17, 2023
1 parent ee6308a commit 0080d86
Show file tree
Hide file tree
Showing 13 changed files with 247 additions and 163 deletions.
66 changes: 38 additions & 28 deletions gui/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@ import { doesExist } from '@apextoaster/js-utils';

import { ConfigParams } from './config.js';

export interface BaseImgParams {
export interface ModelParams {
/**
* Which ONNX model to use.
*/
model?: string;
model: string;

/**
* Hardware accelerator or CPU mode.
*/
platform?: string;
platform: string;

/**
* Scheduling algorithm.
*/
scheduler?: string;
upscaling: string;
correction: string;
}

export interface BaseImgParams {
scheduler: string;
prompt: string;
negativePrompt?: string;

Expand Down Expand Up @@ -90,18 +91,24 @@ export interface ApiReady {
ready: boolean;
}

export interface ApiModels {
diffusion: Array<string>;
correction: Array<string>;
upscaling: Array<string>;
}

export interface ApiClient {
masks(): Promise<Array<string>>;
models(): Promise<Array<string>>;
models(): Promise<ApiModels>;
noises(): Promise<Array<string>>;
params(): Promise<ConfigParams>;
platforms(): Promise<Array<string>>;
schedulers(): Promise<Array<string>>;

img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
inpaint(params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
outpaint(params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;

ready(params: ApiResponse): Promise<ApiReady>;
}
Expand All @@ -111,9 +118,7 @@ export const STATUS_SUCCESS = 200;
export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams> {
return {
cfg: defaults.cfg.default,
model: defaults.model.default,
negativePrompt: defaults.negativePrompt.default,
platform: defaults.platform.default,
prompt: defaults.prompt.default,
scheduler: defaults.scheduler.default,
steps: defaults.steps.default,
Expand Down Expand Up @@ -141,14 +146,6 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));

if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
}

if (doesExist(params.platform)) {
url.searchParams.append('platform', params.platform);
}

if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler);
}
Expand All @@ -167,6 +164,11 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
return url;
}

export function appendModelToURL(url: URL, params: ModelParams) {
url.searchParams.append('model', params.model);
url.searchParams.append('platform', params.platform);
}

export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
if (upscale.enabled) {
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
Expand All @@ -191,10 +193,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async models(): Promise<Array<string>> {
async models(): Promise<ApiModels> {
const path = makeApiUrl(root, 'settings', 'models');
const res = await f(path);
return await res.json() as Array<string>;
return await res.json() as ApiModels;
},
async noises(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'noises');
Expand All @@ -216,12 +218,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
if (doesExist(pending)) {
return pending;
}

const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);

url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));

if (doesExist(upscale)) {
Expand All @@ -239,12 +243,13 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
if (doesExist(pending)) {
return pending;
}

const url = makeImageURL(root, 'txt2img', params);
appendModelToURL(url, model);

if (doesExist(params.width)) {
url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER));
Expand All @@ -265,14 +270,17 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async inpaint(params: InpaintParams, upscale?: UpscaleParams) {
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams) {
if (doesExist(pending)) {
return pending;
}

const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);

url.searchParams.append('filter', params.filter);
url.searchParams.append('noise', params.noise);

if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}
Expand All @@ -289,12 +297,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async outpaint(params: OutpaintParams, upscale?: UpscaleParams) {
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams) {
if (doesExist(pending)) {
return pending;
}

const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);

url.searchParams.append('filter', params.filter);
url.searchParams.append('noise', params.noise);

Expand Down
4 changes: 3 additions & 1 deletion gui/src/components/ImageControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ export function ImageControl(props: ImageControlProps) {
id='schedulers'
labels={SCHEDULER_LABELS}
name='Scheduler'
result={schedulers}
query={{
result: schedulers,
}}
value={mustDefault(params.scheduler, '')}
onChange={(value) => {
if (doesExist(props.onChange)) {
Expand Down
21 changes: 6 additions & 15 deletions gui/src/components/Img2Img.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,23 @@ import * as React from 'react';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';

import { ConfigParams, IMAGE_FILTER } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { IMAGE_FILTER } from '../config.js';
import { ClientContext, ConfigContext, StateContext } from '../state.js';
import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js';
import { NumericField } from './NumericField.js';
import { UpscaleControl } from './UpscaleControl.js';

const { useContext } = React;

export interface Img2ImgProps {
config: ConfigParams;

model: string;
platform: string;
}

export function Img2Img(props: Img2ImgProps) {
const { config, model, platform } = props;
export function Img2Img() {
const config = mustExist(useContext(ConfigContext));

async function uploadSource() {
const { img2img, upscale } = state.getState();
const { model, img2img, upscale } = state.getState();

const output = await client.img2img({
const output = await client.img2img(model, {
...img2img,
model,
platform,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, upscale);

Expand Down
34 changes: 14 additions & 20 deletions gui/src/components/Inpaint.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import * as React from 'react';
import { useMutation, useQuery, useQueryClient } from 'react-query';
import { useStore } from 'zustand';

import { ConfigParams, IMAGE_FILTER, STALE_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { IMAGE_FILTER, STALE_TIME } from '../config.js';
import { ClientContext, ConfigContext, StateContext } from '../state.js';
import { MASK_LABELS, NOISE_LABELS } from '../strings.js';
import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js';
Expand All @@ -16,16 +16,10 @@ import { UpscaleControl } from './UpscaleControl.js';

const { useContext } = React;

export interface InpaintProps {
config: ConfigParams;

model: string;
platform: string;
}

export function Inpaint(props: InpaintProps) {
const { config, model, platform } = props;
export function Inpaint() {
const config = mustExist(useContext(ConfigContext));
const client = mustExist(useContext(ClientContext));

const masks = useQuery('masks', async () => client.masks(), {
staleTime: STALE_TIME,
});
Expand All @@ -35,24 +29,20 @@ export function Inpaint(props: InpaintProps) {

async function uploadSource(): Promise<void> {
// these are not watched by the component, only sent by the mutation
const { inpaint, outpaint, upscale } = state.getState();
const { model, inpaint, outpaint, upscale } = state.getState();

if (outpaint.enabled) {
const output = await client.outpaint({
const output = await client.outpaint(model, {
...inpaint,
...outpaint,
model,
platform,
mask: mustExist(mask),
source: mustExist(source),
}, upscale);

setLoading(output);
} else {
const output = await client.inpaint({
const output = await client.inpaint(model, {
...inpaint,
model,
platform,
mask: mustExist(mask),
source: mustExist(source),
}, upscale);
Expand Down Expand Up @@ -122,7 +112,9 @@ export function Inpaint(props: InpaintProps) {
id='masks'
labels={MASK_LABELS}
name='Mask Filter'
result={masks}
query={{
result: masks,
}}
value={filter}
onChange={(newFilter) => {
setInpaint({
Expand All @@ -134,7 +126,9 @@ export function Inpaint(props: InpaintProps) {
id='noises'
labels={NOISE_LABELS}
name='Noise Source'
result={noises}
query={{
result: noises,
}}
value={noise}
onChange={(newNoise) => {
setInpaint({
Expand Down
89 changes: 89 additions & 0 deletions gui/src/components/ModelControl.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import { mustExist } from '@apextoaster/js-utils';
import { Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useQuery } from 'react-query';
import { useStore } from 'zustand';

import { STALE_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
import { QueryList } from './QueryList.js';

export function ModelControl() {
const client = mustExist(useContext(ClientContext));
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.model);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setModel = useStore(state, (s) => s.setModel);

const models = useQuery('models', async () => client.models(), {
staleTime: STALE_TIME,
});
const platforms = useQuery('platforms', async () => client.platforms(), {
staleTime: STALE_TIME,
});

return <Stack direction='row' spacing={2}>
<QueryList
id='platforms'
labels={PLATFORM_LABELS}
name='Platform'
query={{
result: platforms,
}}
value={params.platform}
onChange={(platform) => {
setModel({
platform,
});
}}
/>
<QueryList
id='diffusion'
labels={MODEL_LABELS}
name='Diffusion Model'
query={{
result: models,
selector: (result) => result.diffusion,
}}
value={params.model}
onChange={(model) => {
setModel({
model,
});
}}
/>
<QueryList
id='upscaling'
labels={MODEL_LABELS}
name='Upscaling Model'
query={{
result: models,
selector: (result) => result.upscaling,
}}
value={params.model}
onChange={(model) => {
setModel({
model,
});
}}
/>
<QueryList
id='correction'
labels={MODEL_LABELS}
name='Correction Model'
query={{
result: models,
selector: (result) => result.correction,
}}
value={params.model}
onChange={(model) => {
setModel({
model,
});
}}
/>

</Stack>;
}

0 comments on commit 0080d86

Please sign in to comment.