In [None]:
# init
import torch

# for parameters ui
from ipywidgets import interact, fixed
import ipywidgets as widgets
# for config
import json
# for saving results
import shutil
from datetime import datetime
# for image_grid
from PIL import Image
# for prompt embeddings
from compel import Compel, ReturnedEmbeddingsType

# hugging face cache directory
CACHE_DIR = "D:\HuggingFaceCache"
# config path
CONFIG_PATH = "sdxl_config.json"

# import models, schedulers and etc
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from diffusers import EulerDiscreteScheduler, DDIMScheduler, LMSDiscreteScheduler

BASE_PIPELINES = {"StableDiffusionXLPipeline":StableDiffusionXLPipeline,}
REFINER_PIPELINES = {"StableDiffusionXLImg2ImgPipeline":StableDiffusionXLImg2ImgPipeline}
SCHEDULERS = {"EulerDiscreteScheduler":EulerDiscreteScheduler, 
              "DDIMScheduler":DDIMScheduler, 
              "LMSDiscreteScheduler":LMSDiscreteScheduler}
PRECISION = {"torch.float16":torch.float16}

class SDXLConfig:
    def __init__(self,
                 base_pipe_model: str = "stabilityai/stable-diffusion-xl-base-1.0",
                 refiner_pipe_model: str = "stabilityai/stable-diffusion-xl-refiner-1.0",
                 torch_dtype_str: str = "torch.float16",
                 base_pipeline_type_str: str = "StableDiffusionXLPipeline",
                 refiner_pipeline_type_str: str = "StableDiffusionXLImg2ImgPipeline",
                 scheduler_type_str: str = "LMSDiscreteScheduler",
                 variant: str = "fp16",
                 use_safetensors: bool = True,
                 #safety_checker = None
                 prompt: str = None,
                 prompt_2: str = None,
                 negative_prompt: str = None,
                 negative_prompt_2: str = None,
                 use_compel: bool = False,
                 num_inference_steps: int = 40,
                 width: int = 768,
                 height: int = 768,
                 guidance_scale: float = 7.5,
                 high_noise_frac: float = 0.8,
                 seed: int = 12345,
                 use_refiner: bool = False
                 ):
        self.base_pipe_model = base_pipe_model
        self.refiner_pipe_model = refiner_pipe_model
        self.torch_dtype_str = torch_dtype_str
        self.base_pipeline_type_str = base_pipeline_type_str
        self.refiner_pipeline_type_str = refiner_pipeline_type_str
        self.scheduler_type_str = scheduler_type_str
        self.variant = variant
        self.use_safetensors = use_safetensors
        self.prompt = prompt
        self.prompt_2 = prompt_2
        self.negative_prompt = negative_prompt
        self.negative_prompt_2 = negative_prompt_2
        self.use_compel = use_compel
        self.num_inference_steps = num_inference_steps
        self.width = width
        self.height = height
        self.guidance_scale = guidance_scale
        self.high_noise_frac = high_noise_frac
        self.seed = seed
        self.use_refiner = use_refiner

    @property
    def torch_dtype(self):return PRECISION[self.torch_dtype_str]
    @property
    def base_pipeline_type(self):return BASE_PIPELINES[self.base_pipeline_type_str]
    @property
    def refiner_pipeline_type(self):return REFINER_PIPELINES[self.refiner_pipeline_type_str]
    @property
    def scheduler_type(self):return SCHEDULERS[self.scheduler_type_str]
    
    def to_json(obj):
        if isinstance(obj, SDXLConfig):
            return obj.__dict__
    def from_json(dict: dict):
            return SDXLConfig(**dict)
    def load_config():
         with open(CONFIG_PATH, "r") as read_file:
            return json.load(read_file, object_hook=SDXLConfig.from_json)
    def save_config(self):
         with open(CONFIG_PATH, "w") as write_file:
            json.dump(self, write_file, skipkeys=True, indent=1, default=SDXLConfig.to_json)

    def set_ui(self):
        # TODO: not best workaround to get variable name
        def f(x, f_name): setattr(self, f_name.split('=')[0].split('.')[1], x)

        # models, precisions, schedulers
        style = {'description_width': 'initial'}
        interact(f, x=widgets.Text(value=self.base_pipe_model, placeholder='', description='Base model:', style=style), f_name=fixed(f'{self.base_pipe_model=}'))
        interact(f, x=widgets.Text(value=self.refiner_pipe_model, placeholder='', description='Refiner model:', style=style), f_name=fixed(f'{self.refiner_pipe_model=}'))
        interact(f, x=widgets.Dropdown(value=self.torch_dtype_str, options=PRECISION.keys(), description='dtype:', style=style), f_name=fixed(f'{self.torch_dtype_str=}'))
        interact(f, x=widgets.Dropdown(value=self.base_pipeline_type_str, options=BASE_PIPELINES.keys(), description='Base type:', style=style), f_name=fixed(f'{self.base_pipeline_type_str=}'))
        interact(f, x=widgets.Dropdown(value=self.refiner_pipeline_type_str, options=REFINER_PIPELINES.keys(), description='Refiner type:', style=style), f_name=fixed(f'{self.refiner_pipeline_type_str=}'))
        interact(f, x=widgets.Dropdown(value=self.scheduler_type_str, options=SCHEDULERS.keys(), description='Scheduler type:', style=style), f_name=fixed(f'{self.scheduler_type_str=}'))
        
        # prompts
        interact(f, x=widgets.Textarea(value=self.prompt, placeholder='Type positive1...', description='Prompt1:', style=style), f_name=fixed(f'{self.prompt=}'))
        interact(f, x=widgets.Textarea(value=self.prompt_2, placeholder='Type positive2...', description='Prompt2:', style=style), f_name=fixed(f'{self.prompt_2=}'))
        interact(f, x=widgets.Textarea(value=self.negative_prompt, placeholder='Type negative1...', description='Negative Prompt1:', style = style), f_name=fixed(f'{self.negative_prompt=}'))
        interact(f, x=widgets.Textarea(value=self.negative_prompt_2, placeholder='Type negative2...', description='Negative Prompt2:', style = style), f_name=fixed(f'{self.negative_prompt_2=}'))
        interact(f, x=widgets.Checkbox(value=self.use_compel, description="Use Compel", indent=False, style=style), f_name=fixed(f'{self.use_compel=}'))

        # inference properties
        interact(f, x=widgets.IntSlider(value=self.num_inference_steps, min=10, max=100, step=5, description="Num inference steps:", continuous_update=False, style=style), f_name=fixed(f'{self.num_inference_steps=}'))
        interact(f, x=widgets.IntSlider(value=self.width, min=512, max=1024, step=64, description="Width:", continuous_update=False, style=style), f_name=fixed(f'{self.width=}'))
        interact(f, x=widgets.FloatSlider(value=self.guidance_scale, min=0, max=10, step=0.25, description="Guidance scale:", continuous_update=False, style=style), f_name=fixed(f'{self.guidance_scale=}'))
        interact(f, x=widgets.IntSlider(value=self.seed, min=0, max=1000000, step=1, description="Seed:", continuous_update=False, style=style), f_name=fixed(f'{self.seed=}'))
        interact(f, x=widgets.FloatSlider(value=self.high_noise_frac, min=0, max=1, step=0.05, description="High noise frac:", continuous_update=False, style=style), f_name=fixed(f'{self.high_noise_frac=}'))

        # refiner
        interact(f, x=widgets.Checkbox(value=self.use_refiner, description="Use refiner", indent=False, style=style), f_name=fixed(f'{self.use_refiner=}'))

# utilities methods
def to_cuda(pipe, start_mess, end_mess):
    if(torch.cuda.is_available()):
        print(start_mess)
        pipe = pipe.to("cuda")
    else:
        print("CUDA IS NOT AVAILABLE")
    print(end_mess)

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def postprocess_latent(pipe, latent):
    vae_output = pipe.vae.decode(
        latent.images / pipe.vae.config.scaling_factor, return_dict=False
    )[0].detach()
    return pipe.image_processor.postprocess(vae_output, output_type="pil")[0]


In [None]:
# ui
config: SDXLConfig = SDXLConfig.load_config()
config.set_ui()

In [None]:
# load pipeline, set schelduler, to cuda, compel init and etc
base_pipe_model = config.base_pipeline_type.from_pretrained(config.base_pipe_model, cache_dir=CACHE_DIR, 
                                               torch_dtype=config.torch_dtype,
                                               variant=config.variant,
                                               use_safetensors=config.use_safetensors)


scheduler = config.scheduler_type.from_config(base_pipe_model.scheduler.config)
base_pipe_model.scheduler = scheduler
#base_pipe.unet = torch.compile(base_pipe.unet, mode="reduce-overhead", fullgraph=True) # not for windows

# base pipeline to CUDA
to_cuda(base_pipe_model, "Base -> CUDA started", "Base -> CUDA finished")

# Compels init (TODO: maybe should move "generate image" cell)
base_compel_1 = Compel(
    tokenizer=base_pipe_model.tokenizer,
    text_encoder=base_pipe_model.text_encoder,
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=False,
)
base_compel_2 = Compel(
    tokenizer=base_pipe_model.tokenizer_2,
    text_encoder=base_pipe_model.text_encoder_2,
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=True,
)

In [None]:
# generate image

# set embeds
base_positive_prompt_embeds_1 = base_compel_1(config.prompt)
base_positive_prompt_embeds_2, base_positive_prompt_pooled = base_compel_2(config.prompt_2)
base_negative_prompt_embeds_1 = base_compel_1(config.negative_prompt)
base_negative_prompt_embeds_2, base_negative_prompt_pooled = base_compel_2(config.negative_prompt_2)

# Pad the conditioning tensors to ensure thet they all have the same length
(base_positive_prompt_embeds_2, base_negative_prompt_embeds_2) = base_compel_2.pad_conditioning_tensors_to_same_length([base_positive_prompt_embeds_2, base_negative_prompt_embeds_2])

# Concatenate the cconditioning tensors corresponding to both the set of prompts
base_positive_prompt_embeds = torch.cat((base_positive_prompt_embeds_1, base_positive_prompt_embeds_2), dim=-1)
base_negative_prompt_embeds = torch.cat((base_negative_prompt_embeds_1, base_negative_prompt_embeds_2), dim=-1)

generator = torch.Generator("cuda").manual_seed(config.seed)

# base
prompt = config.prompt if config.prompt != "" and not config.use_compel else None
prompt_2 = config.prompt_2  if config.prompt_2 != "" and not config.use_compel else None
negative_prompt = config.negative_prompt if config.negative_prompt != "" and not config.use_compel else None
negative_prompt_2 = config.negative_prompt_2 if config.negative_prompt_2 != "" and not config.use_compel else None
prompt_embeds=base_positive_prompt_embeds if base_positive_prompt_embeds != "" and config.use_compel else None
pooled_prompt_embeds=base_positive_prompt_pooled if base_positive_prompt_pooled != "" and config.use_compel else None
negative_prompt_embeds=base_negative_prompt_embeds if base_negative_prompt_embeds != "" and config.use_compel else None
negative_pooled_prompt_embeds=base_negative_prompt_pooled if base_negative_prompt_pooled != "" and config.use_compel else None
output_type="latent" if config.use_refiner else "pil"

print(f"{prompt}\n{prompt_2}\n{negative_prompt}\n{negative_prompt_2}")
base_output = base_pipe_model(prompt=prompt,
                              prompt_2=prompt_2,
                              negative_prompt=negative_prompt,
                              negative_prompt_2=negative_prompt_2,
                              prompt_embeds=prompt_embeds,
                              pooled_prompt_embeds=pooled_prompt_embeds,
                              negative_prompt_embeds=negative_prompt_embeds,
                              negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
                              num_inference_steps=config.num_inference_steps,
                              generator=generator,
                              guidance_scale=config.guidance_scale,
                              output_type=output_type,
                              denoising_end=config.high_noise_frac,
                              width=config.width,
                              height=config.height)

# save data to config
config.save_config()

# save results
if not config.use_refiner:
    cur_date = datetime.now().strftime("%m_%d_%Y-%H_%M_%S")
    image = base_output.images[0]
    shutil.copyfile("sdxl_config.json", f"results\{cur_date}.json")
    image.save(f"results\{cur_date}.png")
    display(image)
    
# clean gpu cache
torch.cuda.empty_cache()

In [None]:
# refiner

refiner_pipe_model = refiner_pipeline_type.from_pretrained(refiner_model, cache_dir=CACHE_DIR, 
                                    torch_dtype=torch_dtype,
                                    variant=variant,
                                    use_safetensors=use_safetensors,
                                    vae=base_pipe_model.vae,
                                    text_encoder_2=base_pipe_model.text_encoder_2)

refiner_pipe_model.scheduler = scheduler
#refiner_pipe.unet = torch.compile(refiner_pipe.unet, mode="reduce-overhead", fullgraph=True) # not for windows

# refiner pipeline to CUDA
to_cuda(base_pipe_model, "Refiner -> CUDA started", "Refiner -> CUDA finished")


image_refined = refiner_pipe_model(prompt=prompt if prompt != "" and not use_compel else None,
                             prompt_2 = prompt_2 if prompt_2 != "" and not use_compel else None,
                             negative_prompt = negative_prompt if negative_prompt != "" and not use_compel else None,
                             negative_prompt_2 = negative_prompt_2 if negative_prompt_2 != "" and not use_compel else None,
                             prompt_embeds=base_positive_prompt_embeds if base_positive_prompt_embeds != "" and use_compel else None,
                             pooled_prompt_embeds=base_positive_prompt_pooled if base_positive_prompt_pooled != "" and use_compel else None,
                             negative_prompt_embeds=base_negative_prompt_embeds if base_negative_prompt_embeds != "" and use_compel else None,
                             negative_pooled_prompt_embeds=base_negative_prompt_pooled if base_negative_prompt_pooled != "" and use_compel else None,
                             num_inference_steps=num_inference_steps,
                             generator=generator,
                             guidance_scale=guidance_scale,
                             denoising_start=high_noise_frac_slider.value, 
                             image=base_output.images,
                             #original_size = (height, width),
                             #target_size = (height, width)
                             ).images[0]

torch.cuda.empty_cache()
display(image_refined)

In [None]:
## old but gold stuff

#if use_refiner_checkbox.value:
    #base_pipe.to("cpu")
    #torch.cuda.empty_cache()
    #torch.cuda.ipc_collect() 
    #unrefined_image = postprocess_latent(base_pipe, base_output)
    #display(unrefined_image)
#else:
   #display(base_output.images[0])