Skip to content

Commit

Permalink
feat: implement blend tab and copy buttons (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 13, 2023
1 parent 1de591e commit 7fa1783
Show file tree
Hide file tree
Showing 15 changed files with 226 additions and 37 deletions.
1 change: 1 addition & 0 deletions api/onnx_web/chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_img2img import blend_img2img
from .blend_inpaint import blend_inpaint
from .blend_mask import blend_mask
from .correct_codeformer import correct_codeformer
from .correct_gfpgan import correct_gfpgan
from .persist_disk import persist_disk
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def blend_img2img(
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info("generating image using img2img, %s steps: %s", params.steps, prompt)
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)

pipe = load_pipeline(
OnnxStableDiffusionImg2ImgPipeline,
Expand Down
4 changes: 3 additions & 1 deletion api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def blend_inpaint(
callback: ProgressCallback = None,
**kwargs,
) -> Image.Image:
logger.info("upscaling image by expanding borders", expand)
logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)

if mask_image is None:
# if no mask was provided, keep the full source image
Expand Down
36 changes: 36 additions & 0 deletions api/onnx_web/chain/blend_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from logging import getLogger
from typing import List, Optional

from PIL import Image

from onnx_web.output import save_image

from ..device_pool import JobContext, ProgressCallback
from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug

logger = getLogger(__name__)


def blend_mask(
_job: JobContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
*,
sources: Optional[List[Image.Image]] = None,
mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using mask")

l_mask = Image.new("RGBA", mask.size, color="black")
l_mask.alpha_composite(mask)
l_mask = l_mask.convert("L")

if is_debug():
save_image(server, "last-mask.png", mask)
save_image(server, "last-mask-l.png", l_mask)

return Image.composite(sources[0], sources[1], l_mask)
39 changes: 38 additions & 1 deletion api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from logging import getLogger
from typing import Any
from typing import Any, List

import numpy as np
import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops

from onnx_web.chain import blend_mask
from onnx_web.chain.base import ChainProgress

from ..chain import upscale_outpaint
Expand Down Expand Up @@ -231,3 +232,39 @@ def run_upscale_pipeline(
run_gc()

logger.info("finished upscale job: %s", dest)


def run_blend_pipeline(
job: JobContext,
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams,
sources: List[Image.Image],
mask: Image.Image,
) -> None:
progress = job.get_progress_callback()
stage = StageParams()

image = blend_mask(
job,
server,
stage,
params,
sources=sources,
mask=mask,
callback=progress,
)

image = run_upscale_correction(
job, server, stage, params, image, upscale=upscale, callback=progress
)

dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)

del image
run_gc()

logger.info("finished blend job: %s", dest)
37 changes: 37 additions & 0 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .device_pool import DevicePoolExecutor
from .diffusion.load import pipeline_schedulers
from .diffusion.run import (
run_blend_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
Expand Down Expand Up @@ -735,6 +736,42 @@ def chain():
return jsonify(json_params(output, params, size))


@app.route("/api/blend", methods=["POST"])
def blend():
if "mask" not in request.files:
return error_reply("mask image is required")

mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")

source_file = request.files.get("source:0")
source_0 = Image.open(BytesIO(source_file.read())).convert("RGBA")

source_file = request.files.get("source:1")
source_1 = Image.open(BytesIO(source_file.read())).convert("RGBA")

device, params, size = pipeline_from_request()
upscale = upscale_from_request()

output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)

executor.submit(
output,
run_blend_pipeline,
context,
params,
size,
output,
upscale,
[source_0, source_1],
mask,
needs_device=device,
)

return jsonify(json_params(output, params, size))


@app.route("/api/cancel", methods=["PUT"])
def cancel():
output_file = request.args.get("output", None)
Expand Down
22 changes: 21 additions & 1 deletion gui/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import { doesExist } from '@apextoaster/js-utils';

import { ServerParams } from './config.js';
import { range } from './utils.js';

/**
* Shared parameters for anything using models, which is pretty much everything.
Expand Down Expand Up @@ -482,7 +483,26 @@ export function makeClient(root: string, f = fetch): ApiClient {
});
},
async blend(model: ModelParams, params: BlendParams, upscale: UpscaleParams): Promise<ImageResponse> {
throw new Error('TODO');
const url = makeApiUrl(root, 'blend');
appendModelToURL(url, model);

if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}

const body = new FormData();
body.append('mask', params.mask, 'mask');

for (const i of range(params.sources.length)) {
const name = `source:${i.toFixed(0)}`;
body.append(name, params.sources[i], name);
}

// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
body,
method: 'POST',
});
},
async ready(params: ImageResponse): Promise<ReadyResponse> {
const path = makeApiUrl(root, 'ready');
Expand Down
41 changes: 32 additions & 9 deletions gui/src/components/ImageCard.tsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Blender, Brush, ContentCopy, CropFree, Delete, Download, ZoomOutMap } from '@mui/icons-material';
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Paper, Tooltip } from '@mui/material';
import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material';
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useContext, useState } from 'react';
import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';

import { ImageResponse } from '../client.js';
import { ConfigContext, StateContext } from '../state.js';
import { BLEND_SOURCES, ConfigContext, StateContext } from '../state.js';
import { MODEL_LABELS, SCHEDULER_LABELS } from '../strings.js';
import { range, visibleIndex } from '../utils.js';

export interface ImageCardProps {
value: ImageResponse;
Expand All @@ -27,6 +28,8 @@ export function ImageCard(props: ImageCardProps) {
const { params, output, size } = value;

const [_hash, setHash] = useHash();
const [anchor, setAnchor] = useState<Maybe<HTMLElement>>();

const config = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
Expand Down Expand Up @@ -67,11 +70,13 @@ export function ImageCard(props: ImageCardProps) {
setHash('upscale');
}

async function copySourceToBlend() {
async function copySourceToBlend(idx: number) {
const blob = await loadSource();
// TODO: push instead
const sources = mustDefault(state.getState().blend.sources, []);
const newSources = [...sources];
newSources[idx] = blob;
setBlend({
sources: [blob],
sources: newSources,
});
setHash('blend');
}
Expand All @@ -86,6 +91,10 @@ export function ImageCard(props: ImageCardProps) {
window.open(output.url, '_blank');
}

function close() {
setAnchor(undefined);
}

const model = mustDefault(MODEL_LABELS[params.model], params.model);
const scheduler = mustDefault(SCHEDULER_LABELS[params.scheduler], params.scheduler);

Expand Down Expand Up @@ -137,10 +146,24 @@ export function ImageCard(props: ImageCardProps) {
</GridItem>
<GridItem xs={2}>
<Tooltip title='Blend'>
<IconButton onClick={copySourceToBlend}>
<IconButton onClick={(event) => {
setAnchor(event.currentTarget);
}}>
<Blender />
</IconButton>
</Tooltip>
<Menu
anchorEl={anchor}
open={doesExist(anchor)}
onClose={close}
>
{range(BLEND_SOURCES).map((idx) => <MenuItem key={idx} onClick={() => {
copySourceToBlend(idx).catch((err) => {
// TODO
});
close();
}}>{visibleIndex(idx)}</MenuItem>)}
</Menu>
</GridItem>
<GridItem xs={2}>
<Tooltip title='Delete'>
Expand Down
2 changes: 1 addition & 1 deletion gui/src/components/control/ImageControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export interface ImageControlProps {
}

/**
* doesn't need to use state, the parent component knows which params to pass
* Doesn't need to use state directly, the parent component knows which params to pass
*/
export function ImageControl(props: ImageControlProps) {
const { params } = mustExist(useContext(ConfigContext));
Expand Down
6 changes: 2 additions & 4 deletions gui/src/components/input/MaskCanvas.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient, InvertColors, Undo } from '@mui/icons-material';
import { Button, Stack, Typography } from '@mui/material';
import { createLogger } from 'browser-bunyan';
import { throttle } from 'lodash';
import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react';
import { useStore } from 'zustand';

import { SAVE_TIME } from '../../config.js';
import { ConfigContext, StateContext } from '../../state.js';
import { ConfigContext, LoggerContext, StateContext } from '../../state.js';
import { imageFromBlob } from '../../utils.js';
import { NumericField } from './NumericField';

Expand Down Expand Up @@ -42,11 +41,10 @@ export interface MaskCanvasProps {
onSave: (blob: Blob) => void;
}

const logger = createLogger({ name: 'react', level: 'debug' }); // TODO: hackeroni and cheese

export function MaskCanvas(props: MaskCanvasProps) {
const { source, mask } = props;
const { params } = mustExist(useContext(ConfigContext));
const logger = mustExist(useContext(LoggerContext));

function composite() {
if (doesExist(viewRef.current)) {
Expand Down
39 changes: 24 additions & 15 deletions gui/src/components/tab/Blend.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { mustDefault, mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';

import { IMAGE_FILTER } from '../../config.js';
import { ClientContext, StateContext } from '../../state.js';
import { BLEND_SOURCES, ClientContext, StateContext } from '../../state.js';
import { range } from '../../utils.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { MaskCanvas } from '../input/MaskCanvas.js';
Expand Down Expand Up @@ -41,22 +42,30 @@ export function Blend() {

return <Box>
<Stack spacing={2}>
<ImageInput
filter={IMAGE_FILTER}
image={sources[0]}
hideSelection={true}
label='Source'
onChange={(file) => {
setBlend({
sources: [file],
});
}}
/>
{range(BLEND_SOURCES).map((idx) =>
<ImageInput
key={`source-${idx.toFixed(0)}`}
filter={IMAGE_FILTER}
image={sources[idx]}
hideSelection={true}
label='Source'
onChange={(file) => {
const newSources = [...sources];
newSources[idx] = file;

setBlend({
sources: newSources,
});
}}
/>
)}
<MaskCanvas
source={sources[0]}
mask={blend.mask}
onSave={() => {
// TODO
onSave={(mask) => {
setBlend({
mask,
});
}}
/>
<UpscaleControl />
Expand Down

0 comments on commit 7fa1783

Please sign in to comment.