Skip to content

Commit

Permalink
fix(api): correctly cache diffusers scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 12, 2023
1 parent 1179092 commit 9c5043e
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_tile_latents(
def load_pipeline(
pipeline: DiffusionPipeline,
model: str,
scheduler: Any,
scheduler_type: Any,
device: DeviceParams,
lpw: bool,
):
Expand All @@ -79,7 +79,7 @@ def load_pipeline(
custom_pipeline = None

logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler.from_pretrained(
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
Expand All @@ -100,22 +100,22 @@ def load_pipeline(

last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler
last_pipeline_scheduler = scheduler_type

if last_pipeline_scheduler != scheduler:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler.from_pretrained(
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
subfolder="scheduler",
)

if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device)
scheduler = scheduler.to(device.torch_device())

pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler
last_pipeline_scheduler = scheduler_type
run_gc()

return pipe

0 comments on commit 9c5043e

Please sign in to comment.