Skip to content

Commit

Permalink
feat(api): write model hashes to image exif
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jun 26, 2023
1 parent 003a350 commit 062b1c4
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 16 deletions.
17 changes: 14 additions & 3 deletions api/onnx_web/convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
return model


model_formats = ["onnx", "pth", "ckpt", "safetensors"]
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"]


def source_format(model: Dict) -> Optional[str]:
Expand All @@ -192,7 +193,7 @@ def source_format(model: Dict) -> Optional[str]:

if "source" in model:
_name, ext = path.splitext(model["source"])
if ext in model_formats:
if ext in MODEL_FORMATS:
return ext

return None
Expand Down Expand Up @@ -231,7 +232,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
checkpoint = torch.load(name, map_location=map_location)
else:
logger.debug("searching for tensors with known extensions")
for next_extension in ["safetensors", "ckpt", "pt", "bin"]:
for next_extension in RESOLVE_FORMATS:
next_name = f"{name}.{next_extension}"
if path.exists(next_name):
checkpoint = load_tensor(next_name, map_location=map_location)
Expand Down Expand Up @@ -275,6 +276,16 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
return checkpoint


def resolve_tensor(name: str) -> Optional[str]:
logger.debug("searching for tensors with known extensions: %s", name)
for next_extension in RESOLVE_FORMATS:
next_name = f"{name}.{next_extension}"
if path.exists(next_name):
return next_name

return None


def onnx_export(
model,
model_args: tuple,
Expand Down
77 changes: 65 additions & 12 deletions api/onnx_web/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,30 @@
from piexif.helper import UserComment
from PIL import Image, PngImagePlugin

from onnx_web.convert.utils import resolve_tensor
from onnx_web.server.load import get_extra_hashes

from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
from .server import ServerContext
from .utils import base_join

logger = getLogger(__name__)

HASH_BUFFER_SIZE = 2**22 # 4MB


def hash_file(name: str):
sha = sha256()
with open(name, "rb") as f:
while True:
data = f.read(HASH_BUFFER_SIZE)
if not data:
break

sha.update(data)

return sha.hexdigest()


def hash_value(sha, param: Optional[Param]):
if param is None:
Expand Down Expand Up @@ -68,24 +86,57 @@ def json_params(


def str_params(
server: ServerContext,
params: ImageParams,
size: Size,
inversions: List[Tuple[str, float]] = None,
loras: List[Tuple[str, float]] = None,
) -> str:
lora_hashes = (
",".join([f"{name}: TODO" for name, weight in loras])
if loras is not None
else ""
)
model_hash = get_extra_hashes().get(params.model, "unknown")
model_name = path.basename(path.normpath(params.model))
hash_map = {
model_name: model_hash,
}

inversion_hashes = ""
if inversions is not None:
inversion_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper(),
)
for name, _weight in inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))

lora_hashes = ""
if loras is not None:
lora_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper(),
)
for name, _weight in loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))

return (
f"{params.input_prompt}.\nNegative prompt: {params.input_negative_prompt}.\n"
f"{params.input_prompt}\nNegative prompt: {params.input_negative_prompt}\n"
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
f"Model hash: TODO, Model: {params.model}, "
f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", '
f"Version: TODO, Tool: onnx-web"
f"Hashes: {dumps(hash_map)}"
)


Expand Down Expand Up @@ -157,10 +208,10 @@ def save_image(
)
),
)
exif.add_text("model", "TODO: server.version")
exif.add_text("model", server.server_version)
exif.add_text(
"parameters",
str_params(params, size, inversions=inversions, loras=loras),
str_params(server, params, size, inversions=inversions, loras=loras),
)

image.save(path, format=server.image_format, pnginfo=exif)
Expand All @@ -182,11 +233,13 @@ def save_image(
encoding="unicode",
),
ExifIFD.UserComment: UserComment.dump(
str_params(params, size, inversions=inversions, loras=loras),
str_params(
server, params, size, inversions=inversions, loras=loras
),
encoding="unicode",
),
ImageIFD.Make: "onnx-web",
ImageIFD.Model: "TODO: server.version",
ImageIFD.Model: server.server_version,
}
}
)
Expand Down
18 changes: 17 additions & 1 deletion api/onnx_web/server/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
upscaling_models: List[str] = []

# Loaded from extra_models
extra_hashes: Dict[str, str] = {}
extra_strings: Dict[str, Any] = {}


Expand Down Expand Up @@ -123,6 +124,10 @@ def get_extra_strings():
return extra_strings


def get_extra_hashes():
return extra_hashes


def get_highres_methods():
return highres_methods

Expand All @@ -147,6 +152,7 @@ def load_extras(server: ServerContext):
"""
Load the extras file(s) and collect the relevant parts for the server: labels and strings
"""
global extra_hashes
global extra_strings

labels = {}
Expand All @@ -173,8 +179,18 @@ def load_extras(server: ServerContext):
for model_type in ["diffusion", "correction", "upscaling", "networks"]:
if model_type in data:
for model in data[model_type]:
model_name = model["name"]

if "hash" in model:
logger.debug(
"collecting hash for model %s from %s",
model_name,
file,
)

extra_hashes[model_name] = model["hash"]

if "label" in model:
model_name = model["name"]
logger.debug(
"collecting label for model %s from %s",
model_name,
Expand Down

0 comments on commit 062b1c4

Please sign in to comment.