In [39]:
# init
import torch

# for parameters ui
import ipywidgets as widgets
# for config
import json
# for saving results
import shutil
from datetime import datetime
# for image_grid
from PIL import Image
# for prompt embeddings
from compel import Compel, ReturnedEmbeddingsType

# hugging face cache directory
CACHE_DIR = "D:\HuggingFaceCache"
# config path
CONFIG_PATH = "sdxl_config.json"

# import models, schedulers and etc
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from diffusers import EulerDiscreteScheduler, DDIMScheduler, LMSDiscreteScheduler

class SDXLConfig:
    def __init__(self, 
                 prompt: str = None,
                 prompt_2: str = None,
                 negative_prompt: str = None,
                 negative_prompt_2: str = None,
                 use_compel: bool = False,
                 num_inference_steps: int = 40,
                 width: int = 768,
                 height: int = 768,
                 guidance_scale: float = 7.5,
                 high_noise_frac: float = 0.8,
                 seed: int = 12345,
                 use_refiner: bool = False
                 ):
        self.prompt = prompt
        self.prompt_2 = prompt_2
        self.negative_prompt = negative_prompt
        self.negative_prompt_2 = negative_prompt_2
        self.use_compel = use_compel
        self.num_inference_steps = num_inference_steps
        self.width = width
        self.height = height
        self.guidance_scale = guidance_scale
        self.high_noise_frac = high_noise_frac
        self.seed = seed
        self.use_refiner = use_refiner

    def to_json(obj):
        if isinstance(obj, SDXLConfig):
            return obj.__dict__
    def from_json(dict: dict):
            return SDXLConfig(**dict)
    def load_config():
         with open(CONFIG_PATH, "r") as read_file:
            return json.load(read_file, object_hook=SDXLConfig.from_json)
    def save_config(self):
         with open(CONFIG_PATH, "w") as write_file:
            json.dump(self, write_file, indent=1, default=SDXLConfig.to_json)
            
    def set_ui(self):
        style = {'description_width': 'initial'}
        prompt1_text_area = widgets.Textarea(value=self.prompt, placeholder='Type positive1...', description='Prompt1:', style=style)
        display(prompt1_text_area)
        prompt2_text_area = widgets.Textarea(value=self.prompt_2, placeholder='Type positive2...', description='Prompt2:', style=style)
        display(prompt2_text_area)
        negative_prompt1_text_area = widgets.Textarea(value=self.negative_prompt, placeholder='Type negative1...', description='Negative Prompt1:', style = style)
        display(negative_prompt1_text_area)
        negative_prompt2_text_area = widgets.Textarea(value=self.negative_prompt_2, placeholder='Type negative2...', description='Negative Prompt2:', style = style)
        display(negative_prompt2_text_area)
        use_compel_checkbox = widgets.Checkbox(value=self.use_compel, description="Use Compel", indent=False, style=style)
        display(use_compel_checkbox)

        # inference properties
        num_inference_steps_slider = widgets.IntSlider(value=self.num_inference_steps, min=10, max=100, step=5, description="Num inference steps:", continuous_update=False, style=style)
        display(num_inference_steps_slider)
        width_slider = widgets.IntSlider(value=self.width, min=512, max=1024, step=64, description="Width:", continuous_update=False, style=style)
        display(width_slider)
        height_slider = widgets.IntSlider(value=self.height, min=512, max=1024, step=64, description="Height:", continuous_update=False, style=style)
        display(height_slider)
        guidance_scale_slider = widgets.FloatSlider(value=self.guidance_scale, min=0, max=10, step=0.25, description="Guidance scale:", continuous_update=False, style=style)
        display(guidance_scale_slider)
        seed_slider = widgets.IntSlider(value=self.seed, min=0, max=1000000, step=1, description="Seed:", continuous_update=False, style=style)
        display(seed_slider)
        high_noise_frac_slider = widgets.FloatSlider(value=self.high_noise_frac, min=0, max=1, step=0.05, description="High noise frac:", continuous_update=False, style=style)
        display(high_noise_frac_slider)

        # refiner
        use_refiner_checkbox = widgets.Checkbox(value=self.use_refiner, description="Use refiner", indent=False, style=style)
        display(use_refiner_checkbox)

# utilities methods
def to_cuda(pipe, start_mess, end_mess):
    if(torch.cuda.is_available()):
        print(start_mess)
        pipe = pipe.to("cuda")
    else:
        print("CUDA IS NOT AVAILABLE")
    print(end_mess)

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def postprocess_latent(pipe, latent):
    vae_output = pipe.vae.decode(
        latent.images / pipe.vae.config.scaling_factor, return_dict=False
    )[0].detach()
    return pipe.image_processor.postprocess(vae_output, output_type="pil")[0]


In [40]:
# ui
config: SDXLConfig = SDXLConfig.load_config()
config.set_ui()

Textarea(value='dsrs', description='Prompt1:', placeholder='Type positive1...', style=TextStyle(description_wi…

Textarea(value='', description='Prompt2:', placeholder='Type positive2...', style=TextStyle(description_width=…

Textarea(value='ewrw', description='Negative Prompt1:', placeholder='Type negative1...', style=TextStyle(descr…

Textarea(value='', description='Negative Prompt2:', placeholder='Type negative2...', style=TextStyle(descripti…

Checkbox(value=False, description='Use Compel', indent=False, style=CheckboxStyle(description_width='initial')…

IntSlider(value=40, continuous_update=False, description='Num inference steps:', min=10, step=5, style=SliderS…

IntSlider(value=768, continuous_update=False, description='Width:', max=1024, min=512, step=64, style=SliderSt…

IntSlider(value=768, continuous_update=False, description='Height:', max=1024, min=512, step=64, style=SliderS…

FloatSlider(value=7.5, continuous_update=False, description='Guidance scale:', max=10.0, step=0.25, style=Slid…

IntSlider(value=12345, continuous_update=False, description='Seed:', max=1000000, style=SliderStyle(descriptio…

FloatSlider(value=0.8, continuous_update=False, description='High noise frac:', max=1.0, step=0.05, style=Slid…

Checkbox(value=False, description='Use refiner', indent=False, style=CheckboxStyle(description_width='initial'…

In [10]:
# choose models, schedulers types and etc
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_model = "stabilityai/stable-diffusion-xl-refiner-1.0"
torch_dtype=torch.float16
base_pipeline_type = StableDiffusionXLPipeline
refiner_pipeline_type = StableDiffusionXLImg2ImgPipeline
scheduler_type = LMSDiscreteScheduler
variant="fp16"
use_safetensors=True
#safety_checker = None

In [None]:
# load pipeline, set schelduler, to cuda, compel init and etc
base_pipe = base_pipeline_type.from_pretrained(base_model, cache_dir=CACHE_DIR, 
                                               torch_dtype=torch_dtype,
                                               variant=variant,
                                               use_safetensors=use_safetensors)


scheduler = scheduler_type.from_config(base_pipe.scheduler.config)
base_pipe.scheduler = scheduler
#base_pipe.unet = torch.compile(base_pipe.unet, mode="reduce-overhead", fullgraph=True) # not for windows

# base pipeline to CUDA
to_cuda(base_pipe, "Base -> CUDA started", "Base -> CUDA finished")

# Compels init (TODO: maybe should move "generate image" cell)
base_compel_1 = Compel(
    tokenizer=base_pipe.tokenizer,
    text_encoder=base_pipe.text_encoder,
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=False,
)
base_compel_2 = Compel(
    tokenizer=base_pipe.tokenizer_2,
    text_encoder=base_pipe.text_encoder_2,
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=True,
)

In [None]:
# generate image

# variables from ui
prompt = str(prompt1_text_area.value)
prompt_2 = str(prompt2_text_area.value)
negative_prompt = str(negative_prompt1_text_area.value)
negative_prompt_2 = str(negative_prompt2_text_area.value)
print(f"{prompt}\n{prompt_2}\n{negative_prompt}\n{negative_prompt_2}")

use_compel = use_compel_checkbox.value
num_inference_steps=num_inference_steps_slider.value
guidance_scale=guidance_scale_slider.value
use_refiner= "latent" if use_refiner_checkbox.value else "pil"
high_noise_frac=high_noise_frac_slider.value
width = width_slider.value
height = height_slider.value
seed = seed_slider.value

config = SDXLConfig

# set embeds
base_positive_prompt_embeds_1 = base_compel_1(prompt)
base_positive_prompt_embeds_2, base_positive_prompt_pooled = base_compel_2(prompt_2)
base_negative_prompt_embeds_1 = base_compel_1(negative_prompt)
base_negative_prompt_embeds_2, base_negative_prompt_pooled = base_compel_2(negative_prompt_2)

# Pad the conditioning tensors to ensure thet they all have the same length
(base_positive_prompt_embeds_2, base_negative_prompt_embeds_2) = base_compel_2.pad_conditioning_tensors_to_same_length([base_positive_prompt_embeds_2, base_negative_prompt_embeds_2])

# Concatenate the cconditioning tensors corresponding to both the set of prompts
base_positive_prompt_embeds = torch.cat((base_positive_prompt_embeds_1, base_positive_prompt_embeds_2), dim=-1)
base_negative_prompt_embeds = torch.cat((base_negative_prompt_embeds_1, base_negative_prompt_embeds_2), dim=-1)

generator = torch.Generator("cuda").manual_seed(seed)

# base
base_output = base_pipe(prompt=prompt if prompt != "" and not use_compel else None,
                        prompt_2 = prompt_2 if prompt_2 != "" and not use_compel else None,
                        negative_prompt = negative_prompt if negative_prompt != "" and not use_compel else None,
                        negative_prompt_2 = negative_prompt_2 if negative_prompt_2 != "" and not use_compel else None,
                        prompt_embeds=base_positive_prompt_embeds if base_positive_prompt_embeds != "" and use_compel else None,
                        pooled_prompt_embeds=base_positive_prompt_pooled if base_positive_prompt_pooled != "" and use_compel else None,
                        negative_prompt_embeds=base_negative_prompt_embeds if base_negative_prompt_embeds != "" and use_compel else None,
                        negative_pooled_prompt_embeds=base_negative_prompt_pooled if base_negative_prompt_pooled != "" and use_compel else None,
                        num_inference_steps=num_inference_steps,
                        generator=generator,
                        guidance_scale=guidance_scale,
                        output_type=use_refiner,
                        denoising_end=high_noise_frac,
                        width=width,
                        height=height)

torch.cuda.empty_cache()

In [None]:
# refiner

refiner_pipe = refiner_pipeline_type.from_pretrained(refiner_model, cache_dir=CACHE_DIR, 
                                    torch_dtype=torch_dtype,
                                    variant=variant,
                                    use_safetensors=use_safetensors,
                                    vae=base_pipe.vae,
                                    text_encoder_2=base_pipe.text_encoder_2)

refiner_pipe.scheduler = scheduler
#refiner_pipe.unet = torch.compile(refiner_pipe.unet, mode="reduce-overhead", fullgraph=True) # not for windows

# refiner pipeline to CUDA
to_cuda(base_pipe, "Refiner -> CUDA started", "Refiner -> CUDA finished")


image_refined = refiner_pipe(prompt=prompt if prompt != "" and not use_compel else None,
                             prompt_2 = prompt_2 if prompt_2 != "" and not use_compel else None,
                             negative_prompt = negative_prompt if negative_prompt != "" and not use_compel else None,
                             negative_prompt_2 = negative_prompt_2 if negative_prompt_2 != "" and not use_compel else None,
                             prompt_embeds=base_positive_prompt_embeds if base_positive_prompt_embeds != "" and use_compel else None,
                             pooled_prompt_embeds=base_positive_prompt_pooled if base_positive_prompt_pooled != "" and use_compel else None,
                             negative_prompt_embeds=base_negative_prompt_embeds if base_negative_prompt_embeds != "" and use_compel else None,
                             negative_pooled_prompt_embeds=base_negative_prompt_pooled if base_negative_prompt_pooled != "" and use_compel else None,
                             num_inference_steps=num_inference_steps,
                             generator=generator,
                             guidance_scale=guidance_scale,
                             denoising_start=high_noise_frac_slider.value, 
                             image=base_output.images,
                             #original_size = (height, width),
                             #target_size = (height, width)
                             ).images[0]

torch.cuda.empty_cache()
display(image_refined)

In [None]:
## old but gold stuff

#if use_refiner_checkbox.value:
    #base_pipe.to("cpu")
    #torch.cuda.empty_cache()
    #torch.cuda.ipc_collect() 
    #unrefined_image = postprocess_latent(base_pipe, base_output)
    #display(unrefined_image)
#else:
   #display(base_output.images[0])