Skip to content

Commit

Permalink
[Feat] Enable State Dict For Textual Inversion Loader (huggingface#3439)
Browse files Browse the repository at this point in the history
* enable state dict for textual inversion loader

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* Empty-Commit | restart CI

* add tests

* fix tests

* fix tests

* fix tests

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
ghunkins and patrickvonplaten committed May 30, 2023
1 parent 7bbc036 commit 8e0cb4e
Showing 1 changed file with 38 additions and 33 deletions.
71 changes: 38 additions & 33 deletions loaders.py
Expand Up @@ -470,7 +470,7 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):

def load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str]],
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None,
**kwargs,
):
Expand All @@ -485,7 +485,7 @@ def load_textual_inversion(
</Tip>
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Expand All @@ -494,6 +494,8 @@ def load_textual_inversion(
- A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`.
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
Or a list of those elements.
token (`str` or `List[str]`, *optional*):
Expand Down Expand Up @@ -618,7 +620,7 @@ def load_textual_inversion(
"framework": "pytorch",
}

if isinstance(pretrained_model_name_or_path, str):
if not isinstance(pretrained_model_name_or_path, list):
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
else:
pretrained_model_name_or_paths = pretrained_model_name_or_path
Expand All @@ -643,16 +645,38 @@ def load_textual_inversion(
token_ids_and_embeddings = []

for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
if not isinstance(pretrained_model_name_or_path, dict):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e

model_file = None

if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand All @@ -663,28 +687,9 @@ def load_textual_inversion(
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e

model_file = None

if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path

# 2. Load token and embedding correcly from file
loaded_token = None
Expand Down

0 comments on commit 8e0cb4e

Please sign in to comment.