Skip to content

Commit

Permalink
feat(api): add a way to download models from civitai or other https s…
Browse files Browse the repository at this point in the history
…ources (#117)
  • Loading branch information
ssube committed Feb 11, 2023
1 parent b3e4076 commit 9f20248
Show file tree
Hide file tree
Showing 12 changed files with 488 additions and 365 deletions.
22 changes: 18 additions & 4 deletions api/extras.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
{
"diffusion": [
["diffusion-knollingcase", "Aybeeceedee/knollingcase"],
["diffusion-openjourney", "prompthero/openjourney"],
["diffusion-stably-diffused-onnx-v2-6", "../models/tensors/stablydiffuseds_26.safetensors"],
["diffusion-unstable-ink-dream-onnx-v6", "../models/tensors/unstableinkdream_v6.safetensors"]
{
"name": "diffusion-knollingcase",
"source": "Aybeeceedee/knollingcase"
},
{
"name": "diffusion-openjourney",
"source": "prompthero/openjourney"
},
{
"name": "diffusion-stablydiffused-aesthetic-v2-6",
"source": "civitai://6266?type=Pruned%20Model&format=SafeTensor",
"format": "safetensors"
},
{
"name": "diffusion-unstable-ink-dream-v6",
"source": "civitai://5796",
"format": "safetensors"
}
],
"correction": [],
"upscaling": []
Expand Down
150 changes: 116 additions & 34 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .upscale_resrgan import convert_upscale_resrgan
from .utils import ConversionContext

import warnings
from argparse import ArgumentParser
from json import loads
from logging import getLogger
from os import environ, makedirs, path
from sys import exit
from typing import Dict, List, Optional, Tuple
from yaml import safe_load
from jsonschema import validate, ValidationError

import torch

from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .upscale_resrgan import convert_upscale_resrgan
from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling

# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
warnings.filterwarnings(
"ignore", ".*The shape inference of prim::Constant type is missing.*"
Expand All @@ -29,20 +31,39 @@
logger = getLogger(__name__)


model_sources: Dict[str, Tuple[str, str]] = {
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
}

model_source_huggingface = "huggingface://"

# recommended models
base_models: Models = {
"diffusion": [
# v1.x
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
# v2.x
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
(
"stable-diffusion-onnx-v2-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
"stable-diffusion-onnx-v1-5",
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
),
# (
# "stable-diffusion-onnx-v1-inpainting",
# model_source_huggingface + "runwayml/stable-diffusion-inpainting",
# ),
# v2.x
# (
# "stable-diffusion-onnx-v2-1",
# model_source_huggingface + "stabilityai/stable-diffusion-2-1",
# ),
# (
# "stable-diffusion-onnx-v2-inpainting",
# model_source_huggingface + "stabilityai/stable-diffusion-2-inpainting",
# ),
# TODO: should have its own converter
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
(
"upscaling-stable-diffusion-x4",
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
True,
),
],
"correction": [
(
Expand Down Expand Up @@ -79,35 +100,86 @@
training_device = "cuda" if torch.cuda.is_available() else "cpu"


def load_models(args, ctx: ConversionContext, models: Models):
def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str:
cache_name = path.join(ctx.cache_path, name)
if format is not None:
# add an extension if possible, some of the conversion code checks for it
cache_name = "%s.%s" % (cache_name, format)

for proto in model_sources:
api_name, api_root = model_sources.get(proto)
if source.startswith(proto):
api_source = api_root % (source.removeprefix(proto))
logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name)
return download_progress([(api_source, cache_name)])

if source.startswith(model_source_huggingface):
hub_source = source.removeprefix(model_source_huggingface)
logger.info("Downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
return hub_source
elif source.startswith("https://"):
logger.info("Downloading model from: %s", source)
return download_progress([(source, cache_name)])
elif source.startswith("http://"):
logger.warning("Downloading model from insecure source: %s", source)
return download_progress([(source, cache_name)])
elif source.startswith(path.sep) or source.startswith("."):
logger.info("Using local model: %s", source)
return source
else:
logger.info("Unknown model location, using path as provided: %s", source)
return source


def convert_models(ctx: ConversionContext, args, models: Models):
if args.diffusion:
for source in models.get("diffusion"):
name, file = source
for model in models.get("diffusion"):
model = tuple_to_diffusion(model)
name = model.get("name")

if name in args.skip:
logger.info("Skipping model: %s", source[0])
logger.info("Skipping model: %s", name)
else:
if file.endswith(".safetensors") or file.endswith(".ckpt"):
convert_diffusion_original(ctx, *source, args.opset, args.half)
format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)

if format in ["safetensors", "ckpt"]:
convert_diffusion_original(
ctx,
model,
source,
)
else:
# TODO: make this a parameter in the JSON/dict
single_vae = "upscaling" in source[0]
convert_diffusion_stable(
ctx, *source, args.opset, args.half, args.token, single_vae=single_vae
ctx,
model,
source,
)

if args.upscaling:
for source in models.get("upscaling"):
if source[0] in args.skip:
logger.info("Skipping model: %s", source[0])
for model in models.get("upscaling"):
model = tuple_to_upscaling(model)
name = model.get("name")

if name in args.skip:
logger.info("Skipping model: %s", name)
else:
convert_upscale_resrgan(ctx, *source, args.opset)
format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
convert_upscale_resrgan(ctx, model, source)

if args.correction:
for source in models.get("correction"):
if source[0] in args.skip:
logger.info("Skipping model: %s", source[0])
for model in models.get("correction"):
model = tuple_to_correction(model)
name = model.get("name")

if name in args.skip:
logger.info("Skipping model: %s", name)
else:
convert_correction_gfpgan(ctx, *source, args.opset)
format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
convert_correction_gfpgan(ctx, model, source)


def main() -> int:
Expand Down Expand Up @@ -146,24 +218,34 @@ def main() -> int:
args = parser.parse_args()
logger.info("CLI arguments: %s", args)

ctx = ConversionContext(model_path, training_device)
ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)

if not path.exists(model_path):
logger.info("Model path does not existing, creating: %s", model_path)
makedirs(model_path)

logger.info("Converting base models.")
load_models(args, ctx, base_models)
convert_models(ctx, args, base_models)

for file in args.extras:
if file is not None and file != "":
logger.info("Loading extra models from %s", file)
try:
with open(file, "r") as f:
data = loads(f.read())
data = safe_load(f.read())

with open("./schemas/extras.yaml", "r") as f:
schema = safe_load(f.read())

logger.debug("validating chain request: %s against %s", data, schema)

try:
validate(data, schema)
logger.info("Converting extra models.")
load_models(args, ctx, data)
convert_models(ctx, args, data)
except ValidationError as err:
logger.error("Invalid data in extras file: %s", err)
except Exception as err:
logger.error("Error converting extra models: %s", err)

Expand Down
45 changes: 24 additions & 21 deletions api/onnx_web/convert/correction_gfpgan.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
import torch
from logging import getLogger
from os import path
from shutil import copyfile

import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export
from os import path
from logging import getLogger
from basicsr.archs.rrdbnet_arch import RRDBNet
from .utils import ConversionContext

from .utils import ConversionContext, ModelDict

logger = getLogger(__name__)


@torch.no_grad()
def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
dest_path = path.join(ctx.model_path, name + ".pth")
dest_onnx = path.join(ctx.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
def convert_correction_gfpgan(
ctx: ConversionContext,
model: ModelDict,
source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale")

dest = path.join(ctx.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest)

if path.isfile(dest_onnx):
if path.isfile(dest):
logger.info("ONNX model already exists, skipping.")
return

if not path.isfile(dest_path):
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)

logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
Expand All @@ -36,7 +39,7 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
scale=scale,
)

torch_model = torch.load(dest_path, map_location=ctx.map_location)
torch_model = torch.load(source, map_location=ctx.map_location)
# TODO: make sure strict=False is safe here
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
Expand All @@ -54,15 +57,15 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
"output": {2: "width", 3: "height"},
}

logger.info("exporting ONNX model to %s", dest_onnx)
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest_onnx,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset,
opset_version=ctx.opset,
export_params=True,
)
logger.info("GFPGAN exported to ONNX successfully.")

0 comments on commit 9f20248

Please sign in to comment.