In [None]:
# init
import torch

# configs
from sdconfigs import SDXLConfig
# utilities
import utilities as utils
# compel helper
from compel_helper import SDXLCompelHelper

# hugging face cache directory
CACHE_DIR = "D:\HuggingFaceCache"

In [None]:
# ui
config: SDXLConfig = SDXLConfig.load_config()
config.set_ui()

In [None]:
# load pipeline, set schelduler, to cuda, compel init and etc
base_pipe_model = config.base_pipeline_type.from_pretrained(config.base_pipe_model, cache_dir=CACHE_DIR, 
                                               torch_dtype=config.torch_dtype,
                                               variant=config.variant,
                                               use_safetensors=config.use_safetensors)


scheduler = config.scheduler_type.from_config(base_pipe_model.scheduler.config)
base_pipe_model.scheduler = scheduler
#base_pipe.unet = torch.compile(base_pipe.unet, mode="reduce-overhead", fullgraph=True) # not for windows

# init compel
compel = SDXLCompelHelper(base_pipe_model.tokenizer, base_pipe_model.text_encoder, base_pipe_model.tokenizer_2, base_pipe_model.text_encoder_2)
# base pipeline to CUDA
utils.to_cuda(base_pipe_model, "Base -> CUDA started", "Base -> CUDA finished")

In [None]:
# generate image

generator = torch.Generator("cuda").manual_seed(config.seed)
if config.high_noise_frac == 1.0 and config.use_refiner:
    print("High noise fraction == 1.0 and Use refiner is True")

# prompt embeds
(base_positive_prompt_embeds, base_positive_prompt_pooled, base_negative_prompt_embeds, base_negative_prompt_pooled) = compel.get_embeddings(config.prompt, config.prompt_2, config.negative_prompt, config.negative_prompt_2)

# prepare for inference
prompt = config.prompt if config.prompt != "" and not config.use_compel else None
prompt_2 = config.prompt_2  if config.prompt_2 != "" and not config.use_compel else None
negative_prompt = config.negative_prompt if config.negative_prompt != "" and not config.use_compel else None
negative_prompt_2 = config.negative_prompt_2 if config.negative_prompt_2 != "" and not config.use_compel else None
prompt_embeds = base_positive_prompt_embeds if config.use_compel else None
pooled_prompt_embeds=base_positive_prompt_pooled if config.use_compel else None
negative_prompt_embeds=base_negative_prompt_embeds if config.use_compel else None
negative_pooled_prompt_embeds=base_negative_prompt_pooled if config.use_compel else None
output_type="latent" if config.use_refiner else "pil"

# base
print(f"{config.prompt}\n{config.prompt_2}\n{config.negative_prompt}\n{config.negative_prompt_2}")
base_output = base_pipe_model(prompt=prompt,
                              prompt_2=prompt_2,
                              negative_prompt=negative_prompt,
                              negative_prompt_2=negative_prompt_2,
                              prompt_embeds=prompt_embeds,
                              pooled_prompt_embeds=pooled_prompt_embeds,
                              negative_prompt_embeds=negative_prompt_embeds,
                              negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
                              num_inference_steps=config.num_inference_steps,
                              generator=generator,
                              guidance_scale=config.guidance_scale,
                              output_type=output_type,
                              denoising_end=config.high_noise_frac,
                              width=config.width,
                              height=config.height)

# save data to config
config.save_config()

# save results
if not config.use_refiner:
    image = base_output.images[0]
    utils.save_results(image, SDXLConfig.CONFIG_PATH)
    display(image)
    
# clean gpu cache
torch.cuda.empty_cache()

In [None]:
# refiner

refiner_pipe_model = config.refiner_pipeline_type.from_pretrained(config.refiner_pipe_model, cache_dir=CACHE_DIR, 
                                    torch_dtype=config.torch_dtype,
                                    variant=config.variant,
                                    use_safetensors=config.use_safetensors,
                                    vae=base_pipe_model.vae,
                                    text_encoder_2=base_pipe_model.text_encoder_2)

refiner_pipe_model.scheduler = scheduler
#refiner_pipe.unet = torch.compile(refiner_pipe.unet, mode="reduce-overhead", fullgraph=True) # not for windows

# refiner pipeline to CUDA
utils.to_cuda(refiner_pipe_model, "Refiner -> CUDA started", "Refiner -> CUDA finished")

# prompt embeds
compel_refiner = SDXLCompelHelper(None, None, refiner_pipe_model.tokenizer_2, refiner_pipe_model.text_encoder_2)
(base_positive_prompt_embeds_refiner, base_positive_prompt_pooled_refiner, base_negative_prompt_embeds_refiner, base_negative_prompt_pooled_refiner) = compel_refiner.get_embeddings(config.prompt, config.prompt_2, config.negative_prompt, config.negative_prompt_2)

prompt_embeds = base_positive_prompt_embeds_refiner if config.use_compel else None
pooled_prompt_embeds=base_positive_prompt_pooled_refiner if config.use_compel else None
negative_prompt_embeds=base_negative_prompt_embeds_refiner if config.use_compel else None
negative_pooled_prompt_embeds=base_negative_prompt_pooled_refiner if config.use_compel else None

image_refined = refiner_pipe_model(prompt=prompt,
                             prompt_2 = prompt_2,
                             negative_prompt = negative_prompt,
                             negative_prompt_2 = negative_prompt_2,
                             prompt_embeds=prompt_embeds,
                             pooled_prompt_embeds=pooled_prompt_embeds,
                             negative_prompt_embeds=negative_prompt_embeds,
                             negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
                             num_inference_steps=config.num_inference_steps,
                             generator=generator,
                             guidance_scale=config.guidance_scale,
                             denoising_start=config.high_noise_frac, 
                             image=base_output.images,
                             #original_size = (height, width),
                             #target_size = (height, width)
                             ).images[0]

utils.save_results(image_refined, SDXLConfig.CONFIG_PATH)
display(image_refined)

torch.cuda.empty_cache()

In [None]:
## old but gold stuff

#if use_refiner_checkbox.value:
    #base_pipe.to("cpu")
    #torch.cuda.empty_cache()
    #torch.cuda.ipc_collect() 
    #unrefined_image = postprocess_latent(base_pipe, base_output)
    #display(unrefined_image)
#else:
   #display(base_output.images[0])