Skip to content

Commit

Permalink
feat(gui): add cancel to API client
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 4, 2023
1 parent 294c831 commit 900a95e
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 19 deletions.
12 changes: 11 additions & 1 deletion gui/src/client.ts
Expand Up @@ -141,6 +141,7 @@ export interface ImageResponse {
* Status response from the ready endpoint.
*/
export interface ReadyResponse {
progress: number;
ready: boolean;
}

Expand Down Expand Up @@ -213,6 +214,8 @@ export interface ApiClient {
* Check whether some pipeline's output is ready yet.
*/
ready(params: ImageResponse): Promise<ReadyResponse>;

cancel(params: ImageResponse): Promise<boolean>;
}

/**
Expand Down Expand Up @@ -495,7 +498,14 @@ export function makeClient(root: string, f = fetch): ApiClient {

const res = await f(path);
return await res.json() as ReadyResponse;
}
},
async cancel(params: ImageResponse): Promise<boolean> {
const path = makeApiUrl(root, 'cancel');
path.searchParams.append('output', params.output.key);

const res = await f(path);
return res.status === STATUS_SUCCESS;
},
};
}

Expand Down
6 changes: 3 additions & 3 deletions gui/src/components/ImageHistory.tsx
Expand Up @@ -17,12 +17,12 @@ export function ImageHistory() {

const children = [];

if (doesExist(loading)) {
children.push(<LoadingCard key='loading' loading={loading} />);
if (loading.length > 0) {
children.push(...loading.map((item) => <LoadingCard key={`loading-${item.image.output.key}`} loading={item.image} />));
}

if (history.length > 0) {
children.push(...history.map((item) => <ImageCard key={item.output.key} value={item} onDelete={removeHistory} />));
children.push(...history.map((item) => <ImageCard key={`history-${item.output.key}`} value={item} onDelete={removeHistory} />));
} else {
if (doesExist(loading) === false) {
children.push(<Typography>No results. Press Generate.</Typography>);
Expand Down
28 changes: 24 additions & 4 deletions gui/src/components/LoadingCard.tsx
@@ -1,8 +1,8 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Card, CardContent, CircularProgress } from '@mui/material';
import { Button, Card, CardContent, CircularProgress } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useQuery } from 'react-query';
import { useMutation, useQuery } from 'react-query';
import { useStore } from 'zustand';

import { ImageResponse } from '../client.js';
Expand All @@ -17,15 +17,34 @@ export function LoadingCard(props: LoadingCardProps) {
const client = mustExist(React.useContext(ClientContext));
const { params } = mustExist(useContext(ConfigContext));

const state = mustExist(useContext(StateContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(mustExist(useContext(StateContext)), (state) => state.pushHistory);
const clearLoading = useStore(state, (s) => s.clearLoading);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);

async function doCancel() {
const cancelled = await client.cancel(props.loading);
if (cancelled) {
clearLoading();
}
}

const cancel = useMutation(doCancel);
const query = useQuery('ready', () => client.ready(props.loading), {
// data will always be ready without this, even if the API says its not
cacheTime: 0,
refetchInterval: POLL_TIME,
});

function progress() {
if (doesExist(query.data)) {
return Math.ceil(query.data.progress / props.loading.params.steps);
}

return 0;
}

function ready() {
return doesExist(query.data) && query.data.ready;
}
Expand All @@ -44,7 +63,8 @@ export function LoadingCard(props: LoadingCardProps) {
justifyContent: 'center',
minHeight: params.height.default,
}}>
<CircularProgress />
<CircularProgress value={progress()} />
<Button>Cancel</Button>
</div>
</CardContent>
</Card>;
Expand Down
46 changes: 35 additions & 11 deletions gui/src/state.ts
@@ -1,5 +1,5 @@
/* eslint-disable no-null/no-null */
import { Maybe } from '@apextoaster/js-utils';
import { doesExist, Maybe } from '@apextoaster/js-utils';
import { createContext } from 'react';
import { StateCreator, StoreApi } from 'zustand';

Expand All @@ -12,6 +12,7 @@ import {
InpaintParams,
ModelParams,
OutpaintPixels,
ReadyResponse,
Txt2ImgParams,
UpscaleParams,
UpscaleReqParams,
Expand All @@ -23,6 +24,11 @@ import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js';
*/
type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;

interface LoadingItem {
image: ImageResponse;
ready: Maybe<ReadyResponse>;
}

interface BrushSlice {
brush: BrushParams;

Expand All @@ -38,12 +44,14 @@ interface DefaultSlice {
interface HistorySlice {
history: Array<ImageResponse>;
limit: number;
loading: Maybe<ImageResponse>;
loading: Array<LoadingItem>;

// TODO: hack until setLoading removes things
clearLoading(): void;
pushHistory(image: ImageResponse): void;
removeHistory(image: ImageResponse): void;
setLimit(limit: number): void;
setLoading(image: Maybe<ImageResponse>): void;
setLoading(image: ImageResponse, ready?: Maybe<ReadyResponse>): void;
}

interface ModelSlice {
Expand Down Expand Up @@ -264,17 +272,39 @@ export function createStateSlices(server: ServerParams) {
const createHistorySlice: Slice<HistorySlice> = (set) => ({
history: [],
limit: DEFAULT_HISTORY.limit,
loading: null,
loading: [],
clearLoading() {
set((prev) => ({
...prev,
loading: [],
}));
},
pushHistory(image) {
set((prev) => ({
...prev,
history: [
image,
...prev.history,
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
loading: null,
loading: [],
}));
},
setLoading(image, ready) {
set((prev) => {
const loading = [...prev.loading];
const idx = loading.findIndex((it) => it.image.output.key === image.output.key);
if (idx >= 0) {
loading[idx].ready = ready;
} else {
loading.push({ image, ready });
}

return {
...prev,
loading,
};
});
},
removeHistory(image) {
set((prev) => ({
...prev,
Expand All @@ -287,12 +317,6 @@ export function createStateSlices(server: ServerParams) {
limit,
}));
},
setLoading(loading) {
set((prev) => ({
...prev,
loading,
}));
},
});

const createOutpaintSlice: Slice<OutpaintSlice> = (set) => ({
Expand Down

0 comments on commit 900a95e

Please sign in to comment.