Skip to content

Commit

Permalink
feat(api): attempt to calculate total steps for chain pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 12, 2023
1 parent 4ccdedb commit 55ddb9f
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 12 deletions.
10 changes: 9 additions & 1 deletion api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..errors import RetryException
from ..output import save_image
from ..params import ImageParams, StageParams
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
Expand Down Expand Up @@ -85,6 +85,14 @@ def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self

def steps(self, params: ImageParams, size: Size):
steps = 0
for callback, _params, _kwargs in self.stages:
steps += callback.steps(params, size)

return steps


def __call__(
self,
worker: WorkerContext,
Expand Down
7 changes: 7 additions & 0 deletions api/onnx_web/chain/blend_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ def run(
output.paste(sources[n], (x * size[0], y * size[1]))

return [*sources, output]

def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1
14 changes: 14 additions & 0 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,17 @@ def run(
outputs.extend(result.images)

return outputs

def steps(
self,
params: ImageParams,
*args,
) -> int:
return params.steps # TODO: multiply by strength

def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1
7 changes: 7 additions & 0 deletions api/onnx_web/chain/source_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@ def run(
outputs.append(output)

return outputs

def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1
7 changes: 7 additions & 0 deletions api/onnx_web/chain/source_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,10 @@ def run(
logger.exception("error loading image from S3")

return outputs

def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1 # TODO: len(source_keys)
14 changes: 14 additions & 0 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,17 @@ def run(
output = list(sources)
output.extend(result.images)
return output

def steps(
self,
params: ImageParams,
size: Size,
) -> int:
return params.steps

def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1
7 changes: 7 additions & 0 deletions api/onnx_web/chain/source_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@ def run(
outputs.append(output)

return outputs

def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1
11 changes: 9 additions & 2 deletions api/onnx_web/chain/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@ def run(

def steps(
self,
_params: ImageParams,
params: ImageParams,
size: Size,
) -> int:
raise NotImplementedError()
return 1

def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources
1 change: 1 addition & 0 deletions api/onnx_web/chain/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def needs_tile(
source: Optional[Image.Image] = None,
) -> bool:
tile = min(max_tile, stage_tile)
logger.debug("")

if source is not None:
return source.width > tile or source.height > tile
Expand Down
10 changes: 6 additions & 4 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,15 @@ def load_pipeline(
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components["vae_decoder_session"]._model_path = vae_decoder
components["vae_decoder_session"]._model_path = vae_decoder # "#\\not a real path on any system"

logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components["vae_encoder_session"]._model_path = vae_encoder
components["vae_encoder_session"]._model_path = vae_encoder # "#\\not a real path on any system"

else:
logger.debug("loading VAE decoder from %s", vae_decoder)
Expand Down Expand Up @@ -439,12 +439,14 @@ def load_pipeline(

if "vae_decoder_session" in components:
pipe.vae_decoder = ORTModelVaeDecoder(
components["vae_decoder_session"], vae_decoder
components["vae_decoder_session"],
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
)

if "vae_encoder_session" in components:
pipe.vae_encoder = ORTModelVaeEncoder(
components["vae_encoder_session"], vae_encoder
components["vae_encoder_session"],
pipe, # TODO: find the right class to provide here. ORTModel is missing the dict json method
)

if not server.show_progress:
Expand Down
11 changes: 6 additions & 5 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
validate(data, schema)

# get defaults from the regular parameters
device, _params, _size = pipeline_from_request(server, data=data)
device, base_params, base_size = pipeline_from_request(server, data=data)
pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")]
Expand Down Expand Up @@ -450,22 +450,23 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):

logger.info("running chain pipeline with %s stages", len(pipeline.stages))

output = make_output_name(server, "chain", params, size, count=len(pipeline.stages))
output = make_output_name(server, "chain", base_params, base_size, count=len(pipeline.stages))
job_name = output[0]

# build and run chain pipeline
pool.submit(
job_name,
pipeline,
server,
params,
base_params,
[],
output=output,
size=size,
size=base_size,
needs_device=device,
)

return jsonify(json_params(output, params, size))
step_params = params.with_args(steps=pipeline.steps(base_params, base_size))
return jsonify(json_params(output, step_params, base_size))


def blend(server: ServerContext, pool: DevicePoolExecutor):
Expand Down

0 comments on commit 55ddb9f

Please sign in to comment.