In [7]:
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL
from PIL import Image

import os
from typing import List

import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from PIL import Image

import math

import torch
import torch.nn as nn
from diffusers import StableDiffusionControlNetPipeline, DDIMScheduler, AutoencoderKL, ControlNetModel

In [33]:
base_model_path = "runwayml/stable-diffusion-v1-5"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "models/image_encoder/"
ip_ckpt = "models/ip-adapter_sd15.bin"
device = "cuda"

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path, cache_dir='/home/tyk/hf_cache').to(dtype=torch.float16)


image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", cache_dir='/home/tyk/hf_cache').to(device, dtype=torch.float16)
clip_image_processor = CLIPImageProcessor()

In [43]:
controlnet_model_path = "lllyasviel/control_v11f1p_sd15_depth"
controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16, cache_dir='/home/tyk/hf_cache')
# load SD pipeline
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    base_model_path,
    controlnet=controlnet,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None,
    cache_dir='/home/tyk/hf_cache'
).to(device)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [60]:
pipe.unet.config.cross_attention_dim

768

: 

In [34]:
class ImageProjModel(torch.nn.Module):
    """Projection Model"""
    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()
        
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)
        
    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

In [35]:
pipe.unet.config.cross_attention_dim

768

In [36]:
image_encoder.config.projection_dim

1024

In [37]:
image_proj_model = ImageProjModel(
            cross_attention_dim=pipe.unet.config.cross_attention_dim,
            clip_embeddings_dim=image_encoder.config.projection_dim,
            clip_extra_context_tokens=4,
        ).to(device, dtype=torch.float16)

In [38]:
from diffusers.utils import load_image

pil_image = load_image(
    "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")


In [39]:
clip_image = clip_image_processor(images=pil_image, return_tensors="pt").pixel_values # [1, 3, 224, 224]

In [40]:
clip_image_embeds = image_encoder(clip_image.to(device, dtype=torch.float16)).image_embeds # [1, 1024]

In [41]:
image_prompt_embeds = image_proj_model(clip_image_embeds) # [1, 4, 768]

In [50]:
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, 4, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

In [51]:
image_prompt_embeds.shape

torch.Size([4, 4, 768])

In [57]:
uncond_image_prompt_embeds = image_proj_model(torch.zeros_like(clip_image_embeds))
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# torch.Size([4, 4, 768])

In [58]:
uncond_image_prompt_embeds.shape

torch.Size([4, 4, 768])

In [59]:
num_prompts = 1
num_samples = 4


prompt = "best quality, high quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

prompt = [prompt] * num_prompts
negative_prompt = [negative_prompt] * num_prompts

# torch.Size([8, 77, 768])
prompt_embeds = pipe._encode_prompt(prompt, device=device,
                    num_images_per_prompt=num_samples,
                    do_classifier_free_guidance=True,
                    negative_prompt=negative_prompt)

# [4, 77, 768]
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)

# torch.Size([4, 77 + 4, 768])
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

In [56]:
negative_prompt_embeds_.shape

torch.Size([4, 77, 768])

In [None]:
import torch

class StableDiffusionControlNetPipeline():

    #
    def _encode_prompt(self, 
                       prompt,
                       device,
                       num_images_per_prompt,
                       do_classifier_free_guidance,
                       negative_prompt=None,
                       prompt_embeds: Optional[torch.FloatTensor] = None,
                       negative_prompt_embeds: Optional[torch.FloatTensor] = None,
                       lora_scale: Optional[float] = None,):

        # prompt를 이용
        if prompt_embeds is None:
            ...

            prompt_embeds = self.text_encoder(
                text_input_ids.to(device),
                attention_mask=attention_mask,
            )
            prompt_embeds = prompt_embeds[0]
            

        
        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

        # get unconditional embeddings for classifier free guidance
        # negative_prompt_embeds가 있어서 여긴 실행 x
        if do_classifier_free_guidance and negative_prompt_embeds is None:
            ...

        # CFG를 하면
        if do_classifier_free_guidance:
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)


            prompt_embeds = torch.cat([])
            

        return prompt_embeds
        
    def __call__(self,
                 prompt: Union[str, List[str]] = None,
                 image,
                 height: Optional[int] = None,
                 width: Optional[int] = None,
                 num_inference_step=50,
                 prompt_embeds = None
                 ):

        # call하면 netavie랑 같이 묶여있음 : torch.Size([4, 2*(77 + 4), 768])
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )


        # depth map을 
        if isinstance(controlnet, ControlNetModel):
            image = self.prepare_image(image=image,
                                       ...
            )


        # 8. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                down_block_res_samples, mid_block_res_sample = self.controlnet(
                    
            
        
class IPAdapter:

    @torch.inference_mode()
    def get_image_embeds(self, pil_image):
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]

        
    def generate(self,
                 pil_image,
                 prompt=None, # multimodal prompt에서 사용 ("best quality, high quality, wearing a hat on the beach")
                 negative_prompt=None,
                 scale=1.0,
                 num_samples=4,
                 seed=-1,
                 guidance_scale=7.5,
                 num_inference_steps=30,
                 **kwargs, # controlnet_conditioning_scale=0.7, image=depth_map
                 ):
        self.set_scale(scale)

        # 이미지 갯수만큼
        if isinstance(pil_image, Image.Image):
            num_prompts = 1
        else:
            num_prompts = len(pil_image)

        # prompt는 항상 넣음
        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        # 이미지 갯수만큼
        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts


        # projection에서 나온 embedding
        # clip과 shape가 같은듯
        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        # prompt, do_classifier_free_guidance, negative_prompt,
        # prompt_embeds=prompt_embeds, negative_prompt_embeds
        # 넣어서 prompt_embeds를 얻는 _encode_prompt를 사용함
        with torch.inference_mode():
            # 그런데 prompt에 대한 promt_embeds, negative_prompt를 넣음
            prompt_embeds = self.pipe._encode_prompt(
                prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)

            negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        # prompt embeds를 넣음
        # 여기선 prompt가 없음
        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            **kwargs,# 나머지 kwargs 다 들어갈 수 있게 # controlnet_conditioning_scale=0.7, image=depth_map
        ).images

        return images

In [None]:
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = "models/image_encoder"
ip_ckpt = "models/ip-adapter_sdxl_vit-h.bin"
device = "cuda"

In [None]:
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    base_model_path,
    controlnet=controlnet,
    use_safetensors=True,
    torch_dtype=torch.float16,
    add_watermarker=False,
).to(device)

In [None]:
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)

images = ip_model.generate(pil_image=image, image=depth_map, controlnet_conditioning_scale=0.7, num_samples=num_samples, num_inference_steps=30, seed=42)