Skip to content

Commit

Permalink
fix(api): patch more download paths (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 14, 2023
1 parent d1b2506 commit 05756b2
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 14 deletions.
10 changes: 4 additions & 6 deletions api/onnx_web/chain/correct_codeformer.py
Expand Up @@ -14,19 +14,17 @@
def correct_codeformer(
job: JobContext,
_server: ServerContext,
stage: StageParams,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
source_image: Image.Image = None,
stage_source: Image.Image = None,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
# must be within the load function for patch to take effect
from codeformer import CodeFormer

device = job.get_device()
# TODO: terrible names, fix
image = source or source_image

pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device())
return pipe(image)
return pipe(stage_source or source)
7 changes: 5 additions & 2 deletions api/onnx_web/chain/correct_gfpgan.py
Expand Up @@ -2,9 +2,9 @@
from os import path

import numpy as np
from gfpgan import GFPGANer
from PIL import Image


from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext, run_gc
Expand All @@ -18,7 +18,10 @@ def load_gfpgan(
upscale: UpscaleParams,
_device: DeviceParams,
):
face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
# must be within the load function for patch to take effect
from gfpgan import GFPGANer

face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key)

Expand Down
10 changes: 6 additions & 4 deletions api/onnx_web/chain/upscale_resrgan.py
Expand Up @@ -2,10 +2,7 @@
from os import path

import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

from ..onnx import OnnxNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
Expand All @@ -23,6 +20,11 @@
def load_resrgan(
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
# must be within load function for patches to take effect
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(server.model_path, model_file)

Expand Down Expand Up @@ -75,7 +77,7 @@ def load_resrgan(

# TODO: shouldn't need the PTH file
model_path_pth = path.join(
server.model_path, ".cache", ("%s.pth" % params.upscale_model)
server.cache_path, ("%s.pth" % params.upscale_model)
)
upsampler = RealESRGANer(
scale=params.scale,
Expand Down
38 changes: 36 additions & 2 deletions api/onnx_web/server/hacks.py
Expand Up @@ -48,6 +48,36 @@ def unload(exclude):

# these should be the same sources and names as `convert.base_models.sources`, but inverted so the source is the key
cache_path_map = {
"https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth": (
"pt-inception-2015-12-05-6726825d.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth": (
"detection-resnet50-final.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth": (
"alignment-wflw-4hg.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.0/assessment_hyperIQA.pth": (
"assessment-hyperiqa.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth": (
"detection-mobilenet-025-final.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.0/headpose_hopenet.pth": (
"headpose-hopenet.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.0/matting_modnet_portrait.pth": (
"matting-modnet-portrait.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth": (
"parsing-bisenet.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth": (
"parsing-parsenet.pth"
),
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth": (
"recognition-arcface-ir-se50.pth"
),
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth": (
"correction-codeformer.pth"
),
Expand Down Expand Up @@ -95,9 +125,13 @@ def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
parsed = urlparse(url)
cache_path = path.basename(parsed.path)

cache_path = path.join(ctx.model_path, ".cache", cache_path)
cache_path = path.join(ctx.cache_path, cache_path)
logger.debug("Patching download path: %s -> %s", url, cache_path)
return cache_path

if path.exists(cache_path):
return cache_path
else:
raise FileNotFoundError("Missing cache file: %s" % (cache_path))


def apply_patch_basicsr(ctx: ServerContext):
Expand Down
2 changes: 2 additions & 0 deletions api/onnx_web/utils.py
Expand Up @@ -25,6 +25,7 @@ def __init__(
default_platform: str = None,
image_format: str = "png",
cache: ModelCache = None,
cache_path: str = None,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
Expand All @@ -37,6 +38,7 @@ def __init__(
self.default_platform = default_platform
self.image_format = image_format
self.cache = cache or ModelCache(num_workers)
self.cache_path = cache_path or path.join(model_path, ".cache")

@classmethod
def from_environ(cls):
Expand Down

0 comments on commit 05756b2

Please sign in to comment.