Skip to content

Commit

Permalink
[Pipeline download] Improve pipeline download for index and passed co… (
Browse files Browse the repository at this point in the history
huggingface#2980)

* [Pipeline download] Improve pipeline download for index and passed components

* correct

* add more tests

* up
  • Loading branch information
patrickvonplaten authored Apr 5, 2023
1 parent 558c85b commit 43efe0c
Showing 1 changed file with 96 additions and 35 deletions.
131 changes: 96 additions & 35 deletions pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray


def is_safetensors_compatible(filenames, variant=None) -> bool:
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
Expand All @@ -150,9 +150,14 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:

sf_filenames = set()

passed_components = passed_components or []

for filename in filenames:
_, extension = os.path.splitext(filename)

if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
continue

if extension == ".bin":
pt_filenames.append(filename)
elif extension == ".safetensors":
Expand All @@ -163,10 +168,8 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
path, filename = os.path.split(filename)
filename, extension = os.path.splitext(filename)

if filename == "pytorch_model":
filename = "model"
elif filename == f"pytorch_model.{variant}":
filename = f"model.{variant}"
if filename.startswith("pytorch_model"):
filename = filename.replace("pytorch_model", "model")
else:
filename = filename

Expand Down Expand Up @@ -196,24 +199,51 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
# -00001-of-00002
transformers_index_format = "\d{5}-of-\d{5}"

if variant is not None:
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
variant_file_re = re.compile(
f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
)
# `text_encoder/pytorch_model.bin.index.fp16.json`
variant_index_re = re.compile(
f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)

variant_file_regex = (
re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})")
if variant is not None
else None
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
non_variant_file_re = re.compile(
f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
)
non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}")
# `text_encoder/pytorch_model.bin.index.json`
non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")

if variant is not None:
variant_filenames = {f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None}
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
variant_filenames = variant_weights | variant_indexes
else:
variant_filenames = set()

non_variant_filenames = {f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None}
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
non_variant_filenames = non_variant_weights | non_variant_indexes

# all variant filenames will be used by default
usable_filenames = set(variant_filenames)

def convert_to_variant(filename):
if "index" in filename:
variant_filename = filename.replace("index", f"index.{variant}")
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename

for f in non_variant_filenames:
variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}"
variant_filename = convert_to_variant(f)
if variant_filename not in usable_filenames:
usable_filenames.add(f)

Expand Down Expand Up @@ -292,6 +322,27 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
return class_obj, class_candidates


def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None):
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
else:
file_name = CUSTOM_PIPELINE_FILE_NAME

return get_class_from_dynamic_module(
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
)

if class_obj != DiffusionPipeline:
return class_obj

diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
return getattr(diffusers_module, config["_class_name"])


def load_sub_model(
library_name: str,
class_name: str,
Expand Down Expand Up @@ -779,7 +830,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)

# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
Expand All @@ -794,8 +845,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token=use_auth_token,
revision=revision,
from_flax=from_flax,
use_safetensors=use_safetensors,
custom_pipeline=custom_pipeline,
custom_revision=custom_revision,
variant=variant,
**kwargs,
)
else:
cached_folder = pretrained_model_name_or_path
Expand All @@ -810,29 +864,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
for folder in os.listdir(cached_folder):
folder_path = os.path.join(cached_folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config_dict
variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path))
variant_exists = is_folder and any(
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
)
if variant_exists:
model_variants[folder] = variant

# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
else:
file_name = CUSTOM_PIPELINE_FILE_NAME

pipeline_class = get_class_from_dynamic_module(
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision
)
elif cls != DiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
pipeline_class = _get_pipeline_class(
cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
)

# DEPRECATED: To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
Expand Down Expand Up @@ -1095,6 +1137,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)

Expand Down Expand Up @@ -1153,7 +1196,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]

allow_patterns += [
SCHEDULER_CONFIG_NAME,
Expand All @@ -1162,17 +1205,28 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
CUSTOM_PIPELINE_FILE_NAME,
]

# retrieve passed components that should not be downloaded
pipeline_class = _get_pipeline_class(
cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
)
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]

if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(model_filenames, variant=variant)
and not is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
)
):
raise EnvironmentError(
f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant):
elif use_safetensors and is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
):
ignore_patterns = ["*.bin", "*.msgpack"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
Expand All @@ -1194,6 +1248,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)

# Don't download any objects that are passed
allow_patterns = [
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
]
# Don't download index files of forbidden patterns either
ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]

re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]

Expand Down

0 comments on commit 43efe0c

Please sign in to comment.