In [None]:
import matplotlib.pyplot as plt
from PIL import Image
from typing import List, Optional
import argparse
import ast
from pathlib import Path
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler, AutoencoderTiny
from huggingface_hub import hf_hub_download
import gc
import torch.nn.functional as F
import os, glob
import torch
from tqdm.auto import tqdm
import time, datetime
import numpy as np
from torch.optim import AdamW
from contextlib import ExitStack
from safetensors.torch import load_file
import torch.nn as nn
import random
from transformers import CLIPModel
import sys
sys.path.append('../')
from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
from utils.inference_util import StableDiffusionXLPipelineSliders
from transformers import logging
logging.set_verbosity_warning()

from diffusers import logging
logging.set_verbosity_error()
modules = DEFAULT_TARGET_REPLACE
modules += UNET_TARGET_REPLACE_MODULE_CONV

In [None]:
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo_name = "tianweiy/DMD2"
ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"


device = 'cuda:0'
weight_dtype = torch.bfloat16

# If you want to compose multiple sliders - please use numsliders_to_sample > 1
numsliders_to_sample = 1


unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to(device, weight_dtype)

unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name)))

pipe = StableDiffusionXLPipelineSliders.from_pretrained(base_model_id, unet=unet, torch_dtype=weight_dtype)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

pipe = pipe.to(device).to(weight_dtype)
unet = pipe.unet

## Change these parameters based on how you trained your sliderspace sliders
train_method = 'xattn-strict'
rank = 1 
alpha =1 
print(rank, alpha, train_method)
networks = {}
modules = DEFAULT_TARGET_REPLACE
modules += UNET_TARGET_REPLACE_MODULE_CONV
for i in range(numsliders_to_sample):
    networks[i] = LoRANetwork(
        unet,
        rank=int(rank),
        multiplier=1.0,
        alpha=int(alpha),
        train_method=train_method,
        fast_init=True,
    ).to(device, dtype=weight_dtype)


# Iterate through every slider you discovered and visualize

In [None]:
prompt = 'image of a robot'
sliderspace_path = '../trained_sliders/sdxl/robot/'
slider_scales = [2]
num_images = 6
sliderspace = glob.glob(f'{sliderspace_path}/*.pt')

# seeds = [random.randint(0,2**15) for _ in range(num_images)]
for slider_idx, slider in enumerate(sliderspace):
    image_list = []
    for net in networks:
        networks[net].load_state_dict(torch.load(slider))

    for im in range(num_images):
        seed = seeds[im]
        for scale in slider_scales:
            for net in networks:
                networks[net].set_lora_slider(scale)
            
            generator = torch.manual_seed(seed)

            images = pipe(prompt, num_images_per_prompt=1, num_inference_steps=4, guidance_scale=0, generator=generator,
                         networks=networks,
                         # FROM WHAT TIMESTEP OF TOTAL INFERENE STEPS DO YOU wANT TO APPLY SLIDER? 
                         # INCREASE TO HAVE MORE PRECISE EDITS 
                         # (NOTE: FOR DISTILLED MODELS IT MIGHT NOT SHOW FULL SLIDER CAPACITY DUE TO THE FIRST TIMESTEP BEING SO POWERFUL)
                         apply_sliders_from=0, 
                         apply_sliders_till=None, # LEAVE THIS NONE. UNLESS YOU WANT TO SKIP APPLYING SLIDER AT FINAL TIMESTEPS?
                         ).images[0]
            image_list.append(images)

    print(f"Slider {os.path.basename(slider).replace('.pt','').split('_')[-1]}")
    fig, ax = plt.subplots(1, len(image_list),frameon=False, dpi=300)
    fig.set_size_inches(num_images * 3, 3)  # Adjust multiplier as needed

    for i, a in enumerate(ax):
        a.imshow(image_list[i])
        a.axis('off')
        a.set_position([0, 0, 1, 1])
    

    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    # Ensure the figure has no padding
    fig.tight_layout(pad=0)
    plt.show()

# Make Gifs by sliding the slider scales

In [None]:
def generate_smooth_gif(
    pipe,
    networks,
    prompt,
    slider_path,
    num_frames=30,
    scale_start=0,
    scale_end=2,
    seed=None,
    output_path="interpolation.gif",
    fps=10,
    dpi=300,
):
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)
    
    for net in networks:
        networks[net].load_state_dict(torch.load(slider_path))
    
    # Simple linear interpolation for forward frames
    scales = np.linspace(scale_start, scale_end, num_frames)
    
    images = []
    seed = random.randint(0, 2**15)
    pipe.set_progress_bar_config(disable=True)
    
    for scale in tqdm(scales):
        for net in networks:
            networks[net].set_lora_slider(float(scale))
        
        generator = torch.manual_seed(seed)

        image = pipe(
            prompt,
            num_images_per_prompt=1,
            num_inference_steps=4,
            guidance_scale=0,
            generator=generator,
            networks=networks,
             # FROM WHAT TIMESTEP OF TOTAL INFERENE STEPS DO YOU wANT TO APPLY SLIDER? 
             # INCREASE TO HAVE MORE PRECISE EDITS 
             # (NOTE: FOR DISTILLED MODELS IT MIGHT NOT SHOW FULL SLIDER CAPACITY DUE TO THE FIRST TIMESTEP BEING SO POWERFUL)
             apply_sliders_from=1, 
             apply_sliders_till=None, # LEAVE THIS NONE. UNLESS YOU WANT TO SKIP APPLYING SLIDER AT FINAL TIMESTEPS?
        ).images[0]
            
        if not isinstance(image, Image.Image):
            image = Image.fromarray(np.uint8(image))
        
        # Resize image while maintaining aspect ratio
        image.thumbnail((256, 256), Image.Resampling.LANCZOS)
        images.append(image)
    
    duration = 1000 / fps  # Convert fps to milliseconds
    images[0].save(
        output_path,
        save_all=True,
        append_images=images[1:],
        duration=duration,
        loop=0,
        optimize=True,
        quality=70,
        dpi=dpi
    )
    pipe.set_progress_bar_config(disable=False)
    return images

In [None]:
# sliderspace_path = '../trained_sliders/sdxl/ancient ruins/'
# prompt = 'picture of a ancient ruin'

slider_scale_start = 0 
slider_scale_end = 2

num_frames = 50 # how many scales you want to interpolate between the start and end scale
num_images = 2
    
sliderspace = glob.glob(f'{sliderspace_path}/*.pt')
gifs_path = sliderspace_path.replace('trained_sliders','gifs')

os.makedirs(gifs_path, exist_ok=True)

seeds = [random.randint(0, 2**15) for _ in range(num_images)]
print(seeds)

for slider_idx, slider in enumerate(sliderspace):
    slider_number = int(os.path.basename(slider).split('.')[0].split('_')[-1])

    print(os.path.basename(slider))
    for seed in seeds:
        
        output_filename = f"{gifs_path}/{os.path.basename(slider).replace('.pt','_'+str(seed)+'.gif')}"
        
        images = generate_smooth_gif(
            pipe=pipe,
            networks=networks,
            prompt=prompt,
            slider_path=slider,
            num_frames=num_frames,  # Adjust for smoother/faster interpolation
            scale_start=slider_scale_start,
            scale_end=slider_scale_end,
            output_path=output_filename,
            seed = seed,
            fps=20  # Adjust for slower/faster playback
        )
        
        # Optional: Display first and last frame for verification
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        ax1.imshow(images[0])
        ax1.set_title("Start (Scale = 0)")
        ax1.axis('off')
        ax2.imshow(images[-1])
        ax2.set_title("End (Scale = 1)")
        ax2.axis('off')
        plt.show()

# Generate Images with Multiple Sliders Composed

## Make sure that you initialized more than 1 slider 


In [None]:
prompt = 'picture of a spaceship'
sliderspace_path = '../trained_sliders/spaceship/'

slider_scale = 1
num_images = 10
sliderspace = glob.glob(f'{sliderspace_path}/*.pt')
image_list = []
for idx in range(num_images):
    seed = random.randint(0,2**15)
    generator = torch.manual_seed(seed)
    sliderspace_samples = random.sample(sliderspace, numsliders_to_sample)
    for i, net in enumerate(networks):
        networks[net].load_state_dict(torch.load(sliderspace_samples[i]))
        networks[net].set_lora_slider(slider_scale)
        with networks[net]:
            pass
        
    with ExitStack() as es:
        for net in networks:
            es.enter_context(networks[net])
        images = pipe(prompt, num_images_per_prompt=1, num_inference_steps=4, guidance_scale=0, generator=generator,
                     networks=networks,
                     # FROM WHAT TIMESTEP OF TOTAL INFERENE STEPS DO YOU wANT TO APPLY SLIDER? 
                     # INCREASE TO HAVE MORE PRECISE EDITS 
                     # (NOTE: FOR DISTILLED MODELS IT MIGHT NOT SHOW FULL SLIDER CAPACITY DUE TO THE FIRST TIMESTEP BEING SO POWERFUL)
                     apply_sliders_from=1, 
                     apply_sliders_till=None, # LEAVE THIS NONE. UNLESS YOU WANT TO SKIP APPLYING SLIDER AT FINAL TIMESTEPS?).images[0]
                     ).images[0]
    image_list.append(images)

print('Randomly Sampled and Composed Sliders')
fig, ax = plt.subplots(1, len(image_list),frameon=False, dpi=600)
for i, a in enumerate(ax):
    a.imshow(image_list[i])
    a.axis('off')

plt.subplots_adjust(wspace=0, hspace=0)
plt.show()