Skip to content

Commit

Permalink
fix(scripts): update release tests with support for batches
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 26, 2023
1 parent cb8e9e7 commit 6809d2d
Show file tree
Hide file tree
Showing 19 changed files with 57 additions and 42 deletions.
99 changes: 57 additions & 42 deletions api/scripts/test-release.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

logger = getLogger(__name__)

FAST_TEST = 20
SLOW_TEST = 50


def test_root() -> str:
if len(sys.argv) > 1:
Expand All @@ -42,7 +45,7 @@ def __init__(
self,
name: str,
query: str,
max_attempts: int = 20,
max_attempts: int = FAST_TEST,
mse_threshold: float = 0.001,
source: Union[Image.Image, List[Image.Image]] = None,
mask: Image.Image = None,
Expand Down Expand Up @@ -95,23 +98,23 @@ def __init__(
TestCase(
"img2img-sd-v1-5-512-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
),
TestCase(
"img2img-sd-v1-5-256-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim",
source="txt2img-sd-v1-5-256-muffin",
source="txt2img-sd-v1-5-256-muffin-0",
),
TestCase(
"inpaint-v1-512-white",
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-white",
),
TestCase(
"inpaint-v1-512-black",
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
),
TestCase(
Expand All @@ -120,8 +123,9 @@ def __init__(
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=256&bottom=256&left=256&right=256"
),
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.025,
),
TestCase(
Expand All @@ -130,8 +134,9 @@ def __init__(
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=512&bottom=512&left=0&right=0"
),
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010,
),
TestCase(
Expand All @@ -140,42 +145,43 @@ def __init__(
"inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask"
"&top=0&bottom=0&left=512&right=512"
),
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010,
),
TestCase(
"upscale-resrgan-x2-1024-muffin",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
),
TestCase(
"upscale-resrgan-x4-2048-muffin",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4",
source="txt2img-sd-v1-5-512-muffin",
source="txt2img-sd-v1-5-512-muffin-0",
),
TestCase(
"blend-512-muffin-black",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
mask="mask-black",
source=[
"txt2img-sd-v1-5-512-muffin",
"txt2img-sd-v2-1-512-muffin",
"txt2img-sd-v1-5-512-muffin-0",
"txt2img-sd-v2-1-512-muffin-0",
],
),
TestCase(
"blend-512-muffin-white",
"upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2",
mask="mask-white",
source=[
"txt2img-sd-v2-1-512-muffin",
"txt2img-sd-v1-5-512-muffin",
"txt2img-sd-v2-1-512-muffin-0",
"txt2img-sd-v1-5-512-muffin-0",
],
),
]


def generate_image(root: str, test: TestCase) -> Optional[str]:
def generate_images(root: str, test: TestCase) -> Optional[str]:
files = {}
if test.source is not None:
if isinstance(test.source, list):
Expand Down Expand Up @@ -211,9 +217,9 @@ def generate_image(root: str, test: TestCase) -> Optional[str]:
resp = requests.post(f"{root}/api/{test.query}", files=files)
if resp.status_code == 200:
json = resp.json()
return json.get("output")
return json.get("outputs")
else:
logger.warning("request failed: %s", resp.status_code)
logger.warning("request failed: %s: %s", resp.status_code, resp.text)
return None


Expand All @@ -227,14 +233,17 @@ def check_ready(root: str, key: str) -> bool:
return False


def download_image(root: str, key: str) -> Image.Image:
resp = requests.get(f"{root}/output/{key}")
if resp.status_code == 200:
logger.debug("downloading image: %s", key)
return Image.open(BytesIO(resp.content))
else:
logger.warning("request failed: %s", resp.status_code)
return None
def download_images(root: str, keys: List[str]) -> List[Image.Image]:
images = []
for key in keys:
resp = requests.get(f"{root}/output/{key}")
if resp.status_code == 200:
logger.debug("downloading image: %s", key)
images.append(Image.open(BytesIO(resp.content)))
else:
logger.warning("request failed: %s", resp.status_code)

return images


def find_mse(result: Image.Image, ref: Image.Image) -> float:
Expand All @@ -259,20 +268,19 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float:
def run_test(
root: str,
test: TestCase,
ref: Image.Image,
) -> bool:
"""
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
"""

key = generate_image(root, test)
if key is None:
keys = generate_images(root, test)
if keys is None:
raise ValueError("could not generate")

attempts = 0
while attempts < test.max_attempts:
if check_ready(root, key):
logger.debug("image is ready: %s", key)
if check_ready(root, keys[0]):
logger.debug("image is ready: %s", keys)
break
else:
logger.debug("waiting for image to be ready")
Expand All @@ -282,16 +290,25 @@ def run_test(
if attempts == test.max_attempts:
raise ValueError("image was not ready in time")

result = download_image(root, key)
result.save(test_path(path.join("test-results", f"{test.name}.png")))
mse = find_mse(result, ref)
results = download_images(root, keys)

if mse < test.mse_threshold:
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
return True
else:
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
return False
passed = True
for i in range(len(results)):
result = results[i]
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))

ref_name = test_path(path.join("test-refs", f"{test.name}-{i}.png"))
ref = Image.open(ref_name) if path.exists(ref_name) else None

mse = find_mse(result, ref)

if mse < test.mse_threshold:
logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold)
else:
logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold)
passed = False

return passed


def main():
Expand All @@ -303,9 +320,7 @@ def main():
for test in TEST_DATA:
try:
logger.info("starting test: %s", test.name)
ref_name = test_path(path.join("test-refs", f"{test.name}.png"))
ref = Image.open(ref_name) if path.exists(ref_name) else None
if run_test(root, test, ref):
if run_test(root, test):
logger.info("test passed: %s", test.name)
passed.append(test.name)
else:
Expand Down

0 comments on commit 6809d2d

Please sign in to comment.