Skip to content

Commit

Permalink
feat: add menu for source image filters
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 14, 2023
1 parent 80d00e4 commit 4df28a5
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 26 deletions.
11 changes: 11 additions & 0 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,17 @@ def list_extra_strings(server: ServerContext):
return jsonify(get_extra_strings())


def list_filters(server: ServerContext):
mask_filters = list(get_mask_filters().keys())
source_filters = list(get_source_filters().keys())
return jsonify({
"mask": mask_filters,
"source": source_filters,
})


def list_mask_filters(server: ServerContext):
logger.info("dedicated list endpoint for mask filters is deprecated")
return jsonify(list(get_mask_filters().keys()))


Expand Down Expand Up @@ -502,6 +512,7 @@ def status(server: ServerContext, pool: DevicePoolExecutor):
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api")(wrap_route(introspect, server, app=app)),
app.route("/api/settings/filters")(wrap_route(list_filters, server)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
app.route("/api/settings/models")(wrap_route(list_models, server)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)),
Expand Down
27 changes: 19 additions & 8 deletions gui/src/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ export interface Txt2ImgParams extends BaseImgParams {
*/
export interface Img2ImgParams extends BaseImgParams {
source: Blob;

sourceFilter?: string;
strength: number;
}

Expand Down Expand Up @@ -201,10 +203,15 @@ export interface NetworkModel {
// TODO: add layer/token count
}

export interface FilterResponse {
mask: Array<string>;
source: Array<string>;
}

/**
* List of available models.
*/
export interface ModelsResponse {
export interface ModelResponse {
correction: Array<string>;
diffusion: Array<string>;
networks: Array<NetworkModel>;
Expand Down Expand Up @@ -253,12 +260,12 @@ export interface ApiClient {
/**
* List the available filter masks for inpaint.
*/
masks(): Promise<Array<string>>;
filters(): Promise<FilterResponse>;

/**
* List the available models.
*/
models(): Promise<ModelsResponse>;
models(): Promise<ModelResponse>;

/**
* List the available noise sources for inpaint.
Expand Down Expand Up @@ -433,15 +440,15 @@ export function makeClient(root: string, f = fetch): ApiClient {
}

return {
async masks(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'masks');
async filters(): Promise<FilterResponse> {
const path = makeApiUrl(root, 'settings', 'filters');
const res = await f(path);
return await res.json() as Array<string>;
return await res.json() as FilterResponse;
},
async models(): Promise<ModelsResponse> {
async models(): Promise<ModelResponse> {
const path = makeApiUrl(root, 'settings', 'models');
const res = await f(path);
return await res.json() as ModelsResponse;
return await res.json() as ModelResponse;
},
async noises(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'noises');
Expand Down Expand Up @@ -483,6 +490,10 @@ export function makeClient(root: string, f = fetch): ApiClient {

url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));

if (doesExist(params.sourceFilter)) {
url.searchParams.append('sourceFilter', params.sourceFilter);
}

if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}
Expand Down
2 changes: 1 addition & 1 deletion gui/src/client/local.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export class NoServerError extends BaseError {
* @TODO client-side inference with https://www.npmjs.com/package/onnxruntime-web
*/
export const LOCAL_CLIENT = {
async masks() {
async filters() {
throw new NoServerError();
},
async blend(model, params, upscale) {
Expand Down
54 changes: 39 additions & 15 deletions gui/src/components/tab/Img2Img.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useMutation, useQueryClient } from '@tanstack/react-query';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import { useStore } from 'zustand';

import { IMAGE_FILTER } from '../../config.js';
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, StateContext } from '../../state.js';
import { ImageControl } from '../control/ImageControl.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { NumericField } from '../input/NumericField.js';
import { QueryList } from '../input/QueryList.js';

export function Img2Img() {
const { params } = mustExist(useContext(ConfigContext));
Expand All @@ -32,8 +33,14 @@ export function Img2Img() {
onSuccess: () => query.invalidateQueries(['ready']),
});

const filters = useQuery(['filters'], async () => client.filters(), {
staleTime: STALE_TIME,
});


const state = mustExist(useContext(StateContext));
const source = useStore(state, (s) => s.img2img.source);
const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter);
const strength = useStore(state, (s) => s.img2img.strength);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setImg2Img = useStore(state, (s) => s.setImg2Img);
Expand All @@ -49,19 +56,36 @@ export function Img2Img() {
});
}} />
<ImageControl selector={(s) => s.img2img} onChange={setImg2Img} />
<NumericField
decimal
label={t('parameter.strength')}
min={params.strength.min}
max={params.strength.max}
step={params.strength.step}
value={strength}
onChange={(value) => {
setImg2Img({
strength: value,
});
}}
/>
<Stack direction='row' spacing={2}>
<QueryList
id='sources'
labelKey={'sourceFilter'}
name={t('parameter.sourceFilter')}
query={{
result: filters,
selector: (f) => f.source,
}}
value={sourceFilter}
onChange={(newFilter) => {
setImg2Img({
sourceFilter: newFilter,
});
}}
/>
<NumericField
decimal
label={t('parameter.strength')}
min={params.strength.min}
max={params.strength.max}
step={params.strength.step}
value={strength}
onChange={(value) => {
setImg2Img({
strength: value,
});
}}
/>
</Stack>
<UpscaleControl />
<Button
disabled={doesExist(source) === false}
Expand Down
5 changes: 3 additions & 2 deletions gui/src/components/tab/Inpaint.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export function Inpaint() {
const { params } = mustExist(useContext(ConfigContext));
const client = mustExist(useContext(ClientContext));

const masks = useQuery(['masks'], async () => client.masks(), {
const filters = useQuery(['filters'], async () => client.filters(), {
staleTime: STALE_TIME,
});
const noises = useQuery(['noises'], async () => client.noises(), {
Expand Down Expand Up @@ -146,7 +146,8 @@ export function Inpaint() {
labelKey={'maskFilter'}
name={t('parameter.maskFilter')}
query={{
result: masks,
result: filters,
selector: (f) => f.mask,
}}
value={filter}
onChange={(newFilter) => {
Expand Down
2 changes: 2 additions & 0 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ export function createStateSlices(server: ServerParams) {
img2img: {
...base,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
setImg2Img(params) {
Expand All @@ -268,6 +269,7 @@ export function createStateSlices(server: ServerParams) {
img2img: {
...base,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
});
Expand Down

0 comments on commit 4df28a5

Please sign in to comment.