Skip to content

Commit

Permalink
feat(api): return all types of models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 17, 2023
1 parent dba6113 commit ee6308a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 25 deletions.
43 changes: 27 additions & 16 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from flask import Flask, jsonify, request, send_from_directory, url_for
from flask_cors import CORS
from flask_executor import Executor
from glob import glob
from io import BytesIO
from PIL import Image
from os import makedirs, path, scandir
Expand Down Expand Up @@ -45,6 +46,7 @@
from .utils import (
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_from_map,
make_output_name,
safer_join,
Expand All @@ -58,7 +60,6 @@
import numpy as np

# pipeline caching
available_models = []
config_params = {}

# pipeline params
Expand Down Expand Up @@ -95,14 +96,10 @@
'gaussian-screen': mask_filter_gaussian_screen,
}

# TODO: load from model_path
upscale_models = [
'RealESRGAN_x4plus',
]

face_models = [
'GFPGANv1.3',
]
# loaded from model_path
diffusion_models = []
correction_models = []
upscaling_models = []


def url_from_rule(rule) -> str:
Expand Down Expand Up @@ -183,13 +180,16 @@ def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
correction = get_from_list(request.args, 'correction', correction_models)
faces = request.args.get('faces', 'false') == 'true'

return UpscaleParams(
upscale_models[0],
upscaling,
correction_model=correction,
scale=scale,
outscale=outscale,
faces=faces,
face_model=face_models[0],
platform='onnx',
denoise=denoise,
)
Expand All @@ -204,9 +204,16 @@ def check_paths(context: ServerContext):


def load_models(context: ServerContext):
global available_models
available_models = [f.name for f in scandir(
context.model_path) if f.is_dir()]
global diffusion_models
global correction_models
global upscaling_models

diffusion_models = glob(context.model_path, 'diffusion-*')
diffusion_models.append(glob(context.model_path, 'stable-diffusion-*'))

correction_models = glob(context.model_path, 'correction-*')
upscaling_models = glob(context.model_path, 'upscaling-*')



def load_params(context: ServerContext):
Expand Down Expand Up @@ -271,7 +278,11 @@ def list_mask_filters():

@app.route('/api/settings/models')
def list_models():
return jsonify(available_models)
return jsonify({
'diffusion': diffusion_models,
'correction': correction_models,
'upscaling': upscaling_models,
})


@app.route('/api/settings/noises')
Expand Down Expand Up @@ -397,7 +408,7 @@ def inpaint():
return jsonify({
'output': output,
'params': params.tojson(),
'size': upscale.resize(size.with_border(expand)).tojson(),
'size': upscale.resize(size.add_border(expand)).tojson(),
})


Expand Down
10 changes: 5 additions & 5 deletions api/onnx_web/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,20 @@ class UpscaleParams():
def __init__(
self,
upscale_model: str,
correction_model: Union[str, None] = None,
scale: int = 4,
outscale: int = 1,
denoise: float = 0.5,
faces=True,
face_model: Union[str, None] = None,
platform: str = 'onnx',
half=False
) -> None:
self.upscale_model = upscale_model
self.correction_model = correction_model
self.scale = scale
self.outscale = outscale
self.denoise = denoise
self.faces = faces
self.face_model = face_model
self.platform = platform
self.half = half

Expand Down Expand Up @@ -158,16 +158,16 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima


def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
print('correcting faces with GFPGAN model: %s' % params.face_model)
print('correcting faces with GFPGAN model: %s' % params.correction_model)

if params.face_model is None:
if params.correction_model is None:
print('no face model given, skipping')
return image

if upsampler is None:
upsampler = make_resrgan(ctx, params, tile=512)

face_path = path.join(ctx.model_path, '%s.pth' % (params.face_model))
face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model))

# TODO: doesn't have a model param, not sure how to pass ONNX model
face_enhancer = GFPGANer(
Expand Down
17 changes: 13 additions & 4 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from os import environ, path
from time import time
from struct import pack
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, List, Tuple, Union
from hashlib import sha256


Expand Down Expand Up @@ -89,15 +89,15 @@ def __init__(self, width: int, height: int) -> None:
self.width = width
self.height = height

def add_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)

def tojson(self) -> Dict[str, int]:
return {
'height': self.height,
'width': self.width,
}

def with_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)


def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value)
Expand All @@ -107,6 +107,15 @@ def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, m
return min(max(int(args.get(key, default_value)), min_value), max_value)


def get_from_list(args: Any, key: str, values: List[Any]):
selected = args.get(key, values[0])
if selected in values:
return selected
else:
print('invalid selection: %s' % (selected))
return values[0]


def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any):
selected = args.get(key, default)
if selected in values:
Expand Down

0 comments on commit ee6308a

Please sign in to comment.