In [3]:
import torch
import os
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

# Define common settings
pretrained_model_name_or_path = "/root/autodl-tmp/stable_diffusion/stable-diffusion-v1-5"
learned_embedding_path_list = ["/root/autodl-tmp/textual_inversion/trained_embeddings/custom_chair/original/learned_embeds_factor=0.0.bin", 
                               "/root/autodl-tmp/textual_inversion/trained_embeddings/custom_cat/original/learned_embeds_factor=0.0.bin"]
merge_path = "/root/autodl-tmp/textual_inversion/merged_embeddings"
specific_path = None # directory name
embedding_name_path = None # embedding name

tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
)
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16
)

In [4]:
# Merge multiple embeddings
learned_embeddings_dict = {}
for learned_embedding_path in learned_embedding_path_list:
    all_embedding_path = os.path.dirname(learned_embedding_path)
    dataset_path, _ = os.path.split(all_embedding_path)
    _, dataset_name = os.path.split(dataset_path)
    if specific_path is None:
        specific_path = dataset_name
    else:
        specific_path = specific_path + "&" + dataset_name
        
    full_embedding_path, _ = os.path.splitext(learned_embedding_path)
    _, embedding_name = os.path.split(full_embedding_path)
    if embedding_name_path is None:
        embedding_name_path = embedding_name
    else:
        embedding_name_path = embedding_name_path + "&" + embedding_name
    
    loaded_learned_embedding = torch.load(learned_embedding_path, map_location="cpu")
    
    # separate token and the embeds
    trained_token = list(loaded_learned_embedding.keys())[0]
    embedding = loaded_learned_embedding[trained_token]
    
    if trained_token in learned_embeddings_dict.keys():
        raise ValueError(f"{trained_token} is repetitive, please use another placeholder token")
    learned_embeddings_dict[trained_token] = embedding
    
output_path = f"{merge_path}/{specific_path}"
os.makedirs(output_path, exist_ok=True)
output_embeddings_path = os.path.join(output_path, "{}.bin".format(embedding_name_path))
torch.save(learned_embeddings_dict, output_embeddings_path)