Skip to content

Commit

Permalink
feat(gui): save source and mask images while changing tabs
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 14, 2023
1 parent e872eea commit 4e82241
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 108 deletions.
22 changes: 9 additions & 13 deletions gui/src/components/ImageInput.tsx
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils';
import { PhotoCamera } from '@mui/icons-material';
import { Button, Stack } from '@mui/material';
import * as React from 'react';

const { useState } = React;

export interface ImageInputProps {
filter: string;
hidden?: boolean;
image?: Maybe<Blob>;
label: string;

onChange: (file: File) => void;
renderImage?: (image: string | undefined) => React.ReactNode;
renderImage?: (image: Maybe<Blob>) => React.ReactNode;
}

export function ImageInput(props: ImageInputProps) {
const [image, setImage] = useState<string>();

function renderImage() {
if (mustDefault(props.hidden, false)) {
return undefined;
}

if (doesExist(props.renderImage)) {
return props.renderImage(image);
return props.renderImage(props.image);
}

return <img src={image} />;
if (doesExist(props.image)) {
return <img src={URL.createObjectURL(props.image)} />;
} else {
return <div>Please select an image.</div>;
}
}

return <Stack direction='row' spacing={2}>
Expand All @@ -41,11 +42,6 @@ export function ImageInput(props: ImageInputProps) {
if (doesExist(files) && files.length > 0) {
const file = mustExist(files[0]);

if (doesExist(image)) {
URL.revokeObjectURL(image);
}

setImage(URL.createObjectURL(file));
props.onChange(file);
}
}}
Expand Down
12 changes: 7 additions & 5 deletions gui/src/components/Img2Img.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js';
import { NumericField } from './NumericField.js';

const { useContext, useState } = React;
const { useContext } = React;

export interface Img2ImgProps {
config: ConfigParams;
Expand All @@ -27,7 +27,7 @@ export function Img2Img(props: Img2ImgProps) {
...params,
model,
platform,
source: mustExist(source), // TODO: show an error if this doesn't exist
source: mustExist(params.source), // TODO: show an error if this doesn't exist
});

setLoading(output);
Expand All @@ -46,11 +46,13 @@ export function Img2Img(props: Img2ImgProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);

const [source, setSource] = useState<File>();

return <Box>
<Stack spacing={2}>
<ImageInput filter={IMAGE_FILTER} label='Source' onChange={setSource} />
<ImageInput filter={IMAGE_FILTER} image={params.source} label='Source' onChange={(file) => {
setImg2Img({
source: file,
});
}} />
<ImageControl config={config} params={params} onChange={(newParams) => {
setImg2Img(newParams);
}} />
Expand Down
189 changes: 109 additions & 80 deletions gui/src/components/Inpaint.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient } from '@mui/icons-material';
import { Box, Button, Stack } from '@mui/material';
import { throttle } from 'lodash';
import * as React from 'react';
import { useCallback } from 'react';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';

import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER, SAVE_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js';
Expand Down Expand Up @@ -67,51 +69,49 @@ export function Inpaint(props: InpaintProps) {
const { config, model, platform } = props;
const client = mustExist(useContext(ClientContext));

async function uploadSource() {
const canvas = mustExist(canvasRef.current);
return new Promise<void>((res, rej) => {
canvas.toBlob((blob) => {
client.inpaint({
...params,
model,
platform,
mask: mustExist(blob),
source: mustExist(source),
}).then((output) => {
setLoading(output);
res();
}).catch((err) => rej(err));
});
});
}

function drawSource(file: File) {
function drawSource(file: Blob): Promise<void> {
const image = new Image();
image.onload = () => {
const canvas = mustExist(canvasRef.current);
const ctx = mustExist(canvas.getContext('2d'));
ctx.drawImage(image, 0, 0);
URL.revokeObjectURL(src);
};

const src = URL.createObjectURL(file);
image.src = src;
}
return new Promise<void>((res, _rej) => {
image.onload = () => {
const canvas = mustExist(canvasRef.current);
const ctx = mustExist(canvas.getContext('2d'));
ctx.drawImage(image, 0, 0);
URL.revokeObjectURL(src);

// putting a save call here has a tendency to go into an infinite loop
res();
};

function changeMask(file: File) {
setMask(file);
const src = URL.createObjectURL(file);
image.src = src;
});
}

// always draw the mask to the canvas
drawSource(file);
function saveMask(): Promise<void> {
return new Promise((res, _rej) => {
if (doesExist(canvasRef.current)) {
canvasRef.current.toBlob((blob) => {
setInpaint({
mask: mustExist(blob),
});
res();
});
} else {
res();
}
});
}

function changeSource(file: File) {
setSource(file);
async function uploadSource(): Promise<void> {
const output = await client.inpaint({
...params,
model,
platform,
mask: mustExist(params.mask),
source: mustExist(params.source),
});

// draw the source to the canvas if the mask has not been set
if (doesExist(mask) === false) {
drawSource(file);
}
setLoading(output);
}

function floodMask(flooder: (n: number) => number) {
Expand All @@ -133,45 +133,10 @@ export function Inpaint(props: InpaintProps) {
}

ctx.putImageData(image, 0, 0);
}

const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.inpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);

const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }),
});
// eslint-disable-next-line no-null/no-null
const canvasRef = useRef<HTMLCanvasElement>(null);

// painting state
const [clicks, setClicks] = useState<Array<Point>>([]);
const [painting, setPainting] = useState(false);
const [brushColor, setBrushColor] = useState(DEFAULT_BRUSH.color);
const [brushSize, setBrushSize] = useState(DEFAULT_BRUSH.size);

// image state
const [mask, setMask] = useState<File>();
const [source, setSource] = useState<File>();

useEffect(() => {
const canvas = mustExist(canvasRef.current);
const ctx = mustExist(canvas.getContext('2d'));
ctx.fillStyle = grayToRGB(brushColor);

for (const click of clicks) {
ctx.beginPath();
ctx.arc(click.x, click.y, brushSize, 0, FULL_CIRCLE);
ctx.fill();
}

clicks.length = 0;
}, [clicks.length]);
// eslint-disable-next-line @typescript-eslint/no-floating-promises
save();
}

function renderCanvas() {
return <canvas
Expand Down Expand Up @@ -217,10 +182,74 @@ export function Inpaint(props: InpaintProps) {
/>;
}

const save = useCallback(throttle(saveMask, SAVE_TIME), []);
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.inpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);

const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }),
});
// eslint-disable-next-line no-null/no-null
const canvasRef = useRef<HTMLCanvasElement>(null);

// painting state
const [clicks, setClicks] = useState<Array<Point>>([]);
const [painting, setPainting] = useState(false);
const [brushColor, setBrushColor] = useState(DEFAULT_BRUSH.color);
const [brushSize, setBrushSize] = useState(DEFAULT_BRUSH.size);

useEffect(function changeMask() {
// always draw the new mask to the canvas
if (doesExist(params.mask)) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
drawSource(params.mask);
}
}, [params.mask]);

useEffect(function changeSource() {
// draw the source to the canvas if the mask has not been set
if (doesExist(params.source) && doesExist(params.mask) === false) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
drawSource(params.source);
}
}, [params.source]);

useEffect(() => {
// including clicks.length prevents the initial render from saving a blank canvas
if (doesExist(canvasRef.current) && clicks.length > 0) {
const ctx = mustExist(canvasRef.current.getContext('2d'));
ctx.fillStyle = grayToRGB(brushColor);

for (const click of clicks) {
ctx.beginPath();
ctx.arc(click.x, click.y, brushSize, 0, FULL_CIRCLE);
ctx.fill();
}

clicks.length = 0;

// eslint-disable-next-line @typescript-eslint/no-floating-promises
save();
}
}, [clicks.length]);

return <Box>
<Stack spacing={2}>
<ImageInput filter={IMAGE_FILTER} label='Source' onChange={changeSource} />
<ImageInput filter={IMAGE_FILTER} label='Mask' onChange={changeMask} renderImage={renderCanvas} />
<ImageInput filter={IMAGE_FILTER} image={params.source} label='Source' onChange={(file) => {
setInpaint({
source: file,
});
}} />
<ImageInput filter={IMAGE_FILTER} image={params.mask} label='Mask' onChange={(file) => {
setInpaint({
mask: file,
});
}} renderImage={renderCanvas} />
<Stack direction='row' spacing={4}>
<NumericField
decimal
Expand Down
15 changes: 11 additions & 4 deletions gui/src/config.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { Maybe } from '@apextoaster/js-utils';

import { Img2ImgParams, STATUS_SUCCESS, Txt2ImgParams } from './api/client.js';

export interface ConfigNumber {
Expand All @@ -12,16 +14,20 @@ export interface ConfigString {
keys: Array<string>;
}

export type KeyFilter<T extends object> = {
[K in keyof T]: T[K] extends number ? K : T[K] extends string ? K : never;
export type KeyFilter<T extends object, TValid = number | string> = {
[K in keyof T]: T[K] extends TValid ? K : never;
}[keyof T];

export type ConfigFiles<T extends object> = {
[K in KeyFilter<T, Blob | File>]: Maybe<T[K]>;
};

export type ConfigRanges<T extends object> = {
[K in KeyFilter<T>]: T[K] extends number ? ConfigNumber : T[K] extends string ? ConfigString : never;
};

export type ConfigState<T extends object> = {
[K in KeyFilter<T>]: T[K] extends number ? number : T[K] extends string ? string : never;
export type ConfigState<T extends object, TValid = number | string> = {
[K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never;
};

export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams>>;
Expand All @@ -45,6 +51,7 @@ export const DEFAULT_BRUSH = {
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
export const STALE_TIME = 300_000; // 5 minutes
export const POLL_TIME = 5_000; // 5 seconds
export const SAVE_TIME = 5_000; // 5 seconds

export async function loadConfig(): Promise<Config> {
const configPath = new URL('./config.json', window.origin);
Expand Down
16 changes: 14 additions & 2 deletions gui/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import { OnnxWeb } from './components/OnnxWeb.js';
import { loadConfig } from './config.js';
import { ClientContext, createStateSlices, OnnxState, StateContext } from './state.js';

const { createContext } = React;

export async function main() {
// load config from GUI server
const config = await loadConfig();
Expand Down Expand Up @@ -41,6 +39,20 @@ export async function main() {
...createDefaultSlice(...slice),
}), {
name: 'onnx-web',
partialize(s) {
return {
...s,
img2img: {
...s.img2img,
source: undefined,
},
inpaint: {
...s.inpaint,
mask: undefined,
source: undefined,
},
};
},
storage: createJSONStorage(() => localStorage),
version: 3,
}));
Expand Down

0 comments on commit 4e82241

Please sign in to comment.