Skip to content

Commit

Permalink
feat(api): intercept model downloads in libs and use cached copy (fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 12, 2023
1 parent 8ea33e9 commit 1179092
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 15 deletions.
2 changes: 1 addition & 1 deletion api/launch.bat
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
echo "Downloading and converting models to ONNX format..."
IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json)
python -m onnx_web.convert --diffusion --upscaling --correction --extras=extras.json --token=%HF_TOKEN%
python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=extras.json --token=%HF_TOKEN%

echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0
1 change: 1 addition & 0 deletions api/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fi

echo "Downloading and converting models to ONNX format..."
python3 -m onnx_web.convert \
--sources \
--diffusion \
--upscaling \
--correction \
Expand Down
45 changes: 44 additions & 1 deletion api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
source_format,
tuple_to_correction,
tuple_to_diffusion,
tuple_to_source,
tuple_to_upscaling,
)

Expand Down Expand Up @@ -101,6 +102,33 @@
4,
),
],
# download only
"sources": [
(
"detection-resnet50-final",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
),
(
"detection-mobilenet025-final",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth",
),
(
"detection-yolo-v5-l",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth",
),
(
"detection-yolo-v5-n",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth",
),
(
"parsing-bisenet",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth",
),
(
"parsing-parsenet",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
),
],
}


Expand All @@ -113,8 +141,9 @@ def fetch_model(
if model_format is None:
url = urlparse(source)
ext = path.basename(url.path)
file, ext = path.splitext(ext)
if ext is not None:
cache_name = "%s.%s" % (cache_name, ext)
cache_name += ext
else:
cache_name = "%s.%s" % (cache_name, model_format)

Expand Down Expand Up @@ -147,6 +176,19 @@ def fetch_model(


def convert_models(ctx: ConversionContext, args, models: Models):
if args.sources:
for model in models.get("sources"):
model = tuple_to_source(model)
name = model.get("name")

if name in args.skip:
logger.info("Skipping source: %s", name)
else:
model_format = source_format(model)
source = model["source"]
dest = fetch_model(ctx, name, source, model_format=model_format)
logger.info("Finished downloading source: %s -> %s", source, dest)

if args.diffusion:
for model in models.get("diffusion"):
model = tuple_to_diffusion(model)
Expand Down Expand Up @@ -208,6 +250,7 @@ def main() -> int:
)

# model groups
parser.add_argument("--sources", action="store_true", default=False)
parser.add_argument("--correction", action="store_true", default=False)
parser.add_argument("--diffusion", action="store_true", default=False)
parser.add_argument("--upscaling", action="store_true", default=False)
Expand Down
12 changes: 12 additions & 0 deletions api/onnx_web/convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ def download_progress(urls: List[Tuple[str, str]]):
return str(dest_path.absolute())


def tuple_to_source(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model

return {
"name": name,
"source": source,
}
else:
return model


def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
Expand Down
50 changes: 37 additions & 13 deletions api/onnx_web/hacks.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import sys
import codeformer.facelib.utils.misc
import basicsr.utils.download_util

from functools import partial
from logging import getLogger
from urllib.parse import urlparse
from os import path
from .utils import ServerContext, base_join
from urllib.parse import urlparse

import basicsr.utils.download_util
import codeformer.facelib.utils.misc

from .utils import ServerContext

logger = getLogger(__name__)


def unload(exclude):
"""
Remove package modules from cache except excluded ones.
Expand All @@ -21,7 +23,7 @@ def unload(exclude):
"""
pkgs = []
for mod in exclude:
pkg = mod.split('.', 1)[0]
pkg = mod.split(".", 1)[0]
pkgs.append(pkg)

to_unload = []
Expand All @@ -34,7 +36,7 @@ def unload(exclude):
continue

for pkg in pkgs:
if mod.startswith(pkg + '.'):
if mod.startswith(pkg + "."):
to_unload.append(mod)
break

Expand All @@ -43,28 +45,50 @@ def unload(exclude):
del sys.modules[mod]


# these should be the same sources and names as `convert.base_models.sources`, but inverted so the source is the key
cache_path_map = {
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth": "correction-codeformer.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth": "detection-resnet50-final.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth": "detection-mobilenet025-final.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth": "parsing-bisenet.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth": "parsing-parsenet.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth": "upscaling-real-esrgan-x2-plus",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth": "detection-yolo-v5-l.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth": "detection-yolo-v5-n.pth",
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth": "correct-gfpgan-v1-3.pth",
"https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/detection_Resnet50_Final.pth": "detection-resnet50-final.pth",
"https://s3.eu-central-1.wasabisys.com/nextml-model-data/codeformer/weights/facelib/parsing_parsenet.pth": "parsing-parsenet.pth",
}


def patch_not_impl():
raise NotImplementedError()

def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
url = urlparse(url)
base = path.basename(url.path)
cache_path = path.join(ctx.model_path, ".cache", base)

logger.debug("Patching download path: %s -> %s", path, cache_path)

def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
if url in cache_path_map:
cache_path = cache_path_map.get(url)
else:
parsed = urlparse(url)
cache_path = path.basename(parsed.path)

cache_path = path.join(ctx.model_path, ".cache", cache_path)
logger.debug("Patching download path: %s -> %s", url, cache_path)
return cache_path


def apply_patch_codeformer(ctx: ServerContext):
logger.debug("Patching CodeFormer module...")
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)


def apply_patch_basicsr(ctx: ServerContext):
logger.debug("Patching BasicSR module...")
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)


def apply_patches(ctx: ServerContext):
apply_patch_basicsr(ctx)
apply_patch_codeformer(ctx)
Expand Down

0 comments on commit 1179092

Please sign in to comment.