In [None]:
# init
import torch
# configs
from sdconfigs import SDXLConfig, UIData
# utilities
import utilities as utils
import copy
# compel helper
from compel_helper import SDXLCompelHelper
# for additional ui
import ipywidgets as widgets
from ipywidgets import Layout
from ipywidgets import VBox, HBox
from PIL import Image

# hugging face cache directory
CACHE_DIR = "D:\HuggingFaceCache"
# config init
DEFAULT_CONFIG_PATH = "sdxl_config.json"
def save_config_to_default(config:SDXLConfig): config.save_config(DEFAULT_CONFIG_PATH)
last_config:SDXLConfig = None
base_pipe = None
refiner_pipe = None

In [None]:
# load pipeline and generate image methods
def set_base_pipe_from_config(config:SDXLConfig):
    global base_pipe
    base_pipe = config.base_pipeline_type.from_pretrained(config.base_model, cache_dir=CACHE_DIR, 
                                                torch_dtype=config.torch_dtype,
                                                variant=config.variant,
                                                use_safetensors=config.use_safetensors)

    scheduler = config.scheduler_type.from_config(base_pipe.scheduler.config, 
                                                use_karras_sigmas=config.use_karras_sigmas,
                                                timestep_spacing=config.timestep_spacing)
    base_pipe.scheduler = scheduler
    #base_pipe.unet = torch.compile(base_pipe.unet, mode="reduce-overhead", fullgraph=True) # not for windows

    # base pipeline to CUDA
    utils.to_cuda(base_pipe, "Base -> CUDA started", "Base -> CUDA finished")

def set_refiner_pipe_from_config(config:SDXLConfig, base_pipe):
    global refiner_pipe
    refiner_pipe = config.refiner_pipeline_type.from_pretrained(config.refiner_model, cache_dir=CACHE_DIR, 
                                    torch_dtype=config.torch_dtype,
                                    variant=config.variant,
                                    use_safetensors=config.use_safetensors,
                                    vae=base_pipe.vae,
                                    text_encoder_2=base_pipe.text_encoder_2)

    refiner_pipe.scheduler = base_pipe.scheduler
    #refiner_pipe.unet = torch.compile(refiner_pipe.unet, mode="reduce-overhead", fullgraph=True) # not for windows

    # refiner pipeline to CUDA
    utils.to_cuda(refiner_pipe, "Refiner -> CUDA started", "Refiner -> CUDA finished")

def generate_refined_image_from_config(config:SDXLConfig, refiner_pipe, generator, prompt, prompt_2, negative_prompt, negative_prompt_2, base_output_images, use_ensemble_of_experts:bool) -> Image:
    # prompt embeds
    (base_positive_prompt_embeds_refiner, base_positive_prompt_pooled_refiner, base_negative_prompt_embeds_refiner, base_negative_prompt_pooled_refiner) = [None, None, None, None]
    if config.use_compel:
        # init compel
        compel_refiner = SDXLCompelHelper(None, None, refiner_pipe.tokenizer_2, refiner_pipe.text_encoder_2)
        (base_positive_prompt_embeds_refiner, base_positive_prompt_pooled_refiner, base_negative_prompt_embeds_refiner, base_negative_prompt_pooled_refiner) = compel_refiner.get_embeddings(config.prompt, config.prompt_2, config.negative_prompt, config.negative_prompt_2)

    prompt_embeds=base_positive_prompt_embeds_refiner 
    pooled_prompt_embeds=base_positive_prompt_pooled_refiner
    negative_prompt_embeds=base_negative_prompt_embeds_refiner
    negative_pooled_prompt_embeds=base_negative_prompt_pooled_refiner

    # denoising start
    denoising_start=config.high_noise_frac if use_ensemble_of_experts else 0.0

    # inference
    base_output_refined = refiner_pipe(prompt=prompt,
                                    prompt_2=prompt_2,
                                    negative_prompt=negative_prompt,
                                    negative_prompt_2=negative_prompt_2,
                                    prompt_embeds=prompt_embeds,
                                    pooled_prompt_embeds=pooled_prompt_embeds,
                                    negative_prompt_embeds=negative_prompt_embeds,
                                    negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
                                    num_inference_steps=config.num_inference_steps,
                                    generator=generator,
                                    guidance_scale=config.guidance_scale,
                                    denoising_start=denoising_start, 
                                    image=base_output_images,
                                    #original_size = (height, width),
                                    #target_size = (height, width)
                                    )
    
    return base_output_refined.images[0]

def generate_image_from_config(config:SDXLConfig, update_pipe:bool) -> Image:
    if base_pipe == None or update_pipe:
        set_base_pipe_from_config(config)

    generator = torch.Generator("cuda").manual_seed(config.seed)

    # warnings
    use_ensemble_of_experts:bool = True
    if config.high_noise_frac == 1.0 and config.use_refiner:
        use_ensemble_of_experts = False
        print("High noise fraction == 1.0 and Use refiner is True -> Set use_ensemble_of_experts to False")
    elif config.high_noise_frac < 1.0 and not config.use_refiner:
        print("High noise fraction < 1.0 and Use refiner is False")

    # prompt embeds
    (base_positive_prompt_embeds, base_positive_prompt_pooled, base_negative_prompt_embeds, base_negative_prompt_pooled) = [None, None, None, None]
    if config.use_compel:
        # init compel
        compel = SDXLCompelHelper(base_pipe.tokenizer, base_pipe.text_encoder, base_pipe.tokenizer_2, base_pipe.text_encoder_2)
        (base_positive_prompt_embeds, base_positive_prompt_pooled, base_negative_prompt_embeds, base_negative_prompt_pooled) = compel.get_embeddings(config.prompt, config.prompt_2, config.negative_prompt, config.negative_prompt_2)

    # prepare for inference
    prompt=config.prompt if config.prompt != "" and not config.use_compel else None
    prompt_2=config.prompt_2 if config.prompt_2 != "" and not config.use_compel else None
    negative_prompt=config.negative_prompt if config.negative_prompt != "" and not config.use_compel else None
    negative_prompt_2=config.negative_prompt_2 if config.negative_prompt_2 != "" and not config.use_compel else None
    prompt_embeds=base_positive_prompt_embeds
    pooled_prompt_embeds=base_positive_prompt_pooled
    negative_prompt_embeds=base_negative_prompt_embeds
    negative_pooled_prompt_embeds=base_negative_prompt_pooled
    output_type="latent" if config.use_refiner else "pil"

    # save data to default config
    save_config_to_default(config)

    # inference
    base_output = base_pipe(prompt=prompt,
                                prompt_2=prompt_2,
                                negative_prompt=negative_prompt,
                                negative_prompt_2=negative_prompt_2,
                                prompt_embeds=prompt_embeds,
                                pooled_prompt_embeds=pooled_prompt_embeds,
                                negative_prompt_embeds=negative_prompt_embeds,
                                negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
                                num_inference_steps=config.num_inference_steps,
                                generator=generator,
                                guidance_scale=config.guidance_scale,
                                output_type=output_type,
                                denoising_end=config.high_noise_frac,
                                #original_size = (height, width),
                                #target_size = (height, width),
                                **config.kwargs
                                )

    # results
    if not config.use_refiner:
        image = base_output.images[0]
    else:
        if refiner_pipe == None or update_pipe:
            set_refiner_pipe_from_config(config, base_pipe)
        image = generate_refined_image_from_config(config, refiner_pipe, generator, prompt, prompt_2, negative_prompt, negative_prompt_2, base_output.images, use_ensemble_of_experts)
    
    # save results
    utils.save_results(image, DEFAULT_CONFIG_PATH, config.image, config.mask)

    # clean gpu cache
    torch.cuda.empty_cache()
    return image

In [None]:
# load config
load_config_text_w = widgets.Text(value=DEFAULT_CONFIG_PATH, description='Load config from:', style={'description_width': 'initial'}, layout=Layout( width='auto'))
display(load_config_text_w)

In [None]:
# config ui
def generate_image_to_ui():
    generate_btn.disabled = True
    global last_config
    image = generate_image_from_config(config, last_config == None or not config.model_params_equals(last_config))
    last_config = copy.deepcopy(config)
    res_image_ui.value = utils.compressed_img_to_bytes(image, 'PNG')
    generate_btn.disabled = False
def on_prompt_text_area_changed(change): on_prompt_changed(change.new)
def on_prompt_changed(value:str):
    if config.is_turbo and value.endswith(","):
        generate_image_to_ui()

load_config_path = load_config_text_w.value
config:SDXLConfig = SDXLConfig.load_config(load_config_path)
ui:UIData = config.get_ui()
ui.prompt.observe(on_prompt_text_area_changed, names='value')

# result image area
image = Image.new(mode="RGB", size=(512, 512))
res_image_ui = widgets.Image(value=utils.compressed_img_to_bytes(image, 'PNG'), format='raw', layout=Layout())
# save button
def s(b): save_config_to_default(config) # for save btn
save_btn = widgets.Button(description="Save config", layout=Layout(width='50%'))
save_btn.on_click(s)
# generate button
def g(b): generate_image_to_ui() # for generate btn
generate_btn = widgets.Button(description="Generate", layout=Layout(width='50%'))
generate_btn.on_click(g)

# display whole ui
system_btns_box = HBox([generate_btn, save_btn], layout=Layout(margin='10px 0 0 0'))
ui = HBox([VBox(ui.prompt_box, layout=Layout(width='33%', margin='10px 20px 10px 0')),
           VBox(ui.params_box, layout=Layout(width='33%', margin='10px 0 10px 0')), 
           VBox([res_image_ui, system_btns_box], layout=Layout(width='33%', margin='10px 0 0 20px')),
           ])
display(ui)