Skip to content

Commit

Permalink
feat: add outscaling option
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 16, 2023
1 parent 091c4e6 commit 8d3ebed
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 29 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Based on guides by:
- [For Nvidia everywhere: Install PyTorch GPU and ONNX GPU](#for-nvidia-everywhere-install-pytorch-gpu-and-onnx-gpu)
- [Download and convert models](#download-and-convert-models)
- [Test the models](#test-the-models)
- [Upscaling and face correction](#upscaling-and-face-correction)
- [Usage](#usage)
- [Running the containers](#running-the-containers)
- [Configuring and running the server](#configuring-and-running-the-server)
Expand Down Expand Up @@ -310,6 +311,13 @@ If the script works, there will be an image of an astronaut in `outputs/test.png

If you get any errors, check [the known errors section](#known-errors-and-solutions).

### Upscaling and face correction

Models:

- https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
- https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth

## Usage

### Running the containers
Expand Down
18 changes: 16 additions & 2 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@
'gaussian-screen': mask_filter_gaussian_screen,
}

# TODO: load from model_path
upscale_models = [
'RealESRGAN_x4plus',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', # TODO: convert GFPGAN
]


def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', bundle_path), filename)
Expand Down Expand Up @@ -192,9 +198,17 @@ def border_from_request() -> Border:
def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
faces = request.args.get('faces', 'false') == 'true'
platform = 'onnx'
return UpscaleParams(scale=scale, faces=faces, platform=platform, denoise=denoise)
return UpscaleParams(
upscale_models[0],
scale=scale,
outscale=outscale,
faces=faces,
face_model=upscale_models[1],
platform='onnx',
denoise=denoise,
)

def check_paths():
if not path.exists(model_path):
Expand Down
59 changes: 36 additions & 23 deletions api/onnx_web/upscale.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from gfpgan import GFPGANer
from onnxruntime import InferenceSession
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Any
from typing import Any, Union

import numpy as np
import torch
Expand All @@ -15,16 +14,9 @@
)

# TODO: these should all be params or config
fp16 = False
outscale = 4
pre_pad = 0
tile_pad = 10

gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
resrgan_name = 'RealESRGAN_x4plus'
resrgan_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']


class ONNXImage():
def __init__(self, source) -> None:
Expand All @@ -51,6 +43,9 @@ def clamp_(self, min, max):
def numpy(self):
return self.source

def size(self):
return np.shape(self.source)


class ONNXNet():
'''
Expand Down Expand Up @@ -87,29 +82,44 @@ def to(self, device):


class UpscaleParams():
def __init__(self, scale=4, faces=True, platform='onnx', denoise=0.5) -> None:
self.denoise = denoise
def __init__(
self,
upscale_model: str,
scale: int = 4,
outscale: int = 1,
denoise: float = 0.5,
faces=True,
face_model: Union[str, None] = None,
platform: str = 'onnx',
half=False
) -> None:
self.upscale_model = upscale_model
self.scale = scale
self.outscale = outscale
self.denoise = denoise
self.faces = faces
self.face_model = face_model
self.platform = platform
self.half = half


def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_path = path.join(ctx.model_path, resrgan_name + '.pth')
model_path = path.join(ctx.model_path, '%s.%s' %
(params.upscale_model, params.platform))
if not path.isfile(model_path):
for url in resrgan_url:
model_path = load_file_from_url(
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
raise Exception('Real ESRGAN model not found at %s' % model_path)

# use ONNX acceleration, if available
if params.platform == 'onnx':
model = ONNXNet(ctx)
else:
elif params.platform == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
else:
raise Exception('unknown platform %s' % params.platform)

dni_weight = None
if resrgan_name == 'realesr-general-x4v3' and params.denoise != 1:
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
Expand All @@ -123,7 +133,7 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
half=fp16)
half=params.half)

return upsampler

Expand All @@ -134,8 +144,7 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
image = np.array(source_image)
upsampler = make_resrgan(ctx, params)

# TODO: what is outscale for here?
output, _ = upsampler.enhance(image, outscale=outscale)
output, _ = upsampler.enhance(image, outscale=params.outscale)

if params.faces:
output = upscale_gfpgan(ctx, params, output)
Expand All @@ -144,14 +153,18 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima


def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
print('correcting faces with GFPGAN')
print('correcting faces with GFPGAN model: %s' % params.face_model)

if params.face_model is None:
print('no face model given, skipping')
return image

if upsampler is None:
upsampler = make_resrgan(ctx, params, tile=512)

face_enhancer = GFPGANer(
model_path=gfpgan_url,
upscale=outscale,
model_path=params.face_model,
upscale=params.outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
Expand Down
6 changes: 6 additions & 0 deletions api/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@
"max": 1,
"step": 0.01
},
"outscale": {
"default": 1,
"min": 1,
"max": 4,
"step": 1
},
"width": {
"default": 512,
"min": 64,
Expand Down
2 changes: 2 additions & 0 deletions gui/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export interface UpscaleParams {
denoise: number;
faces: boolean;
scale: number;
outscale: number;
}

export interface ApiResponse {
Expand Down Expand Up @@ -170,6 +171,7 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
url.searchParams.append('faces', String(upscale.faces));
url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER));
url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
}

export function makeClient(root: string, f = fetch): ApiClient {
Expand Down
4 changes: 2 additions & 2 deletions gui/src/components/ImageCard.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { ContentCopy, ContentCopyTwoTone, Delete, Download } from '@mui/icons-material';
import { Brush, ContentCopy, ContentCopyTwoTone, Delete, Download } from '@mui/icons-material';
import { Box, Button, Card, CardContent, CardMedia, Grid, Paper } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
Expand Down Expand Up @@ -86,7 +86,7 @@ export function ImageCard(props: ImageCardProps) {
</GridItem>
<GridItem xs={2}>
<Button onClick={copySourceToInpaint}>
<ContentCopyTwoTone />
<Brush />
</Button>
</GridItem>
<GridItem xs={2}>
Expand Down
8 changes: 7 additions & 1 deletion gui/src/components/ImageInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ export function ImageInput(props: ImageInputProps) {
}

if (doesExist(props.image)) {
return <img src={URL.createObjectURL(props.image)} />;
return <img
src={URL.createObjectURL(props.image)}
style={{
maxWidth: 512,
maxHeight: 512,
}}
/>;
} else {
return <div>Please select an image.</div>;
}
Expand Down
5 changes: 4 additions & 1 deletion gui/src/components/MaskCanvas.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient } from '@mui/icons-material';
import { Button, Stack } from '@mui/material';
import { Button, Stack, Typography } from '@mui/material';
import { throttle } from 'lodash';
import React, { RefObject, useContext, useEffect, useMemo, useRef, useState } from 'react';
import { useStore } from 'zustand';
Expand Down Expand Up @@ -230,6 +230,9 @@ export function MaskCanvas(props: MaskCanvasProps) {
}
}}
/>
<Typography variant='body1'>
Black pixels in the mask will stay the same, white pixels will be replaced with pixels from the noise source.
</Typography>
<Stack direction='row' spacing={4}>
<NumericField
label='Brush Color'
Expand Down
13 changes: 13 additions & 0 deletions gui/src/components/UpscaleControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ export function UpscaleControl(props: UpscaleControlProps) {
});
}}
/>
<NumericField
label='Outscale'
disabled={params.enabled === false}
min={config.outscale.min}
max={config.outscale.max}
step={config.outscale.step}
value={params.outscale}
onChange={(outscale) => {
setUpscale({
outscale,
});
}}
/>
<NumericField
label='Denoise'
decimal
Expand Down
1 change: 1 addition & 0 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ export function createStateSlices(base: ConfigParams) {
enabled: false,
faces: false,
scale: 1,
outscale: 1,
},
setUpscale(upscale) {
set((prev) => ({
Expand Down
1 change: 1 addition & 0 deletions onnx-web.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"Onnx",
"onnxruntime",
"outpaint",
"outscale",
"pndm",
"pretrained",
"protobuf",
Expand Down

0 comments on commit 8d3ebed

Please sign in to comment.