Skip to content

Commit

Permalink
feat(api): add GFPGAN and Real ESRGAN to model cache
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 14, 2023
1 parent e9472bc commit 0709c1d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 45 deletions.
20 changes: 7 additions & 13 deletions api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,19 @@
logger = getLogger(__name__)


last_pipeline_instance: Optional[GFPGANer] = None
last_pipeline_params: Optional[str] = None


def load_gfpgan(
server: ServerContext,
stage: StageParams,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
_device: DeviceParams,
):
global last_pipeline_instance
global last_pipeline_params

face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key)

if last_pipeline_instance is not None and face_path == last_pipeline_params:
if cache_pipe is not None:
logger.info("reusing existing GFPGAN pipeline")
return last_pipeline_instance
return cache_pipe

logger.debug("loading GFPGAN model from %s", face_path)

Expand All @@ -43,8 +38,7 @@ def load_gfpgan(
upscale=upscale.face_outscale,
)

last_pipeline_instance = gfpgan
last_pipeline_params = face_path
server.cache.set("gfpgan", cache_key, gfpgan)
run_gc()

return gfpgan
Expand Down
26 changes: 12 additions & 14 deletions api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,19 @@


def load_resrgan(
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
global last_pipeline_instance
global last_pipeline_params

model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file)
if not path.isfile(model_path):
raise Exception("Real ESRGAN model not found at %s" % model_path)
model_path = path.join(server.model_path, model_file)

cache_params = (model_path, params.format)
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
cache_key = (model_path, params.format)
cache_pipe = server.cache.get("resrgan", cache_key)
if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline")
return last_pipeline_instance
return cache_pipe

if not path.isfile(model_path):
raise Exception("Real ESRGAN model not found at %s" % model_path)

if x4_v3_tag in model_file:
# the x4-v3 model needs a different network
Expand All @@ -49,7 +48,7 @@ def load_resrgan(
elif params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxNet(
ctx, model_file, provider=device.provider, provider_options=device.options
server, model_file, provider=device.provider, provider_options=device.options
)
elif params.format == "pth":
model = RRDBNet(
Expand All @@ -72,7 +71,7 @@ def load_resrgan(
logger.debug("loading Real ESRGAN upscale model from %s", model_path)

# TODO: shouldn't need the PTH file
model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model)
model_path_pth = path.join(server.model_path, "%s.pth" % params.upscale_model)
upsampler = RealESRGANer(
scale=params.scale,
model_path=model_path_pth,
Expand All @@ -84,8 +83,7 @@ def load_resrgan(
half=params.half,
)

last_pipeline_instance = upsampler
last_pipeline_params = cache_params
server.cache.set("resrgan", cache_key, upsampler)
run_gc()

return upsampler
Expand Down
25 changes: 10 additions & 15 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,26 @@
logger = getLogger(__name__)


last_pipeline_instance = None
last_pipeline_params = (None, None)


def load_stable_diffusion(
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams
server: ServerContext, upscale: UpscaleParams, device: DeviceParams
):
global last_pipeline_instance
global last_pipeline_params
model_path = path.join(server.model_path, upscale.upscale_model)

model_path = path.join(ctx.model_path, upscale.upscale_model)
cache_params = (model_path, upscale.format)
cache_key = (model_path, upscale.format)
cache_pipe = server.cache.get("diffusion", cache_key)

if last_pipeline_instance is not None and cache_params == last_pipeline_params:
if cache_pipe is not None:
logger.debug("reusing existing Stable Diffusion upscale pipeline")
return last_pipeline_instance
return cache_pipe

if upscale.format == "onnx":
logger.debug(
"loading Stable Diffusion upscale ONNX model from %s, using provider %s",
model_path,
device.provider,
)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider, provider_options=device.options
)
else:
Expand All @@ -47,15 +43,14 @@ def load_stable_diffusion(
model_path,
device.provider,
)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
pipe = StableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider
)

last_pipeline_instance = pipeline
last_pipeline_params = cache_params
server.cache.set("diffusion", cache_key, pipe)
run_gc()

return pipeline
return pipe


def upscale_stable_diffusion(
Expand Down
9 changes: 6 additions & 3 deletions api/onnx_web/server/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ def __init__(self, limit: int) -> None:
self.limit = limit

def drop(self, tag: str, key: Any) -> None:
self.cache = [model for model in self.cache if model[0] != tag and model[1] != key]

self.cache = [
model for model in self.cache if model[0] != tag and model[1] != key
]

def get(self, tag: str, key: Any) -> Any:
for t, k, v in self.cache:
Expand All @@ -38,7 +39,9 @@ def set(self, tag: str, key: Any, value: Any) -> None:
def prune(self):
total = len(self.cache)
if total > self.limit:
logger.info("Removing models from cache, %s of %s", (total - self.limit), total)
logger.info(
"Removing models from cache, %s of %s", (total - self.limit), total
)
self.cache[:] = self.cache[: self.limit]
else:
logger.debug("Model cache below limit, %s of %s", total, self.limit)

0 comments on commit 0709c1d

Please sign in to comment.