In [None]:
import numpy as np
import torch
import torch.nn as nn

In [None]:
if is_torch2_available():
  from .atten

In [None]:
class Resampler(nn.Module):
  def __init__(
      self, 
      dim=1024,
      depth=8, 
      dim_head=64,
      heads=16,
      num_quries=8,
      embedding_dim=768,
      output_dim=1024,
      ff_mult=4
      ):
    super().__init__()

    self.latents = nn.Parameter(torch.randn(1, num_quries, dim)/dim**0.5)
    self.proj_in = nn.Linear(embedding_dim, dim)
    self.proj_out = nn.Linear(dim, output_dim)
    self.norm_out = nn.LayerNorm(output_dim)

    self.layers = nn.ModuleList([])

  def forward(self, x):
    latents = self.

In [None]:
class ImageProjModel(nn.Module):
  def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
    super().__init__()

    self.cross_attention_dim = cross_attention_dim
    self.clip_extra_context_tokens = clip_extra_context_tokens
    self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
    self.norm = torch.nn.LayerNorm(cross_attention_dim)

  def forward(self, image_embeds):
    embeds = image_embeds
    clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
    clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
    return clip_extra_context_tokens

In [None]:
class IPAdapter:
    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
        self.device = device
        self.image_encoder_path = image_encoder_path
        self.ip_ckpt = ip_ckpt
        self.num_tokens = num_tokens

        self.pipe = sd_pipe.to(self.device)
        self.set_ip_adapter()

        self.image_encoder = self.init_proj()

        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16)
        self.clip_image_processor = CLIPImageProcessor()

        self.image_proj_model = self.init_proj()

        self.load_ip_adapter()
        
    def init_proj(self):
        image_proj_model = ImageProjModel(
            cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
            clip_embeddings_dim=self.image_encoder.config.projection_dim,
            clip_extra_context_tokens=self.num_tokens,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model

    def set_ip_adapter(self):
        unet = self.pipe.unet
        attn_procs = {}
        # moduel list라서 index 번호 있음
        # dict_keys(['down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', 'mid_block.attentions.0.transformer_blocks.0.attn1.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor'])
        for name in unet.attn_processor.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim

            if name.startswith('mid_block'):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith('up_blocks'):
                block_id = int(name[len])

    def load_ip_adapter(self):
        state_dict = torch.load(self.ip_ckpt, map_location='cpu')
        self.image_proj_model.load_state_dict(state_dict['image_proj'])
        ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
        ip_layers.load_state_dict(state_dict['ip_adapter'])

    @torch.inference_mode()
    def get_image_embeds(self, pil_image):
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]

        # CLIP Processor
        clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values

        # CLIP 인코더
        clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtye=torch.float16)).image_embeds

        # projection 
        image_prompt_embeds = self.image_proj_model(clip_image_embeds)
        uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
        return image_prompt_embeds, uncond_image_prompt_embeds

    def set_scale(self, scale):
        for attn_processor in self.pipe.unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor):
                attn_processor.scale = scale


    def generate(
        self,
        pil_image,
        prompt=None,
        negative_prompt=None,
        scale=1.0,
        
    ):
        self.set_scale(scale)

        if isinstance(pil_image, Image.Image):
            num_prompts = 1
        else:
            num_prompts = len(pil_image)

        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts


        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
        bs_embed, seq_len, _ = image_prompt_embeds.shape

        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples,
                                                                     seq_len,
                                                                     -1)
        
        with torch.inference_mode():
            prompt_embeds = self.pipe._encode_prompt(
                prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
            negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            **kwargs,
        ).images

        return images


        
    def set_ip_adapter(self):
        unet = self.pipe.unet
        attn_procs = {}
        for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim

    
      

In [None]:
class IPAdapterXL(IPAdapter):
    def generate(
        self,
        pil
    ):
        self.set_scale(scale)

        if isinstance(pil_image, Image.Image):
            num_prompts = 1
        else:
            num_prompts = len(pil_image)

        if prompt is None:
            prompt = "best quality, high quality"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)


        with torch.inference_mode():
            prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt(
                prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt)

            
        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        images = self.pipe(
            prompt_embeds = prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            num_inference_steps=num_inference_steps,
            generator=generator,
            **kwargs,
        ).images

        return images

In [None]:
class IPAdapterPlus(IPAdapter):
    def init_proj(self):
        image_proj_model = Resampler(
            dim=self.pipe.unet.config.cross_attention_dim,
            depth=4,
            dim_head=64,
            heads=12,
            
            ).to(self.device, dtype=torch.float16)
        return image_proj_model

    @torch.inference_mode()
    def get_image_embeds(self, pil_image):
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]

        # CLIP
        clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
        clip_image = clip_image.to(self.device, dtype=torch.float16)
        
        # CLIP encoder
        # 달라진건 image embedding이 아니라 hidden states를 가지고 온다
        # CLIPEncoder에서 얻은 hidden_states를 가지고 옴
        # CLIPEncoderLayer마다 거쳐서 나오는 intermediate feature 이용
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]

        image_prompt_embeds = self.image_proj_model(clip_image_embeds)

        uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
        uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)

        return image_prompt_embeds, uncond_image_prompt_embeds

In [None]:
class DDIMSampler(object):
    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                    extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)



    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False,
                      use_original_steps=False, input_noise=None,
                      quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None):
        

        return x_prev, pred_x0
    
    @torch.no_grad()
    def decode(self, x_latent, cond, t_start, unconditional_guiodnace_scale=1.0,
            unconditional_conditioning=None,
                use_original_steps=False, input_noise = None):
        
        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
        timesteps = timesteps[:t_start]

        time_range = np.flip(timesteps)
        total_steps = timesteps.shape[0]


        x_dec = x_latent
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index,
                                        use_original_steps=use_original_steps,
                                        unconditional_conditioning_scale=unconditional_guidance_scale,
                                        unconditional_conditioning=unconditional_conditioning,
                                        input_noise=input_noise)
            
            return x_dec



class EmbeddingManager(nn.Module):
    def __init__(self,):

    def load(self, ckpt_path):
        ckpt = torch.load(ckpt_path, map_location='cpu')
        self.string_to_token_dict = ckpt["string_to_token"]

        if 'attention' in ckpt.keys():
            self.attention = ckpt["attention"]
        else:
            self.attention = None

    
    def 




class LatentDiffusion(DDPM):
    def __init__(self,
                 first_stage_config,
                 cond_stage_config,
                 ):
        

        self.instantiate_first_stage(first_stage_config)
        self.instantiate_cond_stage(cond_stage_config)


        self.cond_stage_model.train = disabled_train
        for param in self.cond_stage_model.parameters():
            param.requires_grad = False



        self.embedding_manager = self.instantiate_embedding_mananger(personalization_config, self.cond_stage_model)

        self.emb_ckpt_counter = 0
        self.mse_loss = nn.MSELoss()

        for param in self.embedding_manager.embedding_parameters():
            param.requires_grad = True


    def instantiate_embedding_manager(self, config, embedder):
        model = instantiate_from_config(config, embedder=embedder)

        if config.params.get():
            model.load(config.params.embedding_manager_ckpt)

        return model
    

    def instantiate_cond_stage(self, config):
        if not self.cond_stage_trainable:
            if config == '__is_first_stage__':
                self.cond_stage_model = self.first_stage_model
            elif config == '__is_unconditional__':
                self.cond_stage_model = None
            else:
                model = instantiate_from_config(config)
                self.cond_stage_model = model.eval()
                self.cond_stage_model.train = disabled_train
                for param in self.cond_stage_model.parameters():
                    param.requires_grad = False
        else:
            model = instantiate_from_config(config)
            self.cond_stage_model = model



    def get_learned_conditining(slef, c, x=None):
        if slef.cond_stage_forward is None:
            if hasattr() and callable(self.cond_stage_model.encode):
                c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager, input_img = x)
                if isinstance(c, DiagonalGaussianDistribution):
                    c = c.mode()
            else:
                c = self.cond_stage_model(c, input_img=x)
        else:
            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, input_img=x)
        return c
    

    def configure_optimizers(self):
        lr = self.learning_rate

        if self.embedding_manager is not None:
            params = list(self.embedding_manager.embedding_parameters())
        else:
            params = list(self.model.parameters())
            if self.cond_stage_trainable:
                params = params + list()

        
        opt = torch.optim.AdamW(params, lr=lr)
        return opt
    



ddim_steps = 50
strength = 0.5
scale=10.0
t_enc = int(strength * ddim_steps)

all_samples = list()

x_noisy = model.q_sample(x_start=init_latent, t=torch.tensor([t_enc]*batch_size).to(device))
model_output = model.apply_model(x_noisy, torch.tensor([t_enc]*batch_size).to(device), c)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device),
                                  noise = model_output, use_original_steps=True)

t_enc = int(strength * ddim_steps)
# p sample
samples = sampler.decode(z_enc, c, t_enc, 
                         unconditioanl_guidance_scale=scale,
                         unconditional_conditioning=uc,)


x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

for x_sample in x_samples:
    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
all_samples.append(x_samples)

grid = torch.stack(all_samples, 0)