In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HF_CACHE"] = "/tmp/wendler/hf_cache"
import sys
sys.path.append("..")
sys.path.append("../NewtonRaphsonInversion")

In [28]:
url = "/share/datasets/datasets/laicoyo/{000000..000009}.tar"
save_path = "/data_shared/wendler/turbo_latents_from_laicoyo_inversion"
blocks_to_save = [
        'unet.down_blocks.2.attentions.1',
        'unet.mid_block.attentions.0',
        'unet.up_blocks.0.attentions.0',
        'unet.up_blocks.0.attentions.1',
    ]
n_max = 20000

In [3]:
import webdataset as wds
import logging
import torch
import io

def url_to_dataloader(url, num_workers=4, batch_size=16, shuffle=False):
    def log_and_continue(exn):
        """Call in an exception handler to ignore any exception, issue a warning, and continue."""
        logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
        return True
    
    def filter_no_latent(sample):
        return 'latent.pt' in sample

    def load_latent(z):
        return torch.load(io.BytesIO(z), map_location='cpu').to(torch.float32)
    if shuffle:
        pipeline = [
            wds.SimpleShardList(url),
            wds.split_by_node,
            wds.split_by_worker,
            wds.tarfile_to_samples(handler=log_and_continue),
            wds.select(filter_no_latent),
            wds.shuffle(bufsize=5000, initial=1000),
            wds.rename(image="latent.pt", txt="txt"),
            wds.map_dict(image=load_latent, txt=lambda x: x.decode("utf-8")),
            wds.to_tuple("image", "txt"),
            wds.batched(batch_size, partial=False),
        ]
    else:
        pipeline = [
            wds.SimpleShardList(url),
            wds.split_by_node,
            wds.split_by_worker,
            wds.tarfile_to_samples(handler=log_and_continue),
            wds.select(filter_no_latent),
            wds.rename(image="latent.pt", txt="txt"),
            wds.map_dict(image=load_latent, txt=lambda x: x.decode("utf-8")),
            wds.to_tuple("image", "txt"),
            wds.batched(batch_size, partial=False),
        ]

    dataset = wds.DataPipeline(*pipeline)

    loader = wds.WebLoader(
        dataset, batch_size=None, shuffle=False, num_workers=num_workers,
    )
    return loader

In [4]:
loader = url_to_dataloader(url)

In [5]:
from SDLens import HookedStableDiffusionXLImg2ImgPipeline
from src.config import RunConfig
from ipywidgets import Text, VBox
import PIL
from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
from PIL import Image
from diffusers.utils.torch_utils import randn_tensor
from IPython.display import display

model = "stabilityai/sdxl-turbo"

def inversion_callback(pipe, step, timestep, callback_kwargs):
    return callback_kwargs


def inference_callback(pipe, step, timestep, callback_kwargs):
    return callback_kwargs

class ImageEditorDemo:
    def __init__(self, pipe_inversion, pipe_inference, latents, prompts, cfg, edit_cfg=1.2):
        self.pipe_inversion = pipe_inversion
        self.pipe_inference = pipe_inference
        self.load_image = True
        g_cpu = torch.Generator().manual_seed(7865)
        if model == "stabilityai/stable-diffusion-xl-base-1.0":
            img_size = (1024,1024)
        else:
            img_size = (512,512)
        # resise input image
        VQAE_SCALE = 8
        latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)
        latents_size = latents.shape
        noise = [randn_tensor(latents_size, dtype=pipe_inversion.unet.dtype, device=torch.device("cuda:0"), generator=g_cpu) for i
                 in range(cfg.num_inversion_steps)]
        pipe_inversion.scheduler.set_noise_list(noise)
        pipe_inference.scheduler.set_noise_list(noise)
        pipe_inversion.scheduler_inference.set_noise_list(noise)
        pipe_inversion.set_progress_bar_config(disable=True)
        pipe_inference.set_progress_bar_config(disable=True)
        self.cfg = cfg
        self.pipe_inversion.cfg = cfg
        self.pipe_inference.cfg = cfg
        self.inv_hp = [2, 0.1, 0.2] # niter, alpha, lr 2, 0.1, 0.2 is default
        self.edit_cfg = edit_cfg
        self.latents = latents
        self.last_latent = self.invert(latents, prompts)
        self.original_latent = self.last_latent

    def invert(self, latents, base_prompts):
        res = self.pipe_inversion.invert_latents(prompt=base_prompts,
                             num_inversion_steps=self.cfg.num_inversion_steps,
                             num_inference_steps=self.cfg.num_inference_steps,
                             latents=latents,
                             guidance_scale=self.cfg.guidance_scale,
                             callback_on_step_end=inversion_callback,
                             strength=self.cfg.inversion_max_step,
                             denoising_start=1.0 - self.cfg.inversion_max_step,
                             inv_hp=self.inv_hp)[0][0]
        return res

    def edit(self, target_prompt, guidance_scale=None):
        if guidance_scale is None:
            guidance_scale = self.edit_cfg
        image = self.pipe_inference(prompt=target_prompt,
                            num_inference_steps=self.cfg.num_inference_steps,
                            negative_prompt="",
                            callback_on_step_end=inference_callback,
                            image=self.last_latent,
                            strength=self.cfg.inversion_max_step,
                            denoising_start=1.0 - self.cfg.inversion_max_step,
                            guidance_scale=guidance_scale).images[0]
        return image.resize((512, 512))
    
    def edit_and_cache(self, target_prompt, guidance_scale=None):
        if guidance_scale is None:
            guidance_scale = self.edit_cfg
        image, cache = self.pipe_inference.run_with_cache(prompt=target_prompt,
                            positions_to_cache=blocks_to_save,
                            save_input=True,
                            save_output=True,
                            output_type='pil',
                            num_inference_steps=self.cfg.num_inference_steps,
                            negative_prompt="",
                            callback_on_step_end=inference_callback,
                            image=self.last_latent,
                            strength=self.cfg.inversion_max_step,
                            denoising_start=1.0 - self.cfg.inversion_max_step,
                            guidance_scale=guidance_scale)
        return image, cache

There was a problem when trying to write in your cache folder (/share/u/models/hub). Please, ensure the directory exists and can be written to.
  deprecate("VQEncoderOutput", "0.31", deprecation_message)
  deprecate("VQModel", "0.31", deprecation_message)


In [15]:
if model == "stabilityai/stable-diffusion-xl-base-1.0":
    image_size = (1024,1024)
    config = RunConfig(num_inference_steps=20,
                   num_inversion_steps=20,
                   guidance_scale=0.0,
                   inversion_max_step=0.75) #4,4,0,0.6 is default settings 0.6 and 0.7 look the same
else:
    image_size = (512,512)
    config = RunConfig(num_inference_steps=4,
                   num_inversion_steps=4,
                   guidance_scale=0.0,
                   inversion_max_step=0.75) #4,4,0,0.6 is default settings 0.6 and 0.7 look the same




In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


dtype = torch.float32
scheduler_class = MyEulerAncestralDiscreteScheduler
if model == "stabilityai/stable-diffusion-xl-base-1.0":
    pipe_inversion = SDXLDDIMPipeline.from_pretrained(model, 
                                                      torch_dtype=dtype,
                                                      device_map="balanced",
                                                      variant=("fp16" if dtype==torch.float16 else None))
    
    pipe_inference = HookedStableDiffusionXLImg2ImgPipeline.from_pretrained(model, 
                                                                        torch_dtype=dtype,
                                                                        device_map="balanced",
                                                                        variant=("fp16" if dtype==torch.float16 else None)
                                                                    )
    if dtype == torch.float32:
        pipe_inversion.text_encoder_2.to(dtype)
        pipe_inference.text_encoder_2.to(dtype)
else:
    pipe_inversion = SDXLDDIMPipeline.from_pretrained(model, use_safetensors=True, safety_checker=None, cache_dir=os.environ["HF_HOME"]).to(device)
    pipe_inference = HookedStableDiffusionXLImg2ImgPipeline.from_pretrained(model, use_safetensors=True, safety_checker=None, cache_dir=os.environ["HF_HOME"]).to(device)

#pipe_inference = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True, safety_checker= None, cache_dir=os.environ["HF_HOME"]).to(device)
pipe_inference.scheduler            = scheduler_class.from_config(pipe_inference.scheduler.config)
pipe_inversion.scheduler            = scheduler_class.from_config(pipe_inversion.scheduler.config)
pipe_inversion.scheduler_inference  = scheduler_class.from_config(pipe_inference.scheduler.config)

Keyword arguments {'safety_checker': None} are not expected by SDXLDDIMPipeline and will be ignored.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Keyword arguments {'safety_checker': None} are not expected by StableDiffusionXLImg2ImgPipeline and will be ignored.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [21]:
if False:
    # it's bit more involved to align the 1 step and 4 step process 
    # the problem is that when not starting from 999 it is hard to align the scheduler timesteps (via its methods and pipeline args) between the two processes
    # that's why i need to set: self_inversion_max_step = 1. in order to get error 0 between the latents
    from copy import deepcopy
    idx = 0
    latents, prompts = next(iter(loader))
    latents = latents.cuda()
    latents *= pipe_inference.vae.config.scaling_factor
    editor1 = ImageEditorDemo(pipe_inversion, pipe_inference, latents[idx].unsqueeze(0), prompts[idx], config, edit_cfg=0.0) 
    editor1.cfg.num_inference_steps = 4
    editor1.cfg.inversion_max_step = 1.
    img1, cache1 = editor1.edit_and_cache("a black lady with a huge afro " + prompts[idx], guidance_scale=0.0)
    print(pipe_inference.scheduler.timesteps)
    editor1.cfg.num_inference_steps = 1
    editor1.cfg.inversion_max_step = 1.
    img2, cache2 = editor1.edit_and_cache("a black lady with a huge afro " + prompts[idx], guidance_scale=0.0)
    print(pipe_inference.scheduler.timesteps)
    import matplotlib.pyplot as plt
    print(prompts[idx])
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img1[0][0])
    plt.title('Image 1')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(img2[0][0])
    plt.title('Image 2')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

    for b in blocks_to_save:
        print(b)
        print(cache1['output'][b].shape)
        print(cache2['output'][b].shape)
        print((cache1['output'][b][:,0] - cache2['output'][b][:,0]).abs().mean())



# save dataset

In [33]:
import datetime
import numpy as np

ct = datetime.datetime.now()
save_path = os.path.join(save_path, str(ct))
# Collecting dataset
os.makedirs(save_path, exist_ok=True)

writers = {
    block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save
}

writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')})

num_document = 0
for d in loader:
    latents, prompts = d
    latents = latents.cuda()
    latents *= pipe_inference.vae.config.scaling_factor
    for latent, prompt in zip(latents, prompts):
        config = RunConfig(num_inference_steps=4,
                num_inversion_steps=4,
                guidance_scale=0.0,
                inversion_max_step=0.75)
        editor = ImageEditorDemo(pipe_inversion, pipe_inference, latent.unsqueeze(0), prompt, config, edit_cfg=0.0) 
        with torch.no_grad():
            output, cache = editor.edit_and_cache(prompt, guidance_scale=0.0)
        blocks = cache['input'].keys()
        for block in blocks:
            sample = {
                "__key__": f"sample_{num_document}",
                "output.pth": cache['output'][block].cpu(),
                "diff.pth": (cache['output'][block] - cache['input'][block]).cpu(),
                "prompt": prompt,
                "inverted_latent.pth": editor.last_latent.cpu(),
                "original_latent.pth": latent.cpu(),
                "gen_args.json": {"num_inference_steps": 4, 
                                "inversion_max_step": 0.75, 
                                "guidance_scale": 0.0, 
                                "num_inversion_steps": 4,
                                "edit_cfg": 0.0}
            }

            writers[block].write(sample)
            writers['images'].write({
                "__key__": f"sample_{num_document}",
                "images.npy": np.stack(output.images)
            })
        num_document += 1
        if num_document >= n_max:
            break
    if num_document >= n_max:
        break

for block, writer in writers.items():
    writer.close()

(1, 512, 512, 3)
(1, 512, 512, 3)
