In [None]:
from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline, AutoencoderKL,StableDiffusionPipeline

from transformers import AutoTokenizer, DistilBertForSequenceClassification
from diffusers import EulerDiscreteScheduler
import torch

In [None]:
import torch
from PIL import Image
import numpy as np
import os
import safetensors

In [None]:
from lora_utils import LoRAModuleWeight, LoRAHook, LoRAHookInjector
from embedding.embedding import EmbeddingExtent
from embedding.text_encoder_hijack import TextEncoderHijack
from embedding.textual_inversion import TextualInversionPlug

from convert_prompt_utils import *

In [None]:
generator = DiffusionPipeline.from_pretrained("../model/diffuser_model/",torch_dtype=torch.float16,safety_checker=None)

# vae = AutoencoderKL.from_pretrained("vae_path",torch_dtype=torch.float16,subfolder="vae").to("cuda")
# generator.vae = vae

generator.scheduler = EulerDiscreteScheduler.from_config(generator.scheduler.config)

generator.to("cuda")

In [None]:
#local lora model path
lora_dir_path = './model/lora'

In [None]:
module_dict = {}
def get_module_dict(module):
    global module_dict
    for lora_layer_name, hook in module.hooks.items():
        module_class_name = hook.orig_module.__class__.__name__
        module_dict[lora_layer_name] = module_class_name

get_module_dict(generator.lora_injector)


In [None]:
def preload_loras(loaded_loras, lora_dir_path):
    """
    preload lora list from file_path at project starting
    """
    for file_name in os.listdir(lora_dir_path):
        if file_name.split(".")[-1] == "safetensors":
            lora = file_name.split(".")[0]
            if lora not in loaded_loras:
                lora_name, LoraWeight = load_lora_from_disk(lora_dir_path, file_name)
                if lora_name == "":
                    continue
                loaded_loras[lora_name] = LoraWeight

def load_lora_by_name(loaded_loras: dict, lora_dir_path: str,lora_name: str):
    """
    if lora file add in dir after project starting, load it from disk by lora name
    """
    if lora_name in loaded_loras:
        return True
    file_name = lora_name + ".safetensors"
    lora_name_loaded, LoraWeight = load_lora_from_disk(lora_dir_path, file_name)
    if lora_name_loaded == "":
        return False
    loaded_loras[lora_name_loaded] = LoraWeight
    return True

def load_lora_from_disk(lora_dir_path: str, file_name: str):
    if not os.path.exists(os.path.join(lora_dir_path, file_name)):
        return "", None
    if file_name.split(".")[-1] == "safetensors":
        lora_name = file_name.split(".")[0]
        state_dict = safetensors.torch.load_file(os.path.join(lora_dir_path, file_name))
        LoraWeight = LoRAModuleWeight(lora_name, module_dict, state_dict, 1.0, "cuda", torch.float16)
        return lora_name, LoraWeight

loaded_loras = {}
preload_loras(loaded_loras, lora_dir_path)

In [None]:
import gc
def torch_gc():
    if torch.cuda.is_available():
        with torch.cuda.device('cuda'):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

def get_lora_by_nameList(lora_name_list):
    global loaded_loras
    res_lora_list = []
    for lora_name,weight in lora_name_list:
        if lora_name in loaded_loras:
            res_lora_list.append((loaded_loras[lora_name],weight))
        else:
            if load_lora_by_name(loaded_loras, lora_dir_path, lora_name):
                res_lora_list.append((loaded_loras[lora_name],weight))
            else:
                print(f'load lora {lora_name} failed, lora name not exist')
    return res_lora_list

def add_lora_weight_to_pipeline(pipeline, lora_name_list):
    apply_lora_list = []
    lora_list = get_lora_by_nameList(lora_name_list)
    for lora_weight, multiple in lora_list:
        if isinstance(multiple, str):
            multiple = float(multiple)
        pipeline.load_lora(lora_weight,multiple)
        apply_lora_list.append((lora_weight.lora_name,multiple))
    pipeline.apply_lora()
    return apply_lora_list

def clear_lora_weight_from_pipeline(pipeline):
    pipeline.clear_lora()
    gc.collect()
    torch_gc()

In [None]:
def install_lora_hook(pipe: DiffusionPipeline):
    """Install LoRAHook to the pipe."""
    if hasattr(pipe, "lora_injector"):
        return
    else:
        injector = LoRAHookInjector()
        injector.install_hooks(pipe)
        pipe.lora_injector = injector
        pipe.load_lora = injector.load_lora
        pipe.apply_lora = injector.apply_lora
        pipe.clear_lora = injector.clear_lora

In [None]:
install_lora_hook(generator)


In [None]:
CLIP_Tokenizer = generator.tokenizer
CLIP_TextModel = generator.text_encoder

# load TextualInversion
TextualInversion = TextualInversionPlug('textual_inversion_path',tokenizer = CLIP_Tokenizer)
TextualInversion.load_textual_inversion()

hijack = TextEncoderHijack()
hijack.hijack_embeding(CLIP_TextModel)

embedding = EmbeddingExtent(tokenizer=CLIP_Tokenizer,text_encoder=CLIP_TextModel,textual_inversion_manager=TextualInversion,hijack=hijack,device="cuda",dtype=torch.float16)

prompt = "(absurdres, highres, ultra detailed), 1 male, handsome, tall muscular guy, very short hair, best ratio four finger and one thumb, best light and shadow, background is back alley, detasiled sunlight, sitting, Little cats are gathered next to him, dappled sunlight, day, depth of field, plants, summer, (dutch angle), closed mouth, summer day"


negative = "(hair between eyes), sketch, duplicate, ugly, huge eyes, text, logo, worst face, (bad and mutated hands:1.3), (worst quality:2.0), (low quality:2.0), (blurry:2.0), horror, geometry, bad_prompt, (bad hands), (missing fingers), multiple limbs, bad anatomy, (interlocked fingers:1.2), Ugly Fingers, (extra digit and hands and fingers and legs and arms:1.4), ((2girl)), (deformed fingers:1.2), (long fingers:1.2), extra legs, upper teeth, parted lips, open mouth"


In [None]:
prompt, lora_name_list = find_and_replace_lora(prompt)
if lora_name_list != None and len(lora_name_list) > 0:
    apply_lora_list = add_lora_weight_to_pipeline(generator, lora_name_list)


In [None]:
prompt_pre_embedding = embedding(prompt,CLIP_stop_at_last_layers=1)
negative_prompt_pre_embedding = embedding(negative,CLIP_stop_at_last_layers=1)

prompt_pre_embedding,negative_prompt_pre_embedding = embedding.pad_prompt_tensor_same_length(prompt_emb=prompt_pre_embedding, negative_prompt_emb=negative_prompt_pre_embedding,CLIP_stop_at_last_layers=1)




In [None]:
import random

seed = int(random.randrange(4294967294))
# seed=4196966724
print(seed)
Generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(seed, seed + 1)]


image = generator(prompt_embeds=prompt_pre_embedding,
                  width=512,
                  height=768,
                  negative_prompt_embeds=negative_prompt_pre_embedding,
                  num_inference_steps=20,
                  guidance_scale=9,
                  generator=Generator,
                  num_images_per_prompt=1).images

In [None]:
image[0]

In [None]:
clear_lora_weight_from_pipeline(generator)
torch_gc()