Skip to content

Commit

Permalink
fix(api): continue adding tests, fix bugs encountered
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 28, 2023
1 parent 898d76e commit 047e58c
Show file tree
Hide file tree
Showing 16 changed files with 526 additions and 52 deletions.
22 changes: 11 additions & 11 deletions api/onnx_web/convert/utils.py
Expand Up @@ -133,9 +133,9 @@ def tuple_to_source(model: Union[ModelDict, LegacyModel]):
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
scale = rest.pop(0) if len(rest) > 0 else 1
half = rest.pop(0) if len(rest) > 0 else False
opset = rest.pop(0) if len(rest) > 0 else None

return {
"name": name,
Expand All @@ -151,9 +151,9 @@ def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
single_vae = rest[0] if len(rest) > 0 else False
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
single_vae = rest.pop(0) if len(rest) > 0 else False
half = rest.pop(0) if len(rest) > 0 else False
opset = rest.pop(0) if len(rest) > 0 else None

return {
"name": name,
Expand All @@ -169,9 +169,9 @@ def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
scale = rest.pop(0) if len(rest) > 0 else 1
half = rest.pop(0) if len(rest) > 0 else False
opset = rest.pop(0) if len(rest) > 0 else None

return {
"name": name,
Expand Down Expand Up @@ -298,6 +298,7 @@ def onnx_export(
half=False,
external_data=False,
v2=False,
op_block_list=None,
):
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Expand All @@ -316,8 +317,7 @@ def onnx_export(
opset_version=opset,
)

op_block_list = None
if v2:
if v2 and op_block_list is None:
op_block_list = ["Attention", "MultiHeadAttention"]

if half:
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/diffusers/run.py
Expand Up @@ -97,6 +97,7 @@ def run_txt2img_pipeline(
_pairs, loras, inversions, _rest = parse_prompt(params)

for image, output in zip(images, outputs):
logger.trace("saving output image %s: %s", output, image.size)
dest = save_image(
server,
output,
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/image/source_filter.py
Expand Up @@ -47,7 +47,7 @@ def source_filter_noise(
source: Image.Image,
strength: float = 0.5,
):
noise = noise_source_histogram(source, source.size)
noise = noise_source_histogram(source, source.size, (0, 0))
return ImageChops.blend(source, noise, strength)


Expand Down
7 changes: 5 additions & 2 deletions api/onnx_web/worker/context.py
Expand Up @@ -25,6 +25,7 @@ class WorkerContext:
idle: "Value[bool]"
timeout: float
retries: int
initial_retries: int

def __init__(
self,
Expand All @@ -37,6 +38,7 @@ def __init__(
active_pid: "Value[int]",
idle: "Value[bool]",
retries: int,
timeout: float,
):
self.job = None
self.name = name
Expand All @@ -48,12 +50,13 @@ def __init__(
self.active_pid = active_pid
self.last_progress = None
self.idle = idle
self.initial_retries = retries
self.retries = retries
self.timeout = 1.0
self.timeout = timeout

def start(self, job: str) -> None:
self.job = job
self.retries = 3
self.retries = self.initial_retries
self.set_cancel(cancel=False)
self.set_idle(idle=False)

Expand Down
11 changes: 6 additions & 5 deletions api/onnx_web/worker/pool.py
Expand Up @@ -86,15 +86,15 @@ def __init__(
self.logs = Queue(self.max_pending_per_worker)
self.rlock = Lock()

def start(self) -> None:
def start(self, *args) -> None:
self.create_health_worker()
self.create_logger_worker()
self.create_progress_worker()

for device in self.devices:
self.create_device_worker(device)
self.create_device_worker(device, *args)

def create_device_worker(self, device: DeviceParams) -> None:
def create_device_worker(self, device: DeviceParams, *args) -> None:
name = device.device

# always recreate queues
Expand Down Expand Up @@ -125,15 +125,16 @@ def create_device_worker(self, device: DeviceParams) -> None:
active_pid=current,
idle=self.worker_idle[name],
retries=self.server.worker_retries,
timeout=self.progress_interval,
)
self.context[name] = context

worker = Process(
name=f"onnx-web worker: {name}",
target=worker_main,
args=(context, self.server),
args=(context, self.server, *args),
daemon=True,
)
worker.daemon = True
self.workers[name] = worker

logger.debug("starting worker for device %s", device)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/worker/worker.py
Expand Up @@ -27,7 +27,7 @@
]


def worker_main(worker: WorkerContext, server: ServerContext):
def worker_main(worker: WorkerContext, server: ServerContext, *args):
apply_patches(server)
setproctitle("onnx-web worker: %s" % (worker.device.device))

Expand Down
30 changes: 30 additions & 0 deletions api/tests/chain/test_tile.py
Expand Up @@ -7,6 +7,8 @@
generate_tile_spiral,
get_tile_grads,
needs_tile,
process_tile_grid,
process_tile_spiral,
)
from onnx_web.params import Size

Expand Down Expand Up @@ -95,3 +97,31 @@ def test_spiral_50_overlap(self):
self.assertEqual(len(tiles), 225)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])


class TestProcessTileGrid(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_grid(source, 32, 1, [])

self.assertEqual(blend.size, (64, 64))

def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_grid(source, 32, 1, [])

self.assertEqual(blend.size, (72, 72))


class TestProcessTileSpiral(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_spiral(source, 32, 1, [])

self.assertEqual(blend.size, (64, 64))

def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_spiral(source, 32, 1, [])

self.assertEqual(blend.size, (72, 72))
167 changes: 166 additions & 1 deletion api/tests/convert/test_utils.py
@@ -1,6 +1,14 @@
import unittest

from onnx_web.convert.utils import DEFAULT_OPSET, ConversionContext, download_progress
from onnx_web.convert.utils import (
DEFAULT_OPSET,
ConversionContext,
download_progress,
tuple_to_correction,
tuple_to_diffusion,
tuple_to_source,
tuple_to_upscaling,
)


class ConversionContextTests(unittest.TestCase):
Expand All @@ -17,3 +25,160 @@ class DownloadProgressTests(unittest.TestCase):
def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
self.assertEqual(path, "/tmp/example-dot-com")


class TupleToSourceTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_source(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_basic_list(self):
source = tuple_to_source(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_basic_dict(self):
source = tuple_to_source(["foo", "bar"])
source["bin"] = "bin"

# make sure this is returned as-is with extra fields
second = tuple_to_source(source)

self.assertEqual(source, second)
self.assertIn("bin", second)


class TupleToCorrectionTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_correction(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_basic_list(self):
source = tuple_to_correction(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_basic_dict(self):
source = tuple_to_correction(["foo", "bar"])
source["bin"] = "bin"

# make sure this is returned with extra fields
second = tuple_to_source(source)

self.assertEqual(source, second)
self.assertIn("bin", second)

def test_scale_tuple(self):
source = tuple_to_correction(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_half_tuple(self):
source = tuple_to_correction(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_opset_tuple(self):
source = tuple_to_correction(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_all_tuple(self):
source = tuple_to_correction(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)


class TupleToDiffusionTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_diffusion(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_basic_list(self):
source = tuple_to_diffusion(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_basic_dict(self):
source = tuple_to_diffusion(["foo", "bar"])
source["bin"] = "bin"

# make sure this is returned with extra fields
second = tuple_to_diffusion(source)

self.assertEqual(source, second)
self.assertIn("bin", second)

def test_single_vae_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_half_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_opset_tuple(self):
source = tuple_to_diffusion(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_all_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["single_vae"], True)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)


class TupleToUpscalingTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_upscaling(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_basic_list(self):
source = tuple_to_upscaling(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_basic_dict(self):
source = tuple_to_upscaling(["foo", "bar"])
source["bin"] = "bin"

# make sure this is returned with extra fields
second = tuple_to_source(source)

self.assertEqual(source, second)
self.assertIn("bin", second)

def test_scale_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_half_tuple(self):
source = tuple_to_upscaling(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_opset_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")

def test_all_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
17 changes: 12 additions & 5 deletions api/tests/helpers.py
@@ -1,9 +1,16 @@
from os import path
from typing import List
from unittest import skipUnless

from onnx_web.params import DeviceParams

def test_with_models(models: List[str]):
def wrapper(func):
# TODO: check if models exist
return func

return wrapper
def test_needs_models(models: List[str]):
return skipUnless(all([path.exists(model) for model in models]), "model does not exist")


def test_device() -> DeviceParams:
return DeviceParams("cpu", "CPUExecutionProvider")


TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
Empty file added api/tests/image/__init__.py
Empty file.

0 comments on commit 047e58c

Please sign in to comment.