In [None]:
from google.colab import userdata
import huggingface_hub

In [None]:
HuggingFace_API_KEY = userdata.get("HuggingFace_API_KEY")
huggingface_hub.login(token=HuggingFace_API_KEY)  # HuggingFaceにログイン

### webUI形式をダウンロードした後にdiffusers形式に変換する

In [None]:
# Counterfeit-V3.0のWebUI形式のsafetensorsファイルをwgetコマンドでダウンロード
!wget -P "./webUI_model" https://huggingface.co/gsdf/Counterfeit-V3.0/resolve/main/Counterfeit-V3.0_fp16.safetensors?download=true
# OrangeMixsのWebUI形式のsafetensorsファイルをwgetコマンドでダウンロード
!wget -P "./webUI_model" https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A1B_orangemixs.safetensors?download=true
# WebUI形式のCouterfeit-V3.0のsafetensorsファイルを、diffusers形式のファイルに変換
!python convert_original_stable_diffusion_to_diffusers.py --original_config_file v1-inference.yaml --checkpoint_path webUI_model/Counterfeit-V3.0_fp16.safetensors --image_size 512 --prediction_type epsilon --extract_ema  --upcast_attention  --dump_path diffusers_model/Counterfeit-V3.0 --from_safetensors  --device cpu
# WebUI形式のOrangeMixsのsafetensorsファイルを、diffusers形式のファイルに変換
!python convert_original_stable_diffusion_to_diffusers.py --original_config_file v1-inference.yaml --checkpoint_path webUI_model/AOM3A1B_orangemixs.safetensors --image_size 512 --prediction_type epsilon --extract_ema  --upcast_attention  --dump_path diffusers_model/AOM3A1B_orangemixs --from_safetensors  --device cpu
# WebUI形式のOrangeMixsのVAEのptファイルをダウンロード
!wget -P "./webUI_model" https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/VAEs/orangemix.vae.pt?download=true
# WebUI形式のOrangeMixsのVAEのptファイルを、diffusers形式のVAEファイルに変換
!python convert_vae_pt_to_diffusers.py --vae_pt_path webUI_model/orangemixs.vae.pt --dump_path diffusers_model/orangemixs.vae

In [None]:
from datetime import datetime
import matplotlib.pyplot as plt
import PIL
import IPython
import torch
import torchvision
import diffusers

### マージモデルを作成する

In [None]:
MODEL_counterfeit = "./diffusers_model/Counterfeit-V3.0"
MODEL_orangemix = "./diffusers_model/AOM3A1B_orangemixs"
merge_model_pipe = diffusers.StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path=MODEL_counterfeit,
                                                                     torch_dtype=torch.float16)
pipe_2 = diffusers.StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path=MODEL_orangemix,
                                                           torch_dtype=torch.float16)
counterfeit_model_unet_param = dict(merge_model_pipe.unet.named_parameters())  # CounterfeitモデルのU-Netのパラメータ名と値を辞書型で取得
orangemix_model_unet_param = dict(pipe_2.unet.named_parameters())  # OrangemixモデルのU-Netのパラメータ名と値を辞書型で取得
# それぞれのモデルのU-Netのパラメータ値の平均を、マージモデルのU-Netのパラメータ値として保存
for key, value in merge_model_pipe.unet.named_parameters():
    value.data = 0.25 * counterfeit_model_unet_param[key] + 0.75 * orangemix_model_unet_param[key]
counterfeit_model_textencoder_param = dict(merge_model_pipe.text_encoder.named_parameters())  # CounterfeitモデルのTextEncoderのパラメータ名と値を辞書型で取得
orangemix_model_textencoder_param = dict(pipe_2.text_encoder.named_parameters())  # OrangemixモデルのTextEncoderのパラメータ名と値を辞書型で取得
# それぞれのモデルのTextEncoderのパラメータ値の平均を、マージモデルのTextEncoderのパラメータ値として保存
for key, value in merge_model_pipe.text_encoder.named_parameters():
    value.data = 0.25 * counterfeit_model_textencoder_param[key] + 0.75 * orangemix_model_textencoder_param[key]
counterfeit_model_vae_param = dict(merge_model_pipe.vae.named_parameters())  # CounterfeitモデルのVAEのパラメータ名と値を辞書型で取得
orangemix_model_vae_param = dict(pipe_2.vae.named_parameters())  # OrangemixモデルのVAEのパラメータ名と値を辞書型で取得
# それぞれのモデルのVAEのパラメータ値の平均を、マージモデルのTextEncoderのパラメータ値として保存
for key, value in merge_model_pipe.vae.named_parameters():
    value.data = 0.95 * counterfeit_model_vae_param[key] + 0.05 * orangemix_model_vae_param[key]
merge_model_pipe.scheduler = diffusers.UniPCMultistepScheduler.from_config(merge_model_pipe.scheduler.config)
merge_model_pipe.save_pretrained(save_directory="./diffusers_model/Counterfeit-V3.0_merge_AOM3A1B_orangemixs")

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL = "./diffusers_model/Counterfeit-V3.0_merge_AOM3A1B_orangemixs"

In [None]:
# トークン数と強調の制限を解除
# https://huggingface.co/docs/diffusers/v0.12.0/en/using-diffusers/custom_pipeline_examples#long-prompt-weighting-stable-diffusion
pipe = diffusers.DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=MODEL,
                                                   custom_pipeline="AlanB/lpw_stable_diffusion_update",  # カスタムパイプラインで制限を解除
                                                   torch_dtype=torch.float16)
pipe.safety_checker = lambda images, **kwargs: (images, None)
pipe.enable_vae_tiling()
pipe.to(DEVICE)

In [None]:
output = pipe.text2img(prompt="(nsfw:1.25), masterpiece, 1girl, brown eyes, black long hair, smile, looking at viewer, (simple shirt:1.05), simple background, upper body, best quality, extremely detailed",
                       height=512,
                       width=512,
                       num_inference_steps=20,
                       guidance_scale=7.0,
                       negative_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
                       num_images_per_prompt=10,
                       generator=torch.Generator().manual_seed(1234),
                       max_embeddings_multiples=5,
                       output_type="pil")

In [None]:
print(type(output.images))
fig = plt.figure(figsize=(10, 5))
for i in range(len(output.images)):
    fig.add_subplot(2, 5, i+1)
    plt.imshow(output.images[i])
    plt.title(i)
    plt.axis("off")
plt.subplots_adjust(wspace=0, hspace=0.25)
plt.show()

In [None]:
for j in range(len(output.images)):
    print(type(output.images[j]))
    IPython.display.display(output.images[j])
    # ファイルとして保存する場合
    now = datetime.now()
    now_timestamp = str(now.year) + str(now.month) + str(now.day) + str(now.hour) + str(now.minute)
    output.images[j].save("./Counterfeit3-0_OrangeMixs_lpw/{a}_{b}.jpeg".format(a=now_timestamp, b=j),
                          "JPEG",
                          quality=95)