Skip to content

Commit

Permalink
feat(gui): add retry function to error card
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 18, 2023
1 parent 6226778 commit 8979064
Show file tree
Hide file tree
Showing 14 changed files with 182 additions and 55 deletions.
157 changes: 131 additions & 26 deletions gui/src/client/api.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable max-lines */
import { doesExist } from '@apextoaster/js-utils';
import { doesExist, InvalidArgumentError } from '@apextoaster/js-utils';

import { ServerParams } from '../config.js';
import { range } from '../utils.js';
Expand Down Expand Up @@ -191,6 +191,43 @@ export interface ModelsResponse {
upscaling: Array<string>;
}

export type RetryParams = {
type: 'txt2img';
model: ModelParams;
params: Txt2ImgParams;
upscale?: UpscaleParams;
} | {
type: 'img2img';
model: ModelParams;
params: Img2ImgParams;
upscale?: UpscaleParams;
} | {
type: 'inpaint';
model: ModelParams;
params: InpaintParams;
upscale?: UpscaleParams;
} | {
type: 'outpaint';
model: ModelParams;
params: OutpaintParams;
upscale?: UpscaleParams;
} | {
type: 'upscale';
model: ModelParams;
params: UpscaleReqParams;
upscale?: UpscaleParams;
} | {
type: 'blend';
model: ModelParams;
params: BlendParams;
upscale?: UpscaleParams;
};

export interface ImageResponseWithRetry {
image: ImageResponse;
retry: RetryParams;
}

export interface ApiClient {
/**
* List the available filter masks for inpaint.
Expand Down Expand Up @@ -232,39 +269,41 @@ export interface ApiClient {
/**
* Start a txt2img pipeline.
*/
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;

/**
* Start an im2img pipeline.
*/
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse>;
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;

/**
* Start an inpaint pipeline.
*/
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponse>;
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;

/**
* Start an outpaint pipeline.
*/
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponse>;
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;

/**
* Start an upscale pipeline.
*/
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponse>;
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;

/**
* Start a blending pipeline.
*/
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponse>;
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;

/**
* Check whether some pipeline's output is ready yet.
*/
ready(key: string): Promise<ReadyResponse>;

cancel(key: string): Promise<boolean>;

retry(params: RetryParams): Promise<ImageResponseWithRetry>;
}

/**
Expand Down Expand Up @@ -363,7 +402,7 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
* Make an API client using the given API root and fetch client.
*/
export function makeClient(root: string, f = fetch): ApiClient {
function throttleRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
function parseRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
return f(url, options).then((res) => parseApiResponse(root, res));
}

Expand Down Expand Up @@ -407,7 +446,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
translation: Record<string, string>;
}>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);

Expand All @@ -420,13 +459,21 @@ export function makeClient(root: string, f = fetch): ApiClient {
const body = new FormData();
body.append('source', params.source, 'source');

// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'img2img',
model,
params,
upscale,
},
};
},
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'txt2img', params);
appendModelToURL(url, model);

Expand All @@ -442,12 +489,20 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}

// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
method: 'POST',
});
return {
image,
retry: {
type: 'txt2img',
model,
params,
upscale,
},
};
},
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams) {
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);

Expand All @@ -464,13 +519,21 @@ export function makeClient(root: string, f = fetch): ApiClient {
body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source');

// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'inpaint',
model,
params,
upscale,
},
};
},
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams) {
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);

Expand Down Expand Up @@ -504,13 +567,21 @@ export function makeClient(root: string, f = fetch): ApiClient {
body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source');

// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'outpaint',
model,
params,
upscale,
},
};
},
async upscale(model: ModelParams, params: UpscaleReqParams, upscale: UpscaleParams): Promise<ImageResponse> {
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeApiUrl(root, 'upscale');
appendModelToURL(url, model);

Expand All @@ -527,13 +598,21 @@ export function makeClient(root: string, f = fetch): ApiClient {
const body = new FormData();
body.append('source', params.source, 'source');

// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'upscale',
model,
params,
upscale,
},
};
},
async blend(model: ModelParams, params: BlendParams, upscale: UpscaleParams): Promise<ImageResponse> {
async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeApiUrl(root, 'blend');
appendModelToURL(url, model);

Expand All @@ -549,11 +628,19 @@ export function makeClient(root: string, f = fetch): ApiClient {
body.append(name, params.sources[i], name);
}

// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'blend',
model,
params,
upscale,
}
};
},
async ready(key: string): Promise<ReadyResponse> {
const path = makeApiUrl(root, 'ready');
Expand All @@ -571,6 +658,24 @@ export function makeClient(root: string, f = fetch): ApiClient {
});
return res.status === STATUS_SUCCESS;
},
async retry(retry: RetryParams): Promise<ImageResponseWithRetry> {
switch (retry.type) {
case 'blend':
return this.blend(retry.model, retry.params, retry.upscale);
case 'img2img':
return this.img2img(retry.model, retry.params, retry.upscale);
case 'inpaint':
return this.inpaint(retry.model, retry.params, retry.upscale);
case 'outpaint':
return this.outpaint(retry.model, retry.params, retry.upscale);
case 'txt2img':
return this.txt2img(retry.model, retry.params, retry.upscale);
case 'upscale':
return this.upscale(retry.model, retry.params, retry.upscale);
default:
throw new InvalidArgumentError('unknown request type');
}
}
};
}

Expand Down
3 changes: 3 additions & 0 deletions gui/src/client/local.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ export const LOCAL_CLIENT = {
async cancel(key) {
throw new NoServerError();
},
async retry(params) {
throw new NoServerError();
},
async models() {
throw new NoServerError();
},
Expand Down
2 changes: 1 addition & 1 deletion gui/src/components/ImageHistory.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export function ImageHistory() {

if (doesExist(item.ready) && item.ready.ready) {
if (item.ready.cancelled || item.ready.failed) {
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} ready={item.ready} />]);
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} ready={item.ready} retry={item.retry} />]);
continue;
}

Expand Down
37 changes: 26 additions & 11 deletions gui/src/components/card/RetryCard.tsx
Original file line number Diff line number Diff line change
@@ -1,37 +1,42 @@
import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Card, CardContent, Typography } from '@mui/material';
import { Delete, Replay } from '@mui/icons-material';
import { Box, Card, CardContent, IconButton, Tooltip, Typography } from '@mui/material';
import { Stack } from '@mui/system';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useMutation } from 'react-query';
import { useStore } from 'zustand';

import { ImageResponse, ReadyResponse } from '../../client/api.js';
import { ImageResponse, ReadyResponse, RetryParams } from '../../client/api.js';
import { ClientContext, ConfigContext, StateContext } from '../../state.js';

export interface ErrorCardProps {
image: ImageResponse;
ready: ReadyResponse;
retry: RetryParams;
}

export function ErrorCard(props: ErrorCardProps) {
const { image, ready } = props;
const { image, ready, retry: retryParams } = props;

const client = mustExist(React.useContext(ClientContext));
const { params } = mustExist(useContext(ConfigContext));

const state = mustExist(useContext(StateContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeHistory = useStore(state, (s) => s.removeHistory);
const { t } = useTranslation();

// TODO: actually retry
const retry = useMutation(() => {
// eslint-disable-next-line no-console
console.log('retry', image);
return Promise.resolve(true);
});
async function retryImage() {
removeHistory(image);
const { image: nextImage, retry: nextRetry } = await client.retry(retryParams);
pushHistory(nextImage, nextRetry);
}

const retry = useMutation(retryImage);

return <Card sx={{ maxWidth: params.width.default }}>
<CardContent sx={{ height: params.height.default }}>
Expand All @@ -50,8 +55,18 @@ export function ErrorCard(props: ErrorCardProps) {
current: ready.progress,
total: image.params.steps,
})}</Typography>
<Button onClick={() => retry.mutate()}>{t('loading.retry')}</Button>
<Button onClick={() => removeHistory(image)}>{t('loading.remove')}</Button>
<Stack direction='row' spacing={2}>
<Tooltip title={t('tooltip.retry')}>
<IconButton onClick={() => retry.mutate()}>
<Replay />
</IconButton>
</Tooltip>
<Tooltip title={t('tooltip.delete')}>
<IconButton onClick={() => removeHistory(image)}>
<Delete />
</IconButton>
</Tooltip>
</Stack>
</Stack>
</Box>
</CardContent>
Expand Down
5 changes: 2 additions & 3 deletions gui/src/components/tab/Blend.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ import { MaskCanvas } from '../input/MaskCanvas.js';
export function Blend() {
async function uploadSource() {
const { model, blend, upscale } = state.getState();

const output = await client.blend(model, {
const { image, retry } = await client.blend(model, {
...blend,
mask: mustExist(blend.mask),
sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist
}, upscale);

pushHistory(output);
pushHistory(image, retry);
}

const client = mustExist(useContext(ClientContext));
Expand Down

0 comments on commit 8979064

Please sign in to comment.