Skip to content

Commit

Permalink
fix(api): use ORT session for correct device when loading blended nets
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 18, 2023
1 parent c465b61 commit 9f9b73b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 2 additions & 3 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
StableDiffusionPipeline,
)
from onnx import load_model
from onnxruntime import SessionOptions
from transformers import CLIPTokenizer

from onnx_web.diffusers.utils import expand_prompt
Expand Down Expand Up @@ -271,7 +270,7 @@ def load_pipeline(
text_encoder
)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions()
text_encoder_opts = device.sess_options(cache=False)
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
Expand All @@ -292,7 +291,7 @@ def load_pipeline(
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions()
unet_opts = device.sess_options(cache=False)
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
Expand Down
8 changes: 5 additions & 3 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def ort_provider(self) -> Union[str, Tuple[str, Any]]:
else:
return (self.provider, self.options)

def sess_options(self) -> SessionOptions:
if self.sess_options_cache is not None:
def sess_options(self, cache = True) -> SessionOptions:
if cache and self.sess_options_cache is not None:
return self.sess_options_cache

sess = SessionOptions()
Expand All @@ -139,7 +139,9 @@ def sess_options(self) -> SessionOptions:
logger.debug("enabling ONNX deterministic compute")
sess.use_deterministic_compute = True

self.sess_options_cache = sess
if cache:
self.sess_options_cache = sess

return sess

def torch_str(self) -> str:
Expand Down

0 comments on commit 9f9b73b

Please sign in to comment.