Skip to content

Commit

Permalink
fix(api): pass device ID in provider params
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 5, 2023
1 parent 9d1f941 commit 37dd892
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 11 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_resrgan.py
Expand Up @@ -45,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams

# use ONNX acceleration, if available
if params.format == 'onnx':
model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options)
model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options)
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
Expand Down
4 changes: 2 additions & 2 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Expand Up @@ -44,10 +44,10 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: De

if upscale.format == 'onnx':
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options)
else:
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider)

last_pipeline_instance = pipeline
last_pipeline_params = cache_params
Expand Down
6 changes: 3 additions & 3 deletions api/onnx_web/diffusion/load.py
Expand Up @@ -65,15 +65,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
sess_options=device.options,
provider_options=device.options,
subfolder='scheduler',
)
pipe = pipeline.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
safety_checker=None,
scheduler=scheduler,
sess_options=device.options,
)

if device is not None and hasattr(pipe, 'to'):
Expand All @@ -88,7 +88,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
sess_options=device.options,
provider_options=device.options,
subfolder='scheduler',
)

Expand Down
7 changes: 2 additions & 5 deletions api/onnx_web/onnx/onnx_net.py
Expand Up @@ -48,14 +48,11 @@ def __init__(
server: ServerContext,
model: str,
provider: str = 'DmlExecutionProvider',
sess_options: Optional[dict] = None,
provider_options: Optional[dict] = None,
) -> None:
'''
TODO: get platform provider from request params
'''
model_path = path.join(server.model_path, model)
self.session = InferenceSession(
model_path, providers=[provider], sess_options=sess_options)
model_path, providers=[provider], provider_options=provider_options)

def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
Expand Down

0 comments on commit 37dd892

Please sign in to comment.