Skip to content

Commit

Permalink
fix(api): filter platforms based on available providers (fixes #69)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 22, 2023
1 parent 3a5bae6 commit c768cd8
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from glob import glob
from io import BytesIO
from PIL import Image
from onnxruntime import get_available_providers
from os import makedirs, path, scandir
from typing import Tuple

Expand Down Expand Up @@ -72,7 +73,7 @@
'cuda': 'CUDAExecutionProvider',
'directml': 'DmlExecutionProvider',
'nvidia': 'CUDAExecutionProvider',
'rocm': 'ROCmExecutionProvider',
'rocm': 'ROCMExecutionProvider',
}
pipeline_schedulers = {
'ddim': DDIMScheduler,
Expand Down Expand Up @@ -102,6 +103,9 @@
'gaussian-screen': mask_filter_gaussian_screen,
}

# Available ORT providers
available_platforms = []

# loaded from model_path
diffusion_models = []
correction_models = []
Expand Down Expand Up @@ -248,11 +252,21 @@ def load_params(context: ServerContext):
config_params = json.load(f)


def load_platforms():
global available_platforms

providers = get_available_providers()
available_platforms = [p for p in platform_providers if (platform_providers[p] in providers)]

print('available acceleration platforms: %s' % (available_platforms))


context = ServerContext.from_environ()

check_paths(context)
load_models(context)
load_params(context)
load_platforms()

app = Flask(__name__)
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers
Expand Down

0 comments on commit c768cd8

Please sign in to comment.