Skip to content

Commit

Permalink
fix(api): patch various download fns to use cache (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 12, 2023
1 parent 6983691 commit 8ea33e9
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 3 deletions.
3 changes: 2 additions & 1 deletion api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from logging import getLogger

from codeformer import CodeFormer
from PIL import Image

from ..device_pool import JobContext
Expand All @@ -23,6 +22,8 @@ def correct_codeformer(
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
from codeformer import CodeFormer

device = job.get_device()
# TODO: terrible names, fix
image = source or source_image
Expand Down
11 changes: 9 additions & 2 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from os import makedirs, path
from sys import exit
from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse

from jsonschema import ValidationError, validate
from yaml import safe_load
Expand Down Expand Up @@ -107,8 +108,14 @@ def fetch_model(
ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None
) -> str:
cache_name = path.join(ctx.cache_path, name)
if model_format is not None:
# add an extension if possible, some of the conversion code checks for it

# add an extension if possible, some of the conversion code checks for it
if model_format is None:
url = urlparse(source)
ext = path.basename(url.path)
if ext is not None:
cache_name = "%s.%s" % (cache_name, ext)
else:
cache_name = "%s.%s" % (cache_name, model_format)

for proto in model_sources:
Expand Down
71 changes: 71 additions & 0 deletions api/onnx_web/hacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import sys
import codeformer.facelib.utils.misc
import basicsr.utils.download_util

from functools import partial
from logging import getLogger
from urllib.parse import urlparse
from os import path
from .utils import ServerContext, base_join

logger = getLogger(__name__)

def unload(exclude):
"""
Remove package modules from cache except excluded ones.
On next import they will be reloaded.
From https://medium.com/@chipiga86/python-monkey-patching-like-a-boss-87d7ddb8098e
Args:
exclude (iter<str>): Sequence of module paths.
"""
pkgs = []
for mod in exclude:
pkg = mod.split('.', 1)[0]
pkgs.append(pkg)

to_unload = []
for mod in sys.modules:
if mod in exclude:
continue

if mod in pkgs:
to_unload.append(mod)
continue

for pkg in pkgs:
if mod.startswith(pkg + '.'):
to_unload.append(mod)
break

logger.debug("Unloading modules for patching: %s", to_unload)
for mod in to_unload:
del sys.modules[mod]


def patch_not_impl():
raise NotImplementedError()

def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
url = urlparse(url)
base = path.basename(url.path)
cache_path = path.join(ctx.model_path, ".cache", base)

logger.debug("Patching download path: %s -> %s", path, cache_path)

return cache_path

def apply_patch_codeformer(ctx: ServerContext):
logger.debug("Patching CodeFormer module...")
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)

def apply_patch_basicsr(ctx: ServerContext):
logger.debug("Patching BasicSR module...")
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)

def apply_patches(ctx: ServerContext):
apply_patch_basicsr(ctx)
apply_patch_codeformer(ctx)
unload(["basicsr.utils.download_util", "codeformer.facelib.utils.misc"])
3 changes: 3 additions & 0 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from onnxruntime import get_available_providers
from PIL import Image

from onnx_web.hacks import apply_patches

from .chain import (
ChainPipeline,
blend_img2img,
Expand Down Expand Up @@ -405,6 +407,7 @@ def any_first_cpu_last(a: DeviceParams, b: DeviceParams):


context = ServerContext.from_environ()
apply_patches(context)
check_paths(context)
load_models(context)
load_params(context)
Expand Down

0 comments on commit 8ea33e9

Please sign in to comment.