Skip to content

Commit

Permalink
fix(api): remove prompt from output name
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 8, 2023
1 parent f4ca6a0 commit 0d4c0a5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
18 changes: 9 additions & 9 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from hashlib import sha256
from io import BytesIO
from PIL import Image
from stringcase import spinalcase
from struct import pack
from os import environ, makedirs, path, scandir
from typing import Tuple, Union
import numpy as np

# defaults
Expand Down Expand Up @@ -74,23 +74,23 @@
}


def get_and_clamp_float(args, key, default_value, max_value, min_value=0.0):
def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0):
return min(max(float(args.get(key, default_value)), min_value), max_value)


def get_and_clamp_int(args, key, default_value, max_value, min_value=1):
def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1):
return min(max(int(args.get(key, default_value)), min_value), max_value)


def get_from_map(args, key, values, default):
def get_from_map(args, key: str, values, default):
selected = args.get(key, default)
if selected in values:
return values[selected]
else:
return values[default]


def get_model_path(model):
def get_model_path(model: str):
return safer_join(model_path, model)


Expand All @@ -104,7 +104,7 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
return image_latents


def load_pipeline(pipeline, model, provider, scheduler):
def load_pipeline(pipeline, model: str, provider: str, scheduler):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
Expand Down Expand Up @@ -141,9 +141,9 @@ def json_with_cors(data, origin='*'):
return res


def make_output_path(type, params):
def make_output_path(type: str, params: Tuple[Union[str, int, float]]):
sha = sha256()
sha.update(type)
sha.update(type.encode('utf-8'))
for param in params:
if isinstance(param, str):
sha.update(param.encode('utf-8'))
Expand All @@ -154,7 +154,7 @@ def make_output_path(type, params):
else:
print('cannot hash param: %s, %s' % (param, type(param)))

output_file = 'txt2img_%s_%s.png' % (params[0], sha.hexdigest())
output_file = '%s_%s.png' % (type, sha.hexdigest())
output_full = safer_join(output_path, output_file)

return (output_file, output_full)
Expand Down
3 changes: 1 addition & 2 deletions api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ protobuf<4,>=3.20.2
transformers

### Server packages ###
flask
stringcase
flask

0 comments on commit 0d4c0a5

Please sign in to comment.