face embedding + controlnet train

In [1]:
import math

import torch
import torch.nn as nn
import torchvision


from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
from helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm

In [None]:
to_tensor = torchvision.transforms.ToTensor()

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
        super().__init__()

        self.tokenizer = tokenizer
        self.size = size
        self.i_drop_rate = i_drop_rate
        self.t_drop_rate = t_drop_rate
        self.ti_drop_rate = ti_drop_rate
        self.image_root_path = image_root_path

    def fr_image_preprocess(self, images=raw_image):
        # pil to tensor
        raw_image = raw_image.resize((256, 256), Image.BILINEAR)
        img = to_tensor(image)
        img = img * 2 - 1
        img = torch.unsqueeze(img, 0)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        image_file = item["image_file"]

        # read image
        raw_image = Image.open(os.path.join(self.image_root_path, image_file))
        image = self.transform(raw_image.convert("RGB"))
        clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
        fr_image = self.fr_image_preprocess(im)


        # drop
        drop_image_embed = 0
        rand_num = random.random()
        if rand_num < self.i_drop_rate:
            drop_image_embed = 1
        elif rand_num < (self.i_drop_rate + self.t_drop_rate):
            text = ""
        elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
            text = ""
            drop_image_embed = 1
        # get text and tokenize
        text_input_ids = self.tokenizer(
            text,
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids

        return {
            "image": image,
            "text_input_ids": text_input_ids,
            "clip_image": clip_image,
            "drop_image_embed": drop_image_embed
        }

    def __len__(self):
        return len(self.data)

In [3]:
class Backbone(Module):
    def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
        super().__init__()
        blocks = get_blocks(num_layers)

        if mode == 'ir':
            unit_module = bottleneck_IR
        elif mode == 'ir_se':
            unit_module = bottleneck_IR_SE

        self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
                                      BatchNorm2d(64),
                                      PReLU(64))
        self.output_layer = Sequential(BatchNorm2d(512),
                                       Dropout(drop_ratio),
                                       Flatten(),
                                       Linear(512 * 7 * 7, 512),
                                       BatchNorm1d(512, affine=affine))

        self.body = nn.ModuleList([])
        for block in blocks:
            for bottleneck in block:
                self.body.append(unit_module(bottleneck.in_channel,
                                           bottleneck.depth,
                                           bottleneck.stride))
    def forward(self, x):
        x = self.input_layer(x)

        samples = ()
        for i, body_block in enumerate(self.body):
            x = body_block(x)
            samples = samples + (x,)
        x = self.output_layer(x)
        return l2_norm(x), samples

In [7]:
ie = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
ie.eval()

x = torch.randn(1, 3, 112, 112)

out, feat = ie(x)

In [8]:
out.shape

torch.Size([1, 512])

In [10]:
class FRImageProjModel(torch.nn.Module):
    """Projection Model"""
    def __init__(self, cross_attention_dim=1024, fr_embeddings_dim=512, fr_extra_context_tokens=4):
        super().__init__()
        
        self.cross_attention_dim = cross_attention_dim
        self.fr_extra_context_tokens = fr_extra_context_tokens
        
        self.proj = torch.nn.Linear(fr_embeddings_dim, self.fr_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)
        
    def forward(self, fr_embeds):
        """
        fr_embeds : [B, 512]
        clip_extra_context_tokens : [B, fr_extra_context_tokens, cross_attention_dim] 
        """
        embeds = fr_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.fr_extra_context_tokens, self.cross_attention_dim)
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

In [12]:
image_proj_model = FRImageProjModel(
    cross_attention_dim=1024,
    fr_embeddings_dim=512,
    fr_extra_context_tokens=4,
)

In [None]:
class IPAdapter(nn.Moduel):
    """IP-Adapter"""
    def __init__(self, unet, image_proj_model, adapter_modules):
        super().__init__()
        self.unet = unet
        self.image_proj_model = image_proj_model
        self.adapter_modules = adapter_modules

    def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
        ip_tokens = self.image_proj_model(image_embeds)
        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
        # Predict the noise residual and compute loss
        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
        return noise_pred

In [25]:
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time

import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image


from transformers import CLIPImageProcessor

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection


def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")

if is_torch2_available():
    from attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor, IPFRAttnProcessor2_0 as IPFRAttnProcessor
else:
    from attention_processor import IPAttnProcessor, AttnProcessor, IPFRAttnProcessor

unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet",
                                            cache_dir='/home/tyk/hf_cache')
unet.requires_grad_(False)
# nn.ModuleList로 넣으려면 nn.Module의 subclass여야 한다.
# 그래서 다시 구현한 것

attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]

    if cross_attention_dim is None:
        attn_procs[name] = AttnProcessor()
    else:
        # 기존 weight로 로드 시켜줌
        layer_name = name.split(".processor")[0]
        weights = {
            "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
            "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
        }
        attn_procs[name] = IPFRAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
                                             fr_embeds_dim=512)
        attn_procs[name].load_state_dict()


# adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

TypeError: diffusers.models.attention_processor.AttnProcessor2_0 is not a Module subclass

In [22]:
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time

import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image


from transformers import CLIPImageProcessor

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection


def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")

if is_torch2_available():
    from attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor, IPFRAttnProcessor2_0 as IPFRAttnProcessor
else:
    from attention_processor import IPAttnProcessor, AttnProcessor, IPFRAttnProcessor

In [None]:
# args = parse_args()
# logging_dir = Path(args.output_dir, args.logging_dir)

# accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

# accelerator = Accelerator(    
#     mixed_precision=args.mixed_precision,
#     log_with=args.report_to,
#     project_config=accelerator_project_config,
# )

# if accelerator.is_main_process:
#     if args.output_dir is not None:
#         os.makedirs(args.output_dir, exist_ok=True)

noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
# image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)

#ip-adapter
image_proj_model = FRImageProjModel(
    cross_attention_dim=unet.config.cross_attention_dim,
    clip_embeddings_dim=512,
    clip_extra_context_tokens=4,
)


In [None]:
# init adapter modules
attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]
        
    if cross_attention_dim is None:
        attn_procs[name] = AttnProcessor()
    else:
        layer_name = name.split(".processor")[0]
        weights = {
            "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
            "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
        }
        attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
        attn_procs[name].load_state_dict(weights)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules)

Inference

In [None]:
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
# infernce
class IPAdapterPlusFR:
    def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
        self.device = device
        self.image_encoder_path = image_encoder_path
        self.ip_ckpt = ip_ckpt
        self.num_tokens = num_tokens

        self.pipe = sd_pipe.to(self.device)
        self.set_ip_adapter()

        self.image_encoder = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
        self.image_encoder.load_state_dict(torch.load("/workspace/ddgm/functions/arcface/model_ir_se50.pth"))
        self.pool = nn.AdaptiveAvgPool2d((256, 256))
        self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
        self.facenet.eval()

    def image_process(self, ):
        
    def set_ip_adapter(self):
        unet = self.pipe.unet
        attn_procs = {}
        for name in unet.attn_processors.keys():


        unet.set_attn_processor(attn_procs)
        if hasattr(self.pipe, "controlnet"):
            if isinstance(self.pipe.controlnet, MultiControlNetModel):
                for controlnet in self.pipe.controlnet.nets:
                    controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
            else:
                self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))



    def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
        ip_tokens = self.image_proj_model(image_embeds)
        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
        return noise_pred


In [None]:
import torch
from diffusers import StableDiffusionControlNetPipeline, DDIMScheduler, AutoencoderKL, ControlNetModel
from PIL import Image

controlnet_model_path = "lllyasviel/control_v11p_sd15_openpose"
base_model_path = "runwayml/stable-diffusion-v1-5"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "models/image_encoder/"
ip_ckpt = "models/ip-adapter_sd15.bin"
device = "cuda"


noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16, cache_dir='/home/tyk/hf_cache')


controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16, cache_dir='/home/tyk/hf_cache')
# load SD pipeline
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    base_model_path,
    controlnet=controlnet,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None,
    cache_dir='/home/tyk/hf_cache'
)

In [None]:
image = Image.open("assets/images/girl_face.png")
image.resize((256, 256))

In [None]:
openpose_image = Image.open("assets/structure_controls/openpose.png")
openpose_image.resize((256, 384))

In [None]:
# load ip-adapter
ip_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, device)

In [None]:
# generate
images = ip_model.generate(pil_image=image, image=openpose_image, width=512, height=768, num_samples=4, num_inference_steps=50, seed=42)
grid = image_grid(images, 1, 4)
grid