Skip to content

Commit

Permalink
feat(api): add provider for each available CUDA device (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 4, 2023
1 parent f6dbab3 commit 98b6e4d
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 19 deletions.
22 changes: 15 additions & 7 deletions api/onnx_web/device_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
from multiprocessing import Value
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from .params import (
DeviceParams,
)

logger = getLogger(__name__)


class JobContext:
def __init__(
self,
key: str,
devices: List[str],
devices: List[DeviceParams],
cancel: bool = False,
device_index: int = -1,
progress: int = 0,
Expand All @@ -24,7 +28,7 @@ def __init__(
def is_cancelled(self) -> bool:
return self.cancel.value

def get_device(self) -> str:
def get_device(self) -> DeviceParams:
'''
Get the device assigned to this job.
'''
Expand All @@ -45,7 +49,8 @@ def on_progress(step: int, timestep: int, latents: Any):
if self.is_cancelled():
raise Exception('job has been cancelled')
else:
logger.debug('setting progress for job %s to %s', self.key, step)
logger.debug('setting progress for job %s to %s',
self.key, step)
self.set_progress(step)

return on_progress
Expand All @@ -63,6 +68,7 @@ class Job:
'''
Link a future to its context.
'''

def __init__(
self,
key: str,
Expand All @@ -88,16 +94,18 @@ class DevicePoolExecutor:
jobs: List[Job] = None
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None

def __init__(self, devices: List[str], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
self.devices = devices
self.jobs = []

device_count = len(devices)
if pool is None:
logger.info('creating thread pool executor for %s devices: %s', device_count, devices)
logger.info(
'creating thread pool executor for %s devices: %s', device_count, devices)
self.pool = ThreadPoolExecutor(device_count)
else:
logger.info('using existing pool for %s devices: %s', device_count, devices)
logger.info('using existing pool for %s devices: %s',
device_count, devices)
self.pool = pool

def cancel(self, key: str) -> bool:
Expand Down Expand Up @@ -142,4 +150,4 @@ def job_done(f: Future):
future.add_done_callback(job_done)

def status(self) -> Dict[str, Tuple[bool, int]]:
return [(job.key, job.future.done(), job.get_progress()) for job in self.jobs]
return [(job.key, job.future.done(), job.get_progress()) for job in self.jobs]
12 changes: 6 additions & 6 deletions api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run_txt2img_pipeline(
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler, device=device)
params.model, device.provider, params.scheduler, device=device.torch_device())

latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
Expand Down Expand Up @@ -92,7 +92,7 @@ def run_img2img_pipeline(
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler, device=device)
params.model, device.provider, params.scheduler, device=device.torch_device())

rng = np.random.RandomState(params.seed)

Expand Down Expand Up @@ -137,8 +137,8 @@ def run_inpaint_pipeline(
strength: float,
fill_color: str,
) -> None:
device = job.get_device()
progress = job.get_progress_callback()
# device = job.get_device()
# progress = job.get_progress_callback()
stage = StageParams()

# TODO: pass device, progress
Expand Down Expand Up @@ -182,8 +182,8 @@ def run_upscale_pipeline(
upscale: UpscaleParams,
source_image: Image.Image,
) -> None:
device = job.get_device()
progress = job.get_progress_callback()
# device = job.get_device()
# progress = job.get_progress_callback()
stage = StageParams()

# TODO: pass device, progress
Expand Down
13 changes: 13 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def tojson(self) -> Dict[str, int]:
}


class DeviceParams:
def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None:
self.device = device
self.provider = provider
self.options = options

def torch_device(self) -> str:
if self.device.startswith('cuda'):
return self.device
else:
return 'cpu'


class ImageParams:
def __init__(
self,
Expand Down
21 changes: 15 additions & 6 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
)
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
from flask_cors import CORS
from flask_executor import Executor
from glob import glob
from io import BytesIO
from jsonschema import validate
from logging import getLogger
from PIL import Image
from onnxruntime import get_available_providers
from os import makedirs, path
from typing import Tuple
from typing import List, Tuple


from .chain import (
Expand Down Expand Up @@ -69,6 +68,7 @@
)
from .params import (
Border,
DeviceParams,
ImageParams,
Size,
StageParams,
Expand All @@ -88,6 +88,7 @@

import gc
import numpy as np
import torch
import yaml

logger = getLogger(__name__)
Expand Down Expand Up @@ -147,7 +148,7 @@
}

# Available ORT providers
available_platforms = []
available_platforms: List[DeviceParams] = []

# loaded from model_path
diffusion_models = []
Expand Down Expand Up @@ -310,8 +311,16 @@ def load_platforms():
global available_platforms

providers = get_available_providers()
available_platforms = [p for p in platform_providers if (
platform_providers[p] in providers and p not in context.block_platforms)]

for potential in platform_providers:
if platform_providers[potential] in providers and potential not in context.block_platforms:
if potential == 'cuda':
for i in range(torch.cuda.device_count()):
available_platforms.append(DeviceParams('%s:%s' % (potential, i), providers[potential], {
'device_id': i,
}))
else:
available_platforms.append(DeviceParams(potential, providers[potential]))

logger.info('available acceleration platforms: %s', available_platforms)

Expand Down Expand Up @@ -404,7 +413,7 @@ def list_params():

@app.route('/api/settings/platforms')
def list_platforms():
return jsonify(list(available_platforms))
return jsonify([p.device for p in available_platforms])


@app.route('/api/settings/schedulers')
Expand Down
1 change: 1 addition & 0 deletions onnx-web.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"randn",
"realesr",
"resrgan",
"rocm",
"RRDB",
"runwayml",
"scandir",
Expand Down

0 comments on commit 98b6e4d

Please sign in to comment.