Skip to content

Commit

Permalink
fix(api): pass current device when loading GFPGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 6, 2023
1 parent 811b664 commit c7e0041
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from realesrgan import RealESRGANer

from ..device_pool import JobContext
from ..params import ImageParams, StageParams, UpscaleParams
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext, run_gc
from .upscale_resrgan import load_resrgan

Expand All @@ -20,7 +20,7 @@


def load_gfpgan(
ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams, upsampler: Optional[RealESRGANer] = None
):
global last_pipeline_instance
global last_pipeline_params
Expand Down Expand Up @@ -54,7 +54,7 @@ def load_gfpgan(


def correct_gfpgan(
_job: JobContext,
job: JobContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
Expand All @@ -69,7 +69,8 @@ def correct_gfpgan(
return source_image

logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
device = job.get_device()
gfpgan = load_gfpgan(server, upscale, device, upsampler=upsampler)

output = np.array(source_image)
_, _, output = gfpgan.enhance(
Expand Down

0 comments on commit c7e0041

Please sign in to comment.