Skip to content

Commit

Permalink
fix(api): use consistent cache key for each model type
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 3, 2023
1 parent a9fa767 commit 47b1094
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 21 deletions.
6 changes: 3 additions & 3 deletions api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL import Image

from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
Expand All @@ -28,7 +28,7 @@ def load(

face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key)
cache_pipe = server.cache.get(ModelTypes.correction, cache_key)

if cache_pipe is not None:
logger.info("reusing existing GFPGAN pipeline")
Expand All @@ -46,7 +46,7 @@ def load(
upscale=upscale.face_outscale,
)

server.cache.set("gfpgan", cache_key, gfpgan)
server.cache.set(ModelTypes.correction, cache_key, gfpgan)
run_gc([device])

return gfpgan
Expand Down
6 changes: 3 additions & 3 deletions api/onnx_web/chain/upscale_bsrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
Expand All @@ -28,7 +28,7 @@ def load(
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("bsrgan", cache_key)
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)

if cache_pipe is not None:
logger.debug("reusing existing BSRGAN pipeline")
Expand All @@ -43,7 +43,7 @@ def load(
sess_options=device.sess_options(),
)

server.cache.set("bsrgan", cache_key, pipe)
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
run_gc([device])

return pipe
Expand Down
6 changes: 3 additions & 3 deletions api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..onnx import OnnxRRDBNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
Expand All @@ -29,7 +29,7 @@ def load(
model_path = path.join(server.model_path, model_file)

cache_key = (model_path, params.format)
cache_pipe = server.cache.get("resrgan", cache_key)
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline")
return cache_pipe
Expand Down Expand Up @@ -66,7 +66,7 @@ def load(
half=False, # TODO: use server optimizations
)

server.cache.set("resrgan", cache_key, upsampler)
server.cache.set(ModelTypes.upscaling, cache_key, upsampler)
run_gc([device])

return upsampler
Expand Down
10 changes: 5 additions & 5 deletions api/onnx_web/chain/upscale_swinir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
Expand All @@ -28,7 +28,7 @@ def load(
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("swinir", cache_key)
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)

if cache_pipe is not None:
logger.info("reusing existing SwinIR pipeline")
Expand All @@ -43,7 +43,7 @@ def load(
sess_options=device.sess_options(),
)

server.cache.set("swinir", cache_key, pipe)
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
run_gc([device])

return pipe
Expand Down Expand Up @@ -75,7 +75,7 @@ def run(
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.info("SwinIR input shape: %s", image.shape)
logger.trace("SwinIR input shape: %s", image.shape)

scale = upscale.outscale
dest = np.zeros(
Expand All @@ -86,7 +86,7 @@ def run(
image.shape[3] * scale,
)
)
logger.info("SwinIR output shape: %s", dest.shape)
logger.trace("SwinIR output shape: %s", dest.shape)

dest = swinir(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
Expand Down
12 changes: 6 additions & 6 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ..diffusers.utils import expand_prompt
from ..params import DeviceParams, ImageParams
from ..server import ServerContext
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from .patches.unet import UNetWrapper
from .patches.vae import VAEWrapper
Expand Down Expand Up @@ -119,14 +119,14 @@ def load_pipeline(
scheduler_key = (params.scheduler, model)
scheduler_type = pipeline_schedulers[params.scheduler]

cache_pipe = server.cache.get("diffusion", pipe_key)
cache_pipe = server.cache.get(ModelTypes.diffusion, pipe_key)

if cache_pipe is not None:
logger.debug("reusing existing diffusion pipeline")
pipe = cache_pipe

# update scheduler
cache_scheduler = server.cache.get("scheduler", scheduler_key)
cache_scheduler = server.cache.get(ModelTypes.scheduler, scheduler_key)
if cache_scheduler is None:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained(
Expand All @@ -141,7 +141,7 @@ def load_pipeline(
scheduler = scheduler.to(device.torch_str())

pipe.scheduler = scheduler
server.cache.set("scheduler", scheduler_key, scheduler)
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
run_gc([device])

else:
Expand Down Expand Up @@ -342,8 +342,8 @@ def load_pipeline(
optimize_pipeline(server, pipe)
patch_pipeline(server, pipe, pipeline, pipeline_class, params)

server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, components["scheduler"])
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])

if hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
apply_patch_facexlib,
apply_patches,
)
from .model_cache import ModelCache
from .model_cache import ModelCache, ModelTypes
from .context import ServerContext
8 changes: 8 additions & 0 deletions api/onnx_web/server/model_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from logging import getLogger
from typing import Any, List, Tuple

Expand All @@ -6,6 +7,13 @@
cache: List[Tuple[str, Any, Any]] = []


class ModelTypes(str, Enum):
correction = "correction"
diffusion = "diffusion"
scheduler = "scheduler"
upscaling = "upscaling"


class ModelCache:
# cache: List[Tuple[str, Any, Any]]
limit: int
Expand Down

0 comments on commit 47b1094

Please sign in to comment.