Skip to content

Commit

Permalink
feat(api): add img2img endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 7, 2023
1 parent 9973bf1 commit 09ce654
Showing 1 changed file with 79 additions and 21 deletions.
100 changes: 79 additions & 21 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from diffusers import OnnxStableDiffusionPipeline
from diffusers import (
# schedulers
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
# onnx
OnnxStableDiffusionPipeline,
OnnxStableDiffusionImg2ImgPipeline,
)
from flask import Flask, jsonify, request, send_from_directory, url_for
from io import BytesIO
from PIL import Image
from stringcase import spinalcase
from os import environ, makedirs, path, scandir
import numpy as np

# defaults
default_model = "stable-diffusion-onnx-v1-5"
default_platform = "amd"
default_scheduler = "euler-a"
default_model = 'stable-diffusion-onnx-v1-5'
default_platform = 'amd'
default_scheduler = 'euler-a'
default_prompt = "a photo of an astronaut eating a hamburger"
default_cfg = 8
default_steps = 20
Expand All @@ -29,14 +34,14 @@
max_width = 512

# paths
model_path = environ.get('ONNX_WEB_MODEL_PATH', "../models")
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', "../outputs")
model_path = environ.get('ONNX_WEB_MODEL_PATH', '../models')
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', '../outputs')


# pipeline caching
available_models = []
last_pipeline_instance = None
last_pipeline_options = (None, None)
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None

# pipeline params
Expand Down Expand Up @@ -68,6 +73,10 @@ def get_from_map(args, key, values, default):
return values[default]


def get_model_path(model):
return safer_join(model_path, model)


# from https://www.travelneil.com/stable-diffusion-updates.html
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
# 1 is batch size
Expand All @@ -78,22 +87,23 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
return image_latents


def load_pipeline(model, provider, scheduler):
def load_pipeline(pipeline, model, provider, scheduler):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options

options = (model, provider)
options = (pipeline, model, provider)
if last_pipeline_instance != None and last_pipeline_options == options:
print('reusing existing pipeline')
pipe = last_pipeline_instance
else:
print('loading different pipeline')
pipe = OnnxStableDiffusionPipeline.from_pretrained(
# pipe = OnnxStableDiffusionPipeline.from_pretrained(
pipe = pipeline.from_pretrained(
model,
provider=provider,
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder="scheduler")
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
)
last_pipeline_instance = pipe
last_pipeline_options = options
Expand All @@ -102,7 +112,7 @@ def load_pipeline(model, provider, scheduler):
if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler')
pipe.scheduler = scheduler.from_pretrained(
model, subfolder="scheduler")
model, subfolder='scheduler')
last_pipeline_scheduler = scheduler

return pipe
Expand Down Expand Up @@ -141,6 +151,8 @@ def check_paths():
def load_models():
global available_models
available_models = [f.name for f in scandir(model_path) if f.is_dir()]
load_pipeline(OnnxStableDiffusionPipeline, get_model_path(available_models[0]), platform_providers.get(
default_platform), pipeline_schedulers.get(default_scheduler))


check_paths()
Expand Down Expand Up @@ -176,12 +188,11 @@ def list_schedulers():
return json_with_cors(list(pipeline_schedulers.keys()))


@app.route('/txt2img')
def txt2img():
def pipeline_from_request(pipeline):
user = request.remote_addr

# pipeline stuff
model = safer_join(model_path, request.args.get('model', default_model))
model = get_model_path(request.args.get('model', default_model))
provider = get_from_map(request.args, 'platform',
platform_providers, default_platform)
scheduler = get_from_map(request.args, 'scheduler',
Expand All @@ -198,12 +209,59 @@ def txt2img():
if seed == -1:
seed = np.random.randint(np.iinfo(np.int32).max)

latents = get_latents_from_seed(seed, width, height)

print("txt2img from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))

pipe = load_pipeline(model, provider, scheduler)
pipe = load_pipeline(pipeline, model, provider, scheduler)
return (model, provider, scheduler, prompt, cfg, steps, height, width, seed, pipe)


@app.route('/img2img', methods=['POST'])
def img2img():
input_file = request.files.get('source')
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
input_image.thumbnail((default_width, default_height))

strength = get_and_clamp(request.args, 'strength', 1.0, 1.0, 0.0)
(model, provider, scheduler, prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)

image = pipe(
prompt=prompt,
image=input_image,
num_inference_steps=steps,
guidance_scale=cfg,
strength=strength,
).images[0]

output_file = 'img2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
output_full = safer_join(output_path, output_file)
print("img2img output: %s" % output_full)
image.save(output_full)

return json_with_cors({
'output': output_file,
'params': {
'model': model,
'provider': provider,
'scheduler': scheduler.__name__,
'cfg': cfg,
'steps': steps,
'height': default_height,
'width': default_width,
'prompt': prompt,
'seed': seed,
}
})


@app.route('/txt2img', methods=['POST'])
def txt2img():
(model, provider, scheduler, prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline)

latents = get_latents_from_seed(seed, width, height)

image = pipe(
prompt,
height,
Expand All @@ -213,7 +271,7 @@ def txt2img():
latents=latents
).images[0]

output_file = "txt2img_%s_%s.png" % (seed, spinalcase(prompt[0:64]))
output_file = 'txt2img_%s_%s.png' % (seed, spinalcase(prompt[0:64]))
output_full = safer_join(output_path, output_file)
print("txt2img output: %s" % output_full)
image.save(output_full)
Expand All @@ -229,7 +287,7 @@ def txt2img():
'height': height,
'width': width,
'prompt': prompt,
'seed': seed
'seed': seed,
}
})

Expand Down

0 comments on commit 09ce654

Please sign in to comment.