Skip to content

Commit

Permalink
feat(api): add option to use any available platform
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 11, 2023
1 parent 917f5be commit ea3b065
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
12 changes: 9 additions & 3 deletions api/onnx_web/device_pool.py
Expand Up @@ -156,7 +156,13 @@ def done(self, key: str) -> Tuple[Optional[bool], int]:
logger.warn("checking status for unknown key: %s", key)
return (None, 0)

def get_next_device(self):
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
for i in self.devices:
if self.devices[i].device == needs_device.device:
return i

# use the first/default device if there are no jobs
if len(self.jobs) == 0:
return 0
Expand All @@ -179,8 +185,8 @@ def get_next_device(self):
def prune(self):
self.jobs[:] = [job for job in self.jobs if job.future.done()]

def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
device = self.get_next_device()
def submit(self, key: str, fn: Callable[..., None], /, *args, needs_device: Optional[DeviceParams] = None, **kwargs) -> None:
device = self.get_next_device(needs_device=needs_device)
logger.info("assigning job %s to device %s", key, device)

context = JobContext(key, self.devices, device_index=device)
Expand Down
40 changes: 26 additions & 14 deletions api/onnx_web/serve.py
Expand Up @@ -157,16 +157,13 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr

# platform stuff
device_name = request.args.get("platform", available_platforms[0].device)
device = None
device_name = request.args.get("platform")

for platform in available_platforms:
if platform.device == device_name:
device = available_platforms[0]

if device is None:
logger.warn("unknown platform: %s", device_name)
device = available_platforms[0]
if device_name is not None and device_name != "any":
for platform in available_platforms:
if platform.device == device_name:
device = available_platforms[0]

# pipeline stuff
lpw = get_not_empty(request.args, "lpw", "false") == "true"
Expand Down Expand Up @@ -223,7 +220,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
steps,
scheduler.__name__,
model_path,
device.provider,
device or "any device",
width,
height,
cfg,
Expand Down Expand Up @@ -368,16 +365,21 @@ def load_platforms():
)

# make sure CPU is last on the list
def cpu_last(a: DeviceParams, b: DeviceParams):
if a.device == "cpu" and b.device == "cpu":
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
if a.device == b.device:
return 0

# any should be first, if it's available
if a.device == "any":
return -1

# cpu should be last, if it's available
if a.device == "cpu":
return 1

return -1

available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
available_platforms = sorted(available_platforms, key=cmp_to_key(any_first_cpu_last))

logger.info(
"available acceleration platforms: %s",
Expand Down Expand Up @@ -521,6 +523,7 @@ def img2img():
upscale,
source_image,
strength,
needs_device=device,
)

return jsonify(json_params(output, params, size, upscale=upscale))
Expand All @@ -535,7 +538,14 @@ def txt2img():
logger.info("txt2img job queued for: %s", output)

executor.submit(
output, run_txt2img_pipeline, context, params, size, output, upscale
output,
run_txt2img_pipeline,
context,
params,
size,
output,
upscale,
needs_device=device,
)

return jsonify(json_params(output, params, size, upscale=upscale))
Expand Down Expand Up @@ -605,6 +615,7 @@ def inpaint():
mask_filter,
strength,
fill_color,
needs_device=device,
)

return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
Expand Down Expand Up @@ -634,6 +645,7 @@ def upscale():
output,
upscale,
source_image,
needs_device=device,
)

return jsonify(json_params(output, params, size, upscale=upscale))
Expand Down Expand Up @@ -711,7 +723,7 @@ def chain():
# build and run chain pipeline
empty_source = Image.new("RGB", (size.width, size.height))
executor.submit(
output, pipeline, context, params, empty_source, output=output, size=size
output, pipeline, context, params, empty_source, output=output, size=size, needs_device=device
)

return jsonify(json_params(output, params, size))
Expand Down
2 changes: 2 additions & 0 deletions gui/src/strings.ts
Expand Up @@ -32,6 +32,8 @@ export const MODEL_LABELS = {

export const PLATFORM_LABELS: Record<string, string> = {
amd: 'AMD GPU',
// eslint-disable-next-line id-blacklist
any: 'Any Platform',
cpu: 'CPU',
cuda: 'CUDA',
directml: 'DirectML',
Expand Down

0 comments on commit ea3b065

Please sign in to comment.