# Initialize Notebook

## 📖 Citation

This notebook is a replication of the [PCDM demo](https://github.com/tencent-ailab/PCDMs/blob/main/pcdms_demo.ipynb) and is based on the following paper:

> Shen, F., Ye, H., Zhang, J., Wang, C., Han, X., & Wei, Y.  
> *Advancing Pose-Guided Image Synthesis with Progressive Conditional Diffusion Models*.  
> The Twelfth International Conference on Learning Representations (ICLR).

```bibtex
@inproceedings{shenadvancing,
  title={Advancing Pose-Guided Image Synthesis with Progressive Conditional Diffusion Models},
  author={Shen, Fei and Ye, Hu and Zhang, Jun and Wang, Cong and Han, Xiao and Wei, Yang},
  booktitle={The Twelfth International Conference on Learning Representations}
}


In [None]:
!pip install open_clip_torch controlnet-aux mediapipe > /dev/null
!pip install -U diffusers > /dev/null

In [None]:
# import basics
from pathlib import Path
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
# basics pytorch
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchvision import transforms
# import image encoder
from transformers import CLIPImageProcessor, Dinov2Model
# import diffusion models
from diffusers import (
    AutoencoderKL,           # Autoencoder model
    DDIMScheduler,           # Scheduler for diffusion steps
    UNet2DConditionModel     # Conditional U-Net model
)
from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding

# constants uses through whole codebase
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = torch.Generator(device=device).manual_seed(42)

## Clone Repo

In [None]:
!git clone https://github.com/tencent-ailab/PCDMs.git
!mv PCDMs/* .
!ls

In [None]:
# fix some importing
# replace 'diffusers.models.unets.unet_2d_blocks' with 'diffusers.models.unet_2d_blocks' in 'stage2_inpaint_unet_2d_condition.py'
!sed -i 's/diffusers\.models\.unet_2d_blocks/diffusers.models.unets.unet_2d_blocks/g' ./src/models/stage2_inpaint_unet_2d_condition.py

## DWpose

I will use `easy-dwpose` instead of `mmpose` which more lightweight because they use onnx

**DON'T** move download cell upward will cause error

In [None]:
!pip install easy-dwpose > /dev/null

In [None]:
# replace single_extract_pose.py
from easy_dwpose  import DWposeDetector

def inference_pose(img_path, image_size=(1024, 1024)):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = DWposeDetector(device=device)
    pil_image = Image.open(img_path).convert("RGB").resize(image_size, Image.BICUBIC)
    dwpose_image = model(pil_image, output_type='np', detect_resolution=image_size[1])
    save_dwpose_image = Image.fromarray(dwpose_image)
    return save_dwpose_image


## Download Checkpoint

In [None]:
! gdown "1JFFy_FBxOFuGFBcB6xMIVwcQb8bfnpO9" -O "pcdms_ckpt.pt"

# Load Checkpoint

In [None]:
def load_model(ckpt_path = "./pcdms_ckpt.pt"):
    model_ckpt = torch.load(ckpt_path)
    
    unet_dict = {}
    pose_proj_dict = {}
    image_proj_model_dict = {}

    for key, value in model_ckpt['module'].items():
        # sub models
        model_name = key.split('.')[0]
        model_key = key[len(model_name)+1:]
        # put weights in correct dict
        if model_name == 'pose_proj':
            pose_proj_dict[model_key] = value
        elif model_name == 'unet':
            unet_dict[model_key] = value
        elif model_name == 'image_proj_model':
            image_proj_model_dict[model_key] = value
        else:
            raise FileNotFoundError("no model called that")
        
    return unet_dict,  pose_proj_dict,  image_proj_model_dict


unet_dict,  pose_proj_dict,  image_proj_model_dict = load_model()

## Load Components

In [None]:
# loading sd components

from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel

# load unet from "stable diffusion v2.1" and fed it to stag2model
unet = Stage2_InapintUNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16,subfolder="unet",in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device)

# load vae I didn't see it in use in this notebook
# vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base",subfolder="vae").to(device, dtype=torch.float16)
noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)

In [None]:
# image encoder
image_encoder = Dinov2Model.from_pretrained("facebook/dinov2-giant").to(device, dtype=torch.float16)


# ImageProjModel will project `embeddings` output from `image_encoder` to input to SD
class ImageProjModel(torch.nn.Module):
    """SD model with image prompt"""
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, out_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):  
        return self.net(x)
image_proj_model = ImageProjModel(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).to(dtype=torch.float16)

# pose encoder
pose_proj_model = ControlNetConditioningEmbedding(
    conditioning_embedding_channels=320,
    block_out_channels=(16, 32, 96, 256),
    conditioning_channels=3).to(device).to(dtype=torch.float16)

In [None]:
from src.pipelines.PCDMs_pipeline import PCDMsPipeline

# pipeline of stage2
pipe = PCDMsPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", unet=unet,  torch_dtype=torch.float16, scheduler=noise_scheduler,feature_extractor=None,safety_checker=None).to(device)

## Load checkpoint weights

In [None]:
unet.load_state_dict(unet_dict)
pose_proj_model.load_state_dict(pose_proj_dict)
image_proj_model.load_state_dict(image_proj_model_dict)

# Inference Step By Step

In [None]:
# transformer of data
clip_image_processor = CLIPImageProcessor()
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]), # transform pixels in range [-1, 1]
])

will load image sample 1 step by step

In [None]:
# define some parameters
num_samples = 1
image_size = (512, 512)
s_img_path = './imgs/img1.png'
target_pose_img = './imgs/pose1.png'

## Preparing Inputs
- In painting input (image || black) + mask of (white || black)
- Pose Condition (src pose || target pose)
- 

### In-painting input

In [None]:
s_img = Image.open(s_img_path).convert("RGB").resize(image_size, Image.BICUBIC)
black_image = Image.new("RGB", s_img.size, (0, 0, 0)).resize(image_size, Image.BICUBIC)

s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height))
s_img_t_mask.paste(s_img, (0, 0))
s_img_t_mask.paste(black_image, (s_img.width, 0))

vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0)
print("s_img_t_mask (vae input) image shape: ", vae_image.shape)

s_img_t_mask

### Mask (In-Painting input)

In [None]:
mask1 = torch.ones((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
mask0 = torch.zeros((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
mask = torch.cat([mask1, mask0], dim=3)

print("mask shape: ", mask.shape)
plt.imshow(mask[0].detach().cpu().permute(1, 2, 0), cmap="gray")

### Pose Condition

put `dwpose` of source image beside `openpose` of target

In [None]:
s_pose = inference_pose(s_img_path, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
print('source image width: {}, height: {}'.format(s_pose.width, s_pose.height))

t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)

st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height))
st_pose.paste(s_pose, (0, 0))
st_pose.paste(t_pose, (s_pose.width, 0))

cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0)
print("st_pose (condition) shape: ", cond_st_pose.shape)
st_pose

### Clip image encoder processor  -> input to Dino :XD

In [None]:
# ??clip_image_processor

In [None]:
clip_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values # [1, 3, 224, 224]
print("clip shape: ", clip_s_img.shape)
plt.imshow(clip_s_img[0].permute(1, 2, 0))

## Inference

In [None]:
# just to make sure I don't have replication of models on GPU
torch.cuda.empty_cache()

In [None]:
# preprocessing step -- prepare latents, encoded embeddings, conditions

with torch.inference_mode():
    # 1. prepare conditioned pose
    cond_pose = pose_proj_model(cond_st_pose.to(dtype=torch.float16, device=device))
    # 2. prepare latent 
    simg_mask_latents = pipe.vae.encode(vae_image.to(device, dtype=torch.float16)).latent_dist.sample()
    simg_mask_latents = simg_mask_latents * 0.18215 # since VAE paper do that 
    # projected encoded embeddings for both (conditional & uncondational)
    images_embeds = image_encoder(clip_s_img.to(device, dtype=torch.float16)).last_hidden_state
    image_prompt_embeds = image_proj_model(images_embeds)
    uncond_image_prompt_embeds = image_proj_model(torch.zeros_like(images_embeds))


bs_embed, seq_len, _ = image_prompt_embeds.shape
# repeat inputs to count for unconditional embeddings
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

In [None]:
# parameters you could play with
num_inference_steps = 50
guidance_scale = 2.0
num_samples = 1

output = pipe(
    simg_mask_latents= simg_mask_latents,
    mask = mask,
    cond_pose = cond_pose,
    prompt_embeds=image_prompt_embeds,
    negative_prompt_embeds=uncond_image_prompt_embeds, # ??
    height=image_size[1],
    width=image_size[0]*2, # for inpainting mask
    num_images_per_prompt=1,
    guidance_scale=guidance_scale,
    generator=generator,
    num_inference_steps=num_inference_steps,
)

In [None]:
output.images[0]

# Inference in Function

In [None]:
def get_inpainting_inputs(s_img):
    """ concatenate source with mask - do basic processing - return latents output from VAE """
    # 1. concatenate source with mask
    black_image = Image.new("RGB", s_img.size, (0, 0, 0)).resize(s_img.size, Image.BICUBIC)
    
    s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height))
    s_img_t_mask.paste(s_img, (0, 0))
    s_img_t_mask.paste(black_image, (s_img.width, 0))
    # 2. do basic processing
    vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0)
    # 3. get latents from VAE
    with torch.inference_mode():
        latents = pipe.vae.encode(vae_image.to(device, dtype=torch.float16)).latent_dist.sample()
        latents = latents * 0.18215 # since VAE paper do that 
        
    return latents

def get_inpainting_cond(s_img, t_pose):
    """ concatenate source pose with target pose -- project conditions"""
    # 1. concatenate source pose with target pose
    s_pose = inference_pose(s_img_path, image_size=s_img.size).resize(s_img.size, Image.BICUBIC)
    st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height))
    st_pose.paste(s_pose, (0, 0))
    st_pose.paste(t_pose, (s_pose.width, 0))

    # 2. project conditions
    cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0)
    with torch.inference_mode():
        cond_pose = pose_proj_model(cond_st_pose.to(dtype=torch.float16, device=device))
    
    return cond_pose

def get_image_embeddings(s_img, num_samples = 1):
    # do basic processing 
    clip_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values 
    # projected encoded embeddings for both (conditional & uncondational)
    with torch.inference_mode():
        images_embeds = image_encoder(clip_s_img.to(device, dtype=torch.float16)).last_hidden_state
        image_prompt_embeds = image_proj_model(images_embeds)
        uncond_image_prompt_embeds = image_proj_model(torch.zeros_like(images_embeds))
    
    # repeat inputs to count for unconditional embeddings
    bs_embed, seq_len, _ = image_prompt_embeds.shape
    image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1).view(bs_embed * num_samples, seq_len, -1)
    uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1).view(bs_embed * num_samples, seq_len, -1)
    
    return image_prompt_embeds, uncond_image_prompt_embeds

In [None]:
def inference_one_image(s_img_path = './imgs/img1.png', target_pose_path = './imgs/pose1.png', 
                        image_size = (512, 512),
                        num_inference_steps = 50,
                        guidance_scale = 2.0):
    # ======================== Preprocessing ==========================================
    # 1. read image
    s_img = Image.open(s_img_path).convert("RGB").resize(image_size, Image.BICUBIC)
    t_pose = Image.open(target_pose_path).convert("RGB").resize((image_size), Image.BICUBIC)
    # 2. get inpainting input
    simg_mask_latents = get_inpainting_inputs(s_img)
    # 3. get conditional pose
    cond_pose = get_inpainting_cond(s_img, t_pose)
    # 4. get image embeddings
    image_prompt_embeds, uncond_image_prompt_embeds = get_image_embeddings(s_img)

    # ======================== Pipeline  ==========================================
    return pipe(
            simg_mask_latents= simg_mask_latents,
            mask = mask,
            cond_pose = cond_pose,
            prompt_embeds=image_prompt_embeds,
            negative_prompt_embeds=uncond_image_prompt_embeds, # ??
            height=image_size[1],
            width=image_size[0]*2, # for inpainting mask
            num_images_per_prompt=1,
            guidance_scale=guidance_scale,
            generator=generator,
            num_inference_steps=num_inference_steps,
        )

# Outputs

- it's clear that model change src image and not perserve it (IDK it's about inpainting or not)

In [None]:
inference_one_image().images[-1]

In [None]:
s_img_path = './imgs/img2.png'
target_pose_path = './imgs/pose1.png'
output = inference_one_image(s_img_path, target_pose_path).images[-1]

output.resize((512, 512), Image.BICUBIC)

In [None]:
output = inference_one_image('./imgs/img3.png', target_pose_path).images[-1]
output.resize((512, 512), Image.BICUBIC)

In [None]:
output = inference_one_image('./imgs/img4.png', target_pose_path).images[-1]
output.resize((1024, 512), Image.BICUBIC)

In [None]:
output = inference_one_image('./imgs/img5.png', './imgs/pose2.png').images[-1]
output.resize((512, 512), Image.BICUBIC)

In [None]:
output = inference_one_image('./imgs/img6.png', './imgs/pose3.png').images[-1]
output.resize((512, 512), Image.BICUBIC)

In [None]:
output = inference_one_image('./imgs/img7.png', './imgs/pose4.png').images[-1]
output.resize((512, 512), Image.BICUBIC)

In [None]:
output = inference_one_image('./imgs/img8.png', './imgs/pose5.png').images[-1]
output.resize((512, 512), Image.BICUBIC)

In [None]:
output = inference_one_image('./imgs/img9.png', './imgs/pose6.png').images[-1]
output.resize((512, 512), Image.BICUBIC)

# Play with Poses

In [None]:
from controlnet_aux import OpenposeDetector
openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
openpose(s_img)

In [None]:
s_pose = inference_pose(s_img_path, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
print('source image width: {}, height: {}'.format(s_pose.width, s_pose.height))
s_pose