Skip to content

Commit

Permalink
feat(api): move txt2img into a background task
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 13, 2023
1 parent d1d079d commit 0ef4d60
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 22 deletions.
57 changes: 36 additions & 21 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DiffusionPipeline,
)
from flask import Flask, jsonify, request, send_from_directory, url_for
from flask_executor import Executor
from hashlib import sha256
from io import BytesIO
from PIL import Image
Expand Down Expand Up @@ -197,11 +198,14 @@ def pipeline_from_request(pipeline: DiffusionPipeline):
if negative_prompt is not None and negative_prompt.strip() == '':
negative_prompt = None

cfg = get_and_clamp_float(request.args, 'cfg', default_cfg, config_params.get('cfg').get('max'), 0)
steps = get_and_clamp_int(request.args, 'steps', default_steps, config_params.get('steps').get('max'))
cfg = get_and_clamp_float(
request.args, 'cfg', default_cfg, config_params.get('cfg').get('max'), 0)
steps = get_and_clamp_int(
request.args, 'steps', default_steps, config_params.get('steps').get('max'))
height = get_and_clamp_int(
request.args, 'height', default_height, config_params.get('height').get('max'))
width = get_and_clamp_int(request.args, 'width', default_width, config_params.get('width').get('max'))
width = get_and_clamp_int(
request.args, 'width', default_width, config_params.get('width').get('max'))

seed = int(request.args.get('seed', -1))
if seed == -1:
Expand All @@ -210,8 +214,30 @@ def pipeline_from_request(pipeline: DiffusionPipeline):
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))

# pipe = load_pipeline(pipeline, model, provider, scheduler)
# , pipe)
return (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed)


def run_txt2img_pipeline(pipeline, model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed, output):
pipe = load_pipeline(pipeline, model, provider, scheduler)
return (model, provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed, pipe)

latents = get_latents_from_seed(seed, width, height)
rng = np.random.RandomState(seed)

image = pipe(
prompt,
height,
width,
generator=rng,
guidance_scale=cfg,
latents=latents,
negative_prompt=negative_prompt,
num_inference_steps=steps,
).images[0]
image.save(output)

print('saved txt2img output: %s' % (output))


# setup
Expand Down Expand Up @@ -240,6 +266,7 @@ def load_params():
load_models()
load_params()
app = Flask(__name__)
executor = Executor(app)

# routes

Expand Down Expand Up @@ -284,7 +311,7 @@ def img2img():
strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0)

(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)
width, seed) = pipeline_from_request(OnnxStableDiffusionImg2ImgPipeline)

rng = np.random.RandomState(seed)
image = pipe(
Expand Down Expand Up @@ -322,26 +349,14 @@ def img2img():
@app.route('/txt2img', methods=['POST'])
def txt2img():
(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
width, seed, pipe) = pipeline_from_request(OnnxStableDiffusionPipeline)

latents = get_latents_from_seed(seed, width, height)
rng = np.random.RandomState(seed)

image = pipe(
prompt,
height,
width,
generator=rng,
guidance_scale=cfg,
latents=latents,
negative_prompt=negative_prompt,
num_inference_steps=steps,
).images[0]
width, seed) = pipeline_from_request(OnnxStableDiffusionPipeline)

(output_file, output_full) = make_output_path('txt2img',
seed, (prompt, cfg, negative_prompt, steps, height, width))
print("txt2img output: %s" % output_full)
image.save(output_full)

executor.submit(run_txt2img_pipeline, OnnxStableDiffusionPipeline, model,
provider, scheduler, prompt, negative_prompt, cfg, steps, height, width, seed, output_full)

return json_with_cors({
'output': output_file,
Expand Down
3 changes: 2 additions & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ protobuf<4,>=3.20.2
transformers

### Server packages ###
flask
flask
flask_executor

0 comments on commit 0ef4d60

Please sign in to comment.