Skip to content

Commit

Permalink
feat(gui): share image history between tabs, add setting to adjust le…
Browse files Browse the repository at this point in the history
…ngth of history (fixes #22)
  • Loading branch information
ssube committed Jan 11, 2023
1 parent 9bb01cc commit 662bf42
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 94 deletions.
37 changes: 37 additions & 0 deletions gui/src/components/ImageHistory.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { mustExist } from '@apextoaster/js-utils';
import { Grid } from '@mui/material';
import { useContext } from 'react';
import * as React from 'react';
import { useStore } from 'zustand';

import { ApiResponse } from '../api/client.js';
import { StateContext } from '../main.js';
import { ImageCard } from './ImageCard.js';
import { LoadingCard } from './LoadingCard.js';

export function ImageHistory() {
const state = useStore(mustExist(useContext(StateContext)));
const { images } = state.history;

const children = [];

if (state.history.loading) {
children.push(<LoadingCard height={512} width={512} />); // TODO: get dimensions from config
}

function removeHistory(image: ApiResponse) {
state.setHistory(images.filter((item) => image.output !== item.output));
}

if (images.length > 0) {
children.push(...images.map((item) => <ImageCard value={item} onDelete={removeHistory} />));
} else {
if (state.history.loading === false) {
children.push(<div>No results. Press Generate.</div>);
}
}

const limited = children.slice(0, state.history.limit);

return <Grid container spacing={2}>{limited.map((child, idx) => <Grid item key={idx} xs={6}>{child}</Grid>)}</Grid>;
}
13 changes: 6 additions & 7 deletions gui/src/components/Img2Img.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@ import * as React from 'react';
import { useMutation } from 'react-query';
import { useStore } from 'zustand';

import { equalResponse } from '../api/client.js';
import { ConfigParams, IMAGE_FILTER } from '../config.js';
import { ClientContext, StateContext } from '../main.js';
import { ImageCard } from './ImageCard.js';
import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js';
import { MutationHistory } from './MutationHistory.js';
import { NumericField } from './NumericField.js';

const { useContext, useState } = React;
Expand All @@ -26,12 +23,17 @@ export function Img2Img(props: Img2ImgProps) {
const { config, model, platform } = props;

async function uploadSource() {
return client.img2img({
state.setLoading(true);

const output = await client.img2img({
...state.img2img,
model,
platform,
source: mustExist(source), // TODO: show an error if this doesn't exist
});

state.pushHistory(output);
state.setLoading(false);
}

const client = mustExist(useContext(ClientContext));
Expand Down Expand Up @@ -60,9 +62,6 @@ export function Img2Img(props: Img2ImgProps) {
}}
/>
<Button onClick={() => upload.mutate()}>Generate</Button>
<MutationHistory result={upload} limit={4} element={ImageCard}
isEqual={equalResponse}
/>
</Stack>
</Box>;
}
18 changes: 9 additions & 9 deletions gui/src/components/Inpaint.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ import * as React from 'react';
import { useMutation } from 'react-query';
import { useStore } from 'zustand';

import { ApiResponse, equalResponse } from '../api/client.js';
import { ApiResponse } from '../api/client.js';
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
import { ClientContext, StateContext } from '../main.js';
import { ImageCard } from './ImageCard.js';
import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js';
import { MutationHistory } from './MutationHistory.js';
import { NumericField } from './NumericField.js';

const { useContext, useEffect, useRef, useState } = React;
Expand Down Expand Up @@ -72,15 +70,20 @@ export function Inpaint(props: InpaintProps) {

async function uploadSource() {
const canvas = mustExist(canvasRef.current);
return new Promise<ApiResponse>((res, _rej) => {
state.setLoading(true);
return new Promise<void>((res, rej) => {
canvas.toBlob((blob) => {
res(client.inpaint({
client.inpaint({
...state.inpaint,
model,
platform,
mask: mustExist(blob),
source: mustExist(source),
}));
}).then((output) => {
state.pushHistory(output);
state.setLoading(false);
res();
}).catch((err) => rej(err));
});
});
}
Expand Down Expand Up @@ -262,9 +265,6 @@ export function Inpaint(props: InpaintProps) {
}}
/>
<Button onClick={() => upload.mutate()}>Generate</Button>
<MutationHistory result={upload} limit={4} element={ImageCard}
isEqual={equalResponse}
/>
</Stack>
</Box>;
}
58 changes: 0 additions & 58 deletions gui/src/components/MutationHistory.tsx

This file was deleted.

9 changes: 7 additions & 2 deletions gui/src/components/OnnxWeb.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { mustExist } from '@apextoaster/js-utils';
import { TabContext, TabList, TabPanel } from '@mui/lab';
import { Box, Container, Stack, Tab, Typography } from '@mui/material';
import { Box, Container, Divider, Stack, Tab, Typography } from '@mui/material';
import * as React from 'react';
import { useQuery } from 'react-query';

import { ApiClient } from '../api/client.js';
import { ConfigParams, STALE_TIME } from '../config.js';
import { ClientContext } from '../main.js';
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
import { ImageHistory } from './ImageHistory.js';
import { Img2Img } from './Img2Img.js';
import { Inpaint } from './Inpaint.js';
import { QueryList } from './QueryList.js';
Expand Down Expand Up @@ -44,7 +45,7 @@ export function OnnxWeb(props: OnnxWebProps) {
ONNX Web
</Typography>
</Box>
<Box sx={{ my: 4 }}>
<Box sx={{ mx: 4, my: 4 }}>
<Stack direction='row' spacing={2}>
<QueryList
id='models'
Expand Down Expand Up @@ -92,6 +93,10 @@ export function OnnxWeb(props: OnnxWebProps) {
<Settings config={config} />
</TabPanel>
</TabContext>
<Divider variant='middle' />
<Box sx={{ mx: 4, my: 4 }}>
<ImageHistory />
</Box>
</Container>
</div>
);
Expand Down
21 changes: 15 additions & 6 deletions gui/src/components/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { useStore } from 'zustand';

import { ConfigParams } from '../config.js';
import { StateContext } from '../main.js';
import { NumericField } from './NumericField.js';

const { useContext } = React;

Expand All @@ -16,12 +17,14 @@ export function Settings(_props: SettingsProps) {
const state = useStore(mustExist(useContext(StateContext)));

return <Stack spacing={2}>
<Stack direction='row' spacing={2}>
<Button onClick={() => state.resetTxt2Img()}>Reset Txt2Img</Button>
<Button onClick={() => state.resetImg2Img()}>Reset Img2Img</Button>
<Button onClick={() => state.resetInpaint()}>Reset Inpaint</Button>
<Button disabled>Reset All</Button>
</Stack>
<NumericField
label='Image History'
min={2}
max={20}
step={1}
value={state.history.limit}
onChange={(value) => state.setLimit(value)}
/>
<TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => {
state.setDefaults({
model: event.target.value,
Expand All @@ -42,5 +45,11 @@ export function Settings(_props: SettingsProps) {
scheduler: event.target.value,
});
}} />
<Stack direction='row' spacing={2}>
<Button onClick={() => state.resetTxt2Img()}>Reset Txt2Img</Button>
<Button onClick={() => state.resetImg2Img()}>Reset Img2Img</Button>
<Button onClick={() => state.resetInpaint()}>Reset Inpaint</Button>
<Button disabled>Reset All</Button>
</Stack>
</Stack>;
}
18 changes: 7 additions & 11 deletions gui/src/components/Txt2Img.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@ import * as React from 'react';
import { useMutation } from 'react-query';
import { useStore } from 'zustand';

import { BaseImgParams, equalResponse, paramsFromConfig } from '../api/client.js';
import { ConfigParams } from '../config.js';
import { ClientContext, StateContext } from '../main.js';
import { ImageCard } from './ImageCard.js';
import { ImageControl } from './ImageControl.js';
import { MutationHistory } from './MutationHistory.js';
import { NumericField } from './NumericField.js';

const { useContext, useState } = React;
const { useContext } = React;

export interface Txt2ImgProps {
config: ConfigParams;
Expand All @@ -25,11 +22,16 @@ export function Txt2Img(props: Txt2ImgProps) {
const { config, model, platform } = props;

async function generateImage() {
return client.txt2img({
state.setLoading(true);

const output = await client.txt2img({
...state.txt2img,
model,
platform,
});

state.pushHistory(output);
state.setLoading(false);
}

const client = mustExist(useContext(ClientContext));
Expand Down Expand Up @@ -68,12 +70,6 @@ export function Txt2Img(props: Txt2ImgProps) {
/>
</Stack>
<Button onClick={() => generate.mutate()}>Generate</Button>
<MutationHistory
element={ImageCard}
limit={4}
isEqual={equalResponse}
result={generate}
/>
</Stack>
</Box>;
}
57 changes: 56 additions & 1 deletion gui/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { QueryClient, QueryClientProvider } from 'react-query';
import { createStore, StoreApi } from 'zustand';
import { createJSONStorage, persist } from 'zustand/middleware';

import { ApiClient, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js';
import { ApiClient, ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js';
import { OnnxWeb } from './components/OnnxWeb.js';
import { ConfigState, loadConfig } from './config.js';

Expand All @@ -27,6 +27,17 @@ interface OnnxState {
resetTxt2Img(): void;
resetImg2Img(): void;
resetInpaint(): void;

history: {
images: Array<ApiResponse>;
limit: number;
loading: boolean;
};

setLimit(limit: number): void;
setLoading(loading: boolean): void;
setHistory(newHistory: Array<ApiResponse>): void;
pushHistory(newImage: ApiResponse): void;
}

export async function main() {
Expand All @@ -38,6 +49,11 @@ export async function main() {
const defaults = paramsFromConfig(params);
const state = createStore<OnnxState, [['zustand/persist', never]]>(persist((set) => ({
defaults,
history: {
images: [],
limit: 4,
loading: false,
},
txt2img: {
...defaults,
height: params.height.default,
Expand All @@ -50,6 +66,45 @@ export async function main() {
inpaint: {
...defaults,
},
setLimit(limit) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
limit,
},
}));
},
setLoading(loading) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
loading,
},
}));
},
pushHistory(newImage: ApiResponse) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: [
newImage,
...oldState.history.images,
].slice(0, oldState.history.limit),
},
}));
},
setHistory(newHistory: Array<ApiResponse>) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: newHistory,
},
}));
},
setDefaults(newParams) {
set((oldState) => ({
...oldState,
Expand Down

0 comments on commit 662bf42

Please sign in to comment.