In [None]:
from diffusers import DiffusionPipeline, UNet2DConditionModel, StableDiffusionXLPipeline, LCMScheduler
from diffusers.callbacks import PipelineCallback
from huggingface_hub import hf_hub_download

import matplotlib.pyplot as plt
import torch
import glob, random, os

In [None]:
def display_slider_images(images, titles):
    fig, axes = plt.subplots(1, len(images), figsize=(len(images)*3, 3))
    
    for i, (img, title) in enumerate(zip(images, titles)):
        if len(images) == 1:
            ax = axes
        else:
            ax = axes[i]
        ax.imshow(img)
        ax.set_title(title)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
class ConceptSliderCallback(PipelineCallback):
    """
    Enable Concept Slider after certain number of steps (set by `slider_strength`), this callback will set the LoRA scale to `0.0` or `slider_scale` based on the strength.

    Use strength < 1 if you want more precise edits (recommend: .7 - .9)
    """
    tensor_inputs = []

    def __init__(self, slider_strength=1, slider_names=None, slider_scales=[0]):
        super().__init__()
        self.slider_names = slider_names
        self.slider_scales = slider_scales
        self.slider_strength = slider_strength
    
    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs):
        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
        attach_step = (
           pipeline.num_timesteps - int(pipeline.num_timesteps * self.slider_strength)
        )


        # at the attach_step point start adding the slider
        if step_index == attach_step:
            pipe.set_adapters(self.slider_names, adapter_weights=self.slider_scales)

        # after final step set the slider to 0 (there is a better implementation if we  callback_at_beginning of step exists in diffusers) 
        if step_index == pipeline.num_timesteps-1 and self.slider_strength!=1:
            pipe.set_adapters(self.slider_names, adapter_weights=[0.]*len(self.slider_names))
        
        return callback_kwargs

In [None]:
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo_name = "tianweiy/DMD2"
ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"


device = 'cuda:0'
weight_dtype = torch.bfloat16

unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to(device, weight_dtype)

unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name)))

pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=weight_dtype)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

pipe = pipe.to(device).to(weight_dtype)

# Iterate through every slider you discovered and visualize

In [None]:
sliderspace_path = '../trained_sliders/sdxl/robot/'
slider_scales = [-2, -1, 0, 1, 2]
sliderspace = glob.glob(f'{sliderspace_path}/*.pt')

prompt = 'image of a robot'
seed = random.randint(0, 2**15)

active_adapters = pipe.get_active_adapters()
[pipe.delete_adapters(s) for s in active_adapters if 'sliderspace' in s]

for slider_idx, slider in enumerate(sliderspace):
    image_list = []

    # you can use your trained slider (either .pt or .safetensors file with diffusers)
    adapter_path = slider
    adapter_name = f'sliderspace_{slider_idx}'
    
    pipe.load_lora_weights(adapter_path, adapter_name=adapter_name)
    pipe.set_adapters(adapter_name, adapter_weights=0)

    
        
    for scale in slider_scales:
        sliders_fn = ConceptSliderCallback(slider_strength=1, 
                                           slider_names=[adapter_name], 
                                           slider_scales=[scale])

        images = pipe(prompt, 
                      num_inference_steps=4, 
                      guidance_scale=0, 
                      generator=torch.manual_seed(seed),
                      callback_on_step_end=sliders_fn,
                     ).images[0]
        image_list.append(images)

    print(f"Slider {os.path.basename(slider).replace('.pt','').split('_')[-1]}")
    display_slider_images(image_list, slider_scales)
    pipe.delete_adapters(adapter_name)