In [1]:

%load_ext autoreload
%autoreload 2
import os
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from einops import rearrange, repeat
from typing import Callable, List, Optional, Union
import torchvision as tv

from diffusers import AutoencoderKL, DDIMScheduler
from transformers import CLIPVisionModelWithProjection

from exp.animate.face_animate import FaceAnimatePipeline
from exp.animate.face_animate_static import StaticPipeline

from exp.models.audio_proj import AudioProjModel
from exp.models.face_locator import FaceLocator
from exp.models.mutual_self_attention import ReferenceAttentionControl
from exp.models.unet_2d_condition import UNet2DConditionModel
from exp.models.unet_3d import UNet3DConditionModel
from exp.models.image_proj import ImageProjModel
from emo.models.speed_encoder import SpeedEncoder

from exp.datasets.audio_processor import AudioProcessor
from exp.datasets.image_processor import ImageProcessor
# from exp.datasets.mask_image import EMODataset
# from exp.datasets.talk_video import TalkingVideoDataset

from exp.utils.util import tensor_to_video
from exp.utils.params import (
    param_optim, 
    create_optim_params, 
    negate_params, 
    create_optimizer_params, 
    handle_trainable_modules,
    freeze_params
)
from exp.models.lora_handler import LoraHandler

# from scripts.train_stage3_emo import Net, load_config, process_audio_emb
from scripts.train_stage_exp_lora import Net, load_config, process_audio_emb
from omegaconf import OmegaConf

from IPython.display import Video

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
vid_dir = '/home/sihun.cha/Work/hallo/.cache/output_7.mp4'
# Video('/home/sihun.cha/Work/hallo/.cache/output_7.mp4')
from IPython.display import HTML

# HTML("""
#     <video alt="test" controls>
#         <source src="/home/sihun.cha/Work/hallo/.cache/output_7.mp4">
#     </video>
# """)


In [3]:
cfg = load_config("./configs/train/exp/stage2.yaml")
if cfg.weight_dtype == "fp16":
    weight_dtype = torch.float16
elif cfg.weight_dtype == "bf16":
    weight_dtype = torch.bfloat16
elif cfg.weight_dtype == "fp32":
    weight_dtype = torch.float32
else:
    raise ValueError(
        f"Do not support weight dtype: {cfg.weight_dtype} during training"
    )

sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
if cfg.enable_zero_snr:
    sched_kwargs.update(
        rescale_betas_zero_snr=True,
        timestep_spacing="trailing",
        prediction_type="v_prediction",
    )
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
sched_kwargs.update({"beta_schedule": "scaled_linear"})

In [4]:
# use_pipeline = 'emo'
use_pipeline = 'hallo'

denoising_unet = UNet3DConditionModel.from_pretrained_2d(
    cfg.base_model_path,
    cfg.motion_module_path,
    subfolder="unet",
    unet_additional_kwargs=OmegaConf.to_container(
        cfg.unet_additional_kwargs),
    use_landmark=False,
    vis_atttn=True,
).to(device="cuda", dtype=weight_dtype)
print("3")

vae = AutoencoderKL.from_pretrained(
    cfg.vae_model_path
).to("cuda", dtype=weight_dtype)
print("1")
reference_unet = UNet2DConditionModel.from_pretrained(
    cfg.base_model_path,
    subfolder="unet",
).to(device="cuda", dtype=weight_dtype)
print("2")

if use_pipeline == 'hallo':
    face_locator = FaceLocator(
        conditioning_embedding_channels=320,
        conditioning_channels=3,
    ).to(device="cuda", dtype=weight_dtype)
    print("5")
    audioproj = AudioProjModel(
        seq_len=5,
        blocks=12,
        channels=768,
        intermediate_dim=512,
        output_dim=768,
        context_tokens=32,
    ).to(device="cuda", dtype=weight_dtype)
    print("6")

    imageproj = ImageProjModel(
        cross_attention_dim=denoising_unet.config.cross_attention_dim,
        clip_embeddings_dim=512,
        clip_extra_context_tokens=4,
    ).to(device="cuda", dtype=weight_dtype)

##### for EMO (but no checkpoints...) #########################
if use_pipeline == 'emo':
    image_enc = CLIPVisionModelWithProjection.from_pretrained(
        cfg.image_encoder_path,
    ).to(device="cuda", dtype=weight_dtype)
    print("4")
    speed_enc = SpeedEncoder(
        num_speed_buckets=8, # ? not sure if it is correct
        speed_embedding_dim=768,
    ).to(device="cuda", dtype=weight_dtype)
    print("7")


The config attributes {'center_input_sample': False} were passed to UNet3DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.


Load motion module params from pretrained_models/motion_module/mm_sd_v15_v2.ckpt
3


The config attributes {'center_input_sample': False, 'out_channels': 4} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.


1


Some weights of the model checkpoint were not used when initializing UNet2DConditionModel: 
 ['conv_norm_out.bias, conv_norm_out.weight, conv_out.bias, conv_out.weight']


2
5
6


In [5]:
### No gradients !!!!
# vae.requires_grad_(False)
# image_enc.requires_grad_(False)
# reference_unet.requires_grad_(False)
# denoising_unet.requires_grad_(False)
# face_locator.requires_grad_(False)
# audioproj.requires_grad_(False)
# speed_enc.requires_grad_(False)
freeze_params([
    vae, 
    reference_unet, 
    denoising_unet,
    face_locator,
    imageproj,
    audioproj
])

In [6]:
i=0
for param in [*denoising_unet.parameters()]:
    if param.requires_grad:
        i+=1
        #print(param)
print(i)

0


In [7]:
reference_control_writer = ReferenceAttentionControl(
    reference_unet,
    do_classifier_free_guidance=False,
    mode="write",
    fusion_blocks="full",
)
reference_control_reader = ReferenceAttentionControl(
    denoising_unet,
    do_classifier_free_guidance=False,
    mode="read",
    fusion_blocks="full",
)

In [8]:
# net = Net(
#         reference_unet,
#         denoising_unet,
# ).to(dtype=weight_dtype)

net = Net(
    reference_unet,
    denoising_unet,
    face_locator,
    # reference_control_writer,
    # reference_control_reader,
    None,
    None,
    imageproj,
    audioproj,
).to(dtype=weight_dtype)

m,u = net.load_state_dict(
    torch.load(
        os.path.join(cfg.audio_ckpt_dir, "net.pth"),
        map_location="cpu",
    ),
)

In [9]:
# # attn feature map analysis
# from exp.utils.attention_map import (
#     register_cross_attention_hook,
# )
# denoising_unet, attn_maps = register_cross_attention_hook(denoising_unet)



In [10]:
# # select layers to visualize attn feature map
# for name, module in denoising_unet.named_modules():
#     # if not name.split('.')[-1].startswith('attn2'):
    
#     if name.split('.')[-1] == 'attn2':
#         print(name)
#     if name.split('.')[-1] == 'attention_blocks':
#         print(name)
#         # print(module.processor)
        
#     # if name.find('attn2') > 0:
#     #     print(name)
#     # else:
#     #     continue

# # if isinstance(net.denoising_unet, UNet3DConditionModel):
# #     print(True)

In [12]:
# denoising_unet
from exp.models.lora import LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d
search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
#search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d]
# ancestor_class = ['VersatileAttention'] # motion_module attn
# ancestor_class = ['TemporalTransformerBlock'] # motion_module
ancestor_class = ['TemporalBasicTransformerBlock'] # spatial

if ancestor_class is not None:
    # ancestors = (
    #     module
    #     for name, module in denoising_unet.named_modules()
    #     if module.__class__.__name__ in ancestor_class # and ('transformer_in' not in name)
    # )
    ancestors = []
    for name, module in denoising_unet.named_modules():
        if module.__class__.__name__ in ancestor_class:
            ancestors.append(module)
            # print(name, module.__class__.__name__)
            print(name, module)
            print('-'*10)
    print('1')
else:
    # this, incase you want to naively iterate over all modules.
    ancestors = [module for module in denoising_unet.modules()]
    print('2')
    
# for name, module in denoising_unet.named_modules():
#     # print(name)
#     # print('-----')
#     if any([isinstance(module, _class) for _class in search_class]):
#         print(module)
#     # for nn, mm in denoising_unet.named_modules():
#     #     print(nn)
#     #     print('-----')
#     #     print(mm)

#     # # if module.__class__.__name__ == 'motion_modules':
#     # break
# # module #.__dict__.keys()


down_blocks.0.attentions.0.transformer_blocks.0 TemporalBasicTransformerBlock(
  (attn1): Attention(
    (to_q): Linear(in_features=320, out_features=320, bias=False)
    (to_k): Linear(in_features=320, out_features=320, bias=False)
    (to_v): Linear(in_features=320, out_features=320, bias=False)
    (to_out): ModuleList(
      (0): Linear(in_features=320, out_features=320, bias=True)
      (1): Dropout(p=0.0, inplace=False)
    )
  )
  (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  (attn2): Attention(
    (to_q): Linear(in_features=320, out_features=320, bias=False)
    (to_k): Linear(in_features=768, out_features=320, bias=False)
    (to_v): Linear(in_features=768, out_features=320, bias=False)
    (to_out): ModuleList(
      (0): Linear(in_features=320, out_features=320, bias=True)
      (1): Dropout(p=0.0, inplace=False)
    )
  )
  (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  (ff): FeedForward(
    (net): ModuleList(
      (0): GEGLU(
       

In [57]:
for ancestor in ancestors:
    for name, module in ancestor.named_modules():
        # print(name)
        # print('-----')
        # if any([isinstance(module, _class) for _class in search_class]):
        #     print(module, name)
        for _class in search_class:
            if any([isinstance(module, _class)]):
                if ('attn1' in name or 'ff' in name):
                    print(module, name, _class)
                    *path, n_ = name.split(".")
                    print(path, n_)
        # for nn, mm in denoising_unet.named_modules():
        #     print(nn)
        #     print('-----')
        #     print(mm)

Linear(in_features=320, out_features=2560, bias=True) ff.net.0.proj.linear <class 'torch.nn.modules.linear.Linear'>
['ff', 'net', '0', 'proj'] linear
Linear(in_features=320, out_features=32, bias=False) ff.net.0.proj.lora_down <class 'torch.nn.modules.linear.Linear'>
['ff', 'net', '0', 'proj'] lora_down
Linear(in_features=32, out_features=2560, bias=False) ff.net.0.proj.lora_up <class 'torch.nn.modules.linear.Linear'>
['ff', 'net', '0', 'proj'] lora_up
Linear(in_features=1280, out_features=320, bias=True) ff.net.2.linear <class 'torch.nn.modules.linear.Linear'>
['ff', 'net', '2'] linear
Linear(in_features=1280, out_features=32, bias=False) ff.net.2.lora_down <class 'torch.nn.modules.linear.Linear'>
['ff', 'net', '2'] lora_down
Linear(in_features=32, out_features=320, bias=False) ff.net.2.lora_up <class 'torch.nn.modules.linear.Linear'>
['ff', 'net', '2'] lora_up
Linear(in_features=320, out_features=2560, bias=True) ff.net.0.proj.linear <class 'torch.nn.modules.linear.Linear'>
['ff', 'n

In [50]:
for ancestor in ancestors:
    for name, module in ancestor.named_modules():
        # print(name)
        # print('-----')
        # if any([isinstance(module, _class) for _class in search_class]):
        #     print(module, name)
        for _class in search_class:
            if any([isinstance(module, _class)]):
                if ('attn1' in name or 'attn2' in name or 'ff' in name):
                    print(module, _class)
                    *path, n_ = name.split(".")
                    print(path, n_)
        # for nn, mm in denoising_unet.named_modules():
        #     print(nn)
        #     print('-----')
        #     print(mm)

In [16]:
for ancestor in ancestors:
    for fullname, module in ancestor.named_modules():
        if any([isinstance(module, _class) for _class in search_class]):
            print(fullname)

In [None]:
denoising_unet.modules()

In [9]:
use_lora = True

if use_lora:
    print('using lora')
    # denoising_unet.add_adapter(unet_lora_config)
    # denoising_unet.unload_lora()

    lora_manager_temporal = LoraHandler(
        use_unet_lora=cfg.lora.use_unet_lora, 
        unet_replace_modules=[
            "TemporalBasicTransformerBlock", # for spatial-attention
            # "TemporalTransformerBlock", # for motion_module
        ]
    )
    lora_path = './_tmp'

    unet_lora_params_temporal, unet_negation_temporal = lora_manager_temporal.add_lora_to_model(
        cfg.lora.use_unet_lora, 
        denoising_unet, 
        lora_manager_temporal.unet_replace_modules,
        cfg.lora.lora_unet_dropout,
        lora_path + '/lora/', 
        r=cfg.lora.lora_rank
    )
else:
    print('not using lora')

using lora
inject
attn1.to_q
to_q
attn1.to_k
to_k
attn1.to_v
to_v
attn1.to_out.0
0
attn2.to_q
to_q
attn2.to_k
to_k
attn2.to_v
to_v
attn2.to_out.0
0
ff.net.0.proj
proj
ff.net.2
2
attn1.to_q
to_q
attn1.to_k
to_k
attn1.to_v
to_v
attn1.to_out.0
0
attn2.to_q
to_q
attn2.to_k
to_k
attn2.to_v
to_v
attn2.to_out.0
0
ff.net.0.proj
proj
ff.net.2
2
attn1.to_q
to_q
attn1.to_k
to_k
attn1.to_v
to_v
attn1.to_out.0
0
attn2.to_q
to_q
attn2.to_k
to_k
attn2.to_v
to_v
attn2.to_out.0
0
ff.net.0.proj
proj
ff.net.2
2
attn1.to_q
to_q
attn1.to_k
to_k
attn1.to_v
to_v
attn1.to_out.0
0
attn2.to_q
to_q
attn2.to_k
to_k
attn2.to_v
to_v
attn2.to_out.0
0
ff.net.0.proj
proj
ff.net.2
2
attn1.to_q
to_q
attn1.to_k
to_k
attn1.to_v
to_v


IndexError: pop from empty list

In [59]:
# denoising_unet

In [12]:
i = 0
for param in [*denoising_unet.parameters()]:
    if param.requires_grad:
        # print(True)
        i+=1
print(i)

0


In [13]:
save_lora=False
# denoising_unet.get_submodule('down_blocks.0.motion_modules')
if save_lora:
    print('saving lora')
    lora_manager_temporal.save_lora_weights(
                        model=copy.deepcopy(net),
                        save_path='./_tmp',
                        step=0
                        )

In [21]:
denoising_unet

UNet3DConditionModel(
  (conv_in): InflatedConv3d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock3D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer3DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): TemporalBasicTransformerBlock(
              (attn1): Attention(
                (to_q): LoraInjectedLinear(
                  (linear): Linear(in_features=320, out_features=320, bias=False)
                  (lora_down): Linear(in_features=320, out_features=32, bias=False)
                  (dropout): Dropout(p=0.1, inplace=False)
  

In [15]:
use_pipeline

'hallo'

In [16]:

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if use_pipeline == 'emo':
    print('use_pipeline: emo')
    pipeline = FaceAnimatePipeline(
            vae=vae,
            reference_unet=reference_unet,
            denoising_unet=denoising_unet,
            face_locator=face_locator,
            image_encoder=image_enc,
            speed_encoder=speed_enc,
            scheduler=val_noise_scheduler,
        )
if use_pipeline == 'hallo':
    print('use_pipeline: hallo')
    pipeline = FaceAnimatePipeline(
            vae=vae,
            reference_unet=reference_unet,
            denoising_unet=denoising_unet,
            face_locator=face_locator,
            image_proj=imageproj,
            scheduler=val_noise_scheduler,
        )
pipeline.to(device=device, dtype=weight_dtype)
torch.cuda.empty_cache()

use_pipeline: hallo


In [17]:
# pixel_values_ref_img = torch.randn([1, 3, 3, 512, 512]).to(device=device, dtype=weight_dtype)
# audio_tensor = torch.randn([1, 16, 32, 768]).to(device=device, dtype=weight_dtype)
# source_image_face_emb = torch.randn([1, 512]).to(device=device, dtype=weight_dtype)
# source_image_face_region = torch.randn([1, 3, 512, 512]).to(device=device, dtype=weight_dtype)
# source_image_clip_img = torch.randn([1, 3, 224, 224]).to(device=device, dtype=weight_dtype)
# head_speed = torch.randn([1, 16]).to(device=device, dtype=weight_dtype)
# source_image_face_emb = torch.randn([1, 512]).to(device=device, dtype=weight_dtype)
# img_size = 512, 512
# clip_length = 16
# generator = torch.manual_seed(42)

In [18]:

width, height = 512, 512

image_processor = ImageProcessor((width, height), "./pretrained_models/face_analysis")
audio_processor = AudioProcessor(
    16000,
    25,
    "./pretrained_models/wav2vec/wav2vec2-base-960h",
    False,
    os.path.dirname("./pretrained_models/audio_separator/Kim_Vocal_2.onnx"),
    os.path.basename("./pretrained_models/audio_separator/Kim_Vocal_2.onnx"),
    os.path.join('./', '.cache', "audio_preprocess")
)


Applied providers: ['CUDAExecutionProvider', 'CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}, 'CUDAExecutionProvider': {'prefer_nhwc': '0', 'enable_skip_layer_norm_strict_mode': '0', 'tunable_op_enable': '0', 'enable_cuda_graph': '0', 'tunable_op_max_tuning_duration_ms': '0', 'tunable_op_tuning_enable': '0', 'cudnn_conv_use_max_workspace': '1', 'use_tf32': '1', 'cudnn_conv1d_pad_to_nc1d': '0', 'do_copy_in_default_stream': '1', 'cudnn_conv_algo_search': 'EXHAUSTIVE', 'gpu_external_empty_cache': '0', 'gpu_external_free': '0', 'gpu_external_alloc': '0', 'gpu_mem_limit': '18446744073709551615', 'arena_extend_strategy': 'kNextPowerOfTwo', 'user_compute_stream': '0', 'has_user_compute_stream': '0', 'use_ep_level_unified_stream': '0', 'device_id': '0'}}
find model: ./pretrained_models/face_analysis/models/1k3d68.onnx landmark_3d_68 ['None', 3, 192, 192] 0.0 1.0
Applied providers: ['CUDAExecutionProvider', 'CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}, '

Some weights of Wav2VecModel were not initialized from the model checkpoint at ./pretrained_models/wav2vec/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2024-10-18 19:34:12,327 - INFO - separator - Separator version 0.17.2 instantiating with output_dir: ./.cache/audio_preprocess, output_format: WAV
2024-10-18 19:34:12,343 - INFO - separator - Operating System: Linux #1 SMP PREEMPT_DYNAMIC Wed Jul 31 15:28:35 UTC 2024
2024-10-18 19:34:12,344 - INFO - separator - System: Linux Node: research-workstation-849 Release: 5.14.0-427.28.1.el9_4.x86_64 Machine: x86_64 Proc: x86_64
2024-10-18 19:34:12,344 - INFO - separator - Python Version: 3.10.15
2024-10-18 19:34:12,344 - INFO - separator - PyTorch Version: 2

In [19]:
# RUN hallo
idx = 0
ref_img_path = "examples/reference_images/1.jpg"
audio_path = "examples/driving_audios/1.wav"
clip_length = 16
img_size = 512, 512
face_expand_ratio = 1.2
save_dir = './'

audio_path = cfg.audio_path[idx]
source_image_pixels, \
source_image_face_region, \
source_image_face_emb, \
source_image_full_mask, \
source_image_face_mask, \
source_image_lip_mask = image_processor.preprocess(
    ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio)
audio_emb, audio_length = audio_processor.preprocess(
    audio_path, clip_length)

audio_emb = process_audio_emb(audio_emb)

source_image_pixels = source_image_pixels.unsqueeze(0)
source_image_face_region = source_image_face_region.unsqueeze(0)
source_image_face_emb = source_image_face_emb.reshape(1, -1)
source_image_face_emb = torch.tensor(source_image_face_emb)

source_image_full_mask = [
    (mask.repeat(clip_length, 1))
    for mask in source_image_full_mask
]
source_image_face_mask = [
    (mask.repeat(clip_length, 1))
    for mask in source_image_face_mask
]
source_image_lip_mask = [
    (mask.repeat(clip_length, 1))
    for mask in source_image_lip_mask
]

times = audio_emb.shape[0] // clip_length
tensor_result = []
generator = torch.manual_seed(42)

torch.cuda.empty_cache()
with torch.no_grad():
    for t in range(times):
        print(f"[{t+1}/{times}]")

        if len(tensor_result) == 0:
            # The first iteration
            motion_zeros = source_image_pixels.repeat(
                cfg.data.n_motion_frames, 1, 1, 1)
            motion_zeros = motion_zeros.to(
                dtype=source_image_pixels.dtype, device=source_image_pixels.device)
            pixel_values_ref_img = torch.cat(
                [source_image_pixels, motion_zeros], dim=0)  # concat the ref image and the first motion frames
        else:
            motion_frames = tensor_result[-1][0]
            motion_frames = motion_frames.permute(1, 0, 2, 3)
            motion_frames = motion_frames[0 - cfg.data.n_motion_frames:]
            motion_frames = motion_frames * 2.0 - 1.0
            motion_frames = motion_frames.to(
                dtype=source_image_pixels.dtype, device=source_image_pixels.device)
            pixel_values_ref_img = torch.cat(
                [source_image_pixels, motion_frames], dim=0)  # concat the ref image and the motion frames

        pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)

        audio_tensor = audio_emb[
            t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
        ]
        audio_tensor = audio_tensor.unsqueeze(0)
        audio_tensor = audio_tensor.to(
            device=audioproj.device, 
            dtype=audioproj.dtype
        )
        audio_tensor = audioproj(audio_tensor)

        pipeline_output = pipeline(
            ref_image=pixel_values_ref_img,
            audio_tensor=audio_tensor,
            face_emb=source_image_face_emb,
            face_mask=source_image_face_region,
            pixel_values_full_mask=source_image_full_mask,
            pixel_values_face_mask=source_image_face_mask,
            pixel_values_lip_mask=source_image_lip_mask,
            width=cfg.data.train_width,
            height=cfg.data.train_height,
            video_length=clip_length,
            num_inference_steps=25,
            # num_inference_steps=cfg.inference_steps,
            guidance_scale=cfg.cfg_scale,
            generator=generator,
            store_attn_map=True,
        )
        tensor_result.append(pipeline_output.videos)

tensor_result = torch.cat(tensor_result, dim=2)
tensor_result = tensor_result.squeeze(0)
tensor_result = tensor_result[:, :audio_length]
audio_name = os.path.basename(audio_path).split('.')[0]
ref_name = os.path.basename(ref_img_path).split('.')[0]
output_file = os.path.join(save_dir,f"test_{ref_name}_{audio_name}.mp4")
# save the result after all iteration
tensor_to_video(tensor_result, output_file, audio_path)

I0000 00:00:1729280056.506137   35656 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1729280056.592072   36623 gl_context.cc:357] GL version: 3.2 (OpenGL ES 3.2 NVIDIA 550.54.17), renderer: NVIDIA A10G/PCIe/SSE2
W0000 00:00:1729280056.594760   35656 face_landmarker_graph.cc:174] Sets FaceBlendshapesGraph acceleration to xnnpack by default.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1729280056.620806   36625 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1729280056.630640   36628 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
2024-10-18 19:34:16,754 - INFO - separator - Starting separation process for audio_file_path: examples/driving_audios/1.wav


Processed and saved: ./.cache/1_sep_background.png
Processed and saved: ./.cache/1_sep_face.png


100%|██████████| 3/3 [00:06<00:00,  2.06s/it]
100%|██████████| 3/3 [00:00<00:00, 12.00it/s]
2024-10-18 19:34:24,121 - INFO - mdx_separator - Saving Vocals stem to 1_(Vocals)_Kim_Vocal_2.wav...
2024-10-18 19:34:24,325 - INFO - common_separator - Clearing input audio file paths, sources and stems...
2024-10-18 19:34:24,326 - INFO - separator - Separation duration: 00:00:07


[1/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.56s/it]
100%|██████████| 16/16 [00:01<00:00, 15.37it/s]


[2/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[3/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[4/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[5/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[6/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[7/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[8/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[9/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[10/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[11/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


[12/12]
store attn map


100%|██████████| 25/25 [00:38<00:00,  1.55s/it]
100%|██████████| 16/16 [00:01<00:00, 15.67it/s]


Moviepy - Building video ./test_1_1.mp4.
MoviePy - Writing audio in test_1_1TEMP_MPY_wvf_snd.mp4


                                                        

MoviePy - Done.
Moviepy - Writing video ./test_1_1.mp4



                                                               

Moviepy - Done !
Moviepy - video ready ./test_1_1.mp4


In [None]:
# Visualize attention maps
torch.cuda.empty_cache()
pipeline.attn_maps

## Debugging

In [None]:
ref_img_path = "examples/reference_images/1.jpg"
audio_path = "examples/driving_audios/1.wav"
clip_length = 16
img_size = 512, 512
face_expand_ratio = 1.2

source_image_pixels, \
source_image_face_region, \
source_image_face_emb, \
source_image_clip_img = image_processor.preprocess(
    ref_img_path, 
    os.path.join('./', '.cache'), 
    face_expand_ratio
)
audio_emb, audio_length = audio_processor.preprocess(audio_path, clip_length)

audio_emb = process_audio_emb(audio_emb)

source_image_pixels = source_image_pixels.unsqueeze(0)
source_image_face_region = source_image_face_region.unsqueeze(0)
source_image_face_emb = source_image_face_emb.reshape(1, -1)
source_image_face_emb = torch.tensor(source_image_face_emb)
# source_image_clip_img = source_image_clip_img.repeat(3, 1, 1, 1)

times = audio_emb.shape[0] // clip_length
print(times)

# motion_zeros = source_image_pixels.repeat(2, 1, 1, 1)
# motion_zeros = motion_zeros.to(
#         dtype=source_image_pixels.dtype, 
#         device=source_image_pixels.device
#     )
# pixel_values_ref_img = torch.cat(
#     [source_image_pixels, motion_zeros], 
#     dim=0
# )  # concat the ref image and the first motion frames
torch.cuda.empty_cache()

In [None]:

tensor_result = []
for t in range(times):
    print(f"[{t+1}/{times}]")

    if len(tensor_result) == 0:
        # The first iteration
        motion_zeros = source_image_pixels.repeat(2, 1, 1, 1)
        motion_zeros = motion_zeros.to(
            dtype=source_image_pixels.dtype, device=source_image_pixels.device)
        pixel_values_ref_img = torch.cat(
            [source_image_pixels, motion_zeros], dim=0)  # concat the ref image and the first motion frames
    else:
        motion_frames = tensor_result[-1][0]
        motion_frames = motion_frames.permute(1, 0, 2, 3)
        motion_frames = motion_frames[0 - 2:]
        motion_frames = motion_frames * 2.0 - 1.0
        motion_frames = motion_frames.to(
            dtype=source_image_pixels.dtype, device=source_image_pixels.device)
        pixel_values_ref_img = torch.cat(
            [source_image_pixels, motion_frames], dim=0)  # concat the ref image and the motion frames

    pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)

    audio_tensor = audio_emb[
        t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
    ]
    audio_tensor = audio_tensor.unsqueeze(0)
    audio_tensor = audio_tensor.to(
        device=audioproj.device, dtype=audioproj.dtype
    )
    audio_tensor = audioproj(audio_tensor)

    speed_emb = torch.ones(audio_tensor.shape[:2]).to(
        dtype=speed_enc.dtype, device=speed_enc.device
    )
    # print(speed_emb.shape)
    # break
    # print(pixel_values_ref_img.shape)
    # print(source_image_face_emb.shape)
    # print(audio_tensor.shape)
    # print(source_image_face_region.shape)
    # print(source_image_clip_img.shape)
    # print(width)
    # print(height)
    # print(speed_emb.shape)
    # print(clip_length)

    pipeline_output = pipeline(
        ref_image=pixel_values_ref_img,
        face_emb=source_image_face_emb,
        audio_tensor=audio_tensor,
        face_mask=source_image_face_region,
        clip_img=source_image_clip_img,
        width=width,
        height=height,
        speed_emb=speed_emb,
        video_length=clip_length,
        num_inference_steps=25,
        guidance_scale=3.5,
        generator=torch.manual_seed(42),
    )

    tensor_result.append(pipeline_output.videos)
    # break

tensor_result = torch.cat(tensor_result, dim=2)
tensor_result = tensor_result.squeeze(0)
tensor_result = tensor_result[:, :audio_length]
audio_name = os.path.basename(audio_path).split('.')[0]
ref_name = os.path.basename(ref_img_path).split('.')[0]

output_file = os.path.join('./',f"test_{ref_name}_{audio_name}.mp4")
# save the result after all iteration
tensor_to_video(tensor_result, output_file, audio_path)
torch.cuda.empty_cache()


# clip_img torch.Size([1, 3, 224, 224])
# latents torch.Size([1, 4, 16, 64, 64])
# ref_image_latents torch.Size([3, 4, 64, 64])
# face_mask torch.Size([2, 320, 16, 64, 64])
# audio_tensor torch.Size([2, 16, 32, 768])
# speed_emb torch.Size([2, 16, 768])
# encoder_hidden_states torch.Size([2, 1, 768])
# encoder_hidden_states.repeat torch.Size([6, 1, 768])


In [None]:
pipeline.reference_control_reader

In [None]:
Video(f"test_{ref_name}_{audio_name}.mp4")

In [12]:
torch.cuda.empty_cache()

In [None]:
pixel_values_ref_img.shape

In [None]:
# audio_tensor.shape
# speed_emb.shape
# speed_emb = torch.ones(audio_tensor.shape[:2]).to(
#     dtype=speed_enc.dtype, device=speed_enc.device
# )
print(speed_emb.shape)

In [None]:
self = pipeline
ref_image=pixel_values_ref_img # [1, 3, 3, 512, 512]
audio_tensor=audio_tensor # [1, 16, 32, 768]
face_emb=source_image_face_emb # [1, 512]
face_mask=source_image_face_region # [1, 3, 512, 512]
clip_img=source_image_clip_img # [1, 3, 224, 224]
speed_emb=speed_emb
width=width
height=height
video_length=clip_length
num_inference_steps=25
guidance_scale=3.5
generator=torch.manual_seed(42)
num_images_per_prompt=1
output_type: Optional[str] = "tensor"
eta: float = 0.0


# Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

device = self._execution_device

do_classifier_free_guidance = guidance_scale > 1.0
# do_classifier_free_guidance = False

with torch.no_grad():
    # Prepare timesteps
    self.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = self.scheduler.timesteps

batch_size = 1

# prepare clip image embeddings
with torch.no_grad():
    clip_image_embeds = clip_img.to(self.image_encoder.device, self.image_encoder.dtype)
    encoder_hidden_states = self.image_encoder(
        clip_image_embeds
    ).image_embeds.unsqueeze(1)
    uncond_encoder_hidden_states = self.image_encoder(
        torch.zeros_like(clip_image_embeds)
    ).image_embeds.unsqueeze()

if do_classifier_free_guidance:
    encoder_hidden_states = torch.cat(
        [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
    )


In [None]:
reference_control_writer = ReferenceAttentionControl(
    self.reference_unet,
    do_classifier_free_guidance=do_classifier_free_guidance,
    mode="write",
    batch_size=batch_size,
    fusion_blocks="full",
    # return_modules=True,
)
reference_control_reader = ReferenceAttentionControl(
    self.denoising_unet,
    do_classifier_free_guidance=do_classifier_free_guidance,
    mode="read",
    batch_size=batch_size,
    fusion_blocks="full",
    # return_modules=True,
)

In [14]:

num_channels_latents = self.denoising_unet.in_channels

with torch.no_grad():
    latents = self.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        width,
        height,
        video_length,
        clip_image_embeds.dtype,
        device,
        generator,
    )

In [15]:
# Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# Prepare ref image latents
with torch.no_grad():
    ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
    ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width)  # (bs, c, width, height)
    ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
    ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
    ref_image_latents = ref_image_latents * 0.18215  # (b, 4, h, w)

with torch.no_grad():
    face_mask = face_mask.unsqueeze(1).to(dtype=self.face_locator.dtype, device=self.face_locator.device) # (bs, f, c, H, W)
    face_mask = repeat(face_mask, "b f c h w -> b (repeat f) c h w", repeat=video_length)
    face_mask = face_mask.transpose(1, 2)  # (bs, c, f, H, W)
    face_mask = self.face_locator(face_mask)
    face_mask = torch.cat([torch.zeros_like(face_mask), face_mask], dim=0) if do_classifier_free_guidance else face_mask

In [16]:

if do_classifier_free_guidance:
    uncond_audio_tensor = torch.zeros_like(audio_tensor)
    audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0)
    audio_tensor = audio_tensor.to(dtype=self.denoising_unet.dtype, device=self.denoising_unet.device)

In [17]:
# if self.speed_encoder is not None:
#     if speed_emb is None:
#         speed_emb = torch.ones(audio_tensor.shape[0]).to(
#             dtype=self.speed_encoder.dtype, device=self.speed_encoder.device
#         ).repeat(2,1)
#     else:
#         uncond_speed_emb = torch.zeros_like(speed_emb)
#         speed_emb = torch.cat([uncond_speed_emb, speed_emb], dim=0)
        
#     with torch.no_grad():
#         speed_emb = self.speed_encoder(speed_emb)

if speed_emb is None:
    speed_emb = torch.ones(audio_tensor.shape[0]).to(
        dtype=self.speed_encoder.dtype, device=self.speed_encoder.device
    )

if do_classifier_free_guidance:
    # uncond_speed_emb = torch.zeros_like(speed_emb)
    # speed_emb = torch.cat([uncond_speed_emb, speed_emb], dim=0)
    speed_emb = speed_emb.repeat(2, 1)
    
with torch.no_grad():
    speed_emb = self.speed_encoder(speed_emb)

In [None]:
print("clip_img", clip_img.shape)
print("latents", latents.shape)
print("ref_image_latents", ref_image_latents.shape)
print("face_mask", face_mask.shape)
print("audio_tensor", audio_tensor.shape)
print("speed_emb", speed_emb.shape)

print("encoder_hidden_states", encoder_hidden_states.shape)
print("encoder_hidden_states.repeat", encoder_hidden_states.repeat(ref_image_latents.shape[0], 1, 1).shape)

# clip_img torch.Size([1, 3, 224, 224])
# latents torch.Size([1, 4, 16, 64, 64])
# ref_image_latents torch.Size([3, 4, 64, 64])
# face_mask torch.Size([2, 320, 16, 64, 64])
# audio_tensor torch.Size([2, 16, 32, 768])
# speed_emb torch.Size([2, 16, 768])
# encoder_hidden_states torch.Size([2, 1, 768])

# encoder_hidden_states.repeat torch.Size([6, 1, 768])


In [19]:
# timesteps

# denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

In [25]:

# with self.progress_bar(total=num_inference_steps) as progress_bar:
i = 0
t = timesteps[i]
    # for i, t in enumerate(timesteps):
# Forward reference image
if i == 0:
    with torch.no_grad():
        self.reference_unet(
            ref_image_latents.repeat(
                (2 if do_classifier_free_guidance else 1), 1, 1, 1
            ),
            torch.zeros_like(t),
            encoder_hidden_states=encoder_hidden_states.repeat(
                ref_image_latents.shape[0], 1, 1
            ),
            return_dict=False,
        )
    reference_control_reader.update(reference_control_writer)

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

with torch.no_grad():
    noise_pred = self.denoising_unet(
        latent_model_input,
        t,
        encoder_hidden_states=encoder_hidden_states,
        mask_cond_fea=face_mask,
        audio_embedding=audio_tensor,
        speed_embedding=speed_emb,
        return_dict=False,
    )[0]
torch.cuda.empty_cache()

In [None]:
encoder_hidden_states.shape

In [23]:
reader_attn_modules, writer_attn_modules = reference_control_reader.update(reference_control_writer)

In [None]:
reader_attn_modules[0].bank

In [None]:
writer_attn_modules[0].attn_score[0].shape

In [34]:
from diffusers.models.attention_processor import (
    AttnProcessor,
    AttnProcessor2_0,
    SpatialNorm,
)
from diffusers.models.attention import (
    AdaLayerNorm, 
    AdaLayerNormZero,
    Attention, 
    FeedForward
)

class CustomAttention(Attention):
    r"""
    A cross attention layer.

    Parameters:
        query_dim (`int`):
            The number of channels in the query.
        cross_attention_dim (`int`, *optional*):
            The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
        heads (`int`,  *optional*, defaults to 8):
            The number of heads to use for multi-head attention.
        dim_head (`int`,  *optional*, defaults to 64):
            The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0):
            The dropout probability to use.
        bias (`bool`, *optional*, defaults to False):
            Set to `True` for the query, key, and value linear layers to contain a bias parameter.
        upcast_attention (`bool`, *optional*, defaults to False):
            Set to `True` to upcast the attention computation to `float32`.
        upcast_softmax (`bool`, *optional*, defaults to False):
            Set to `True` to upcast the softmax computation to `float32`.
        cross_attention_norm (`str`, *optional*, defaults to `None`):
            The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
        cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups to use for the group norm in the cross attention.
        added_kv_proj_dim (`int`, *optional*, defaults to `None`):
            The number of channels to use for the added key and value projections. If `None`, no projection is used.
        norm_num_groups (`int`, *optional*, defaults to `None`):
            The number of groups to use for the group norm in the attention.
        spatial_norm_dim (`int`, *optional*, defaults to `None`):
            The number of channels to use for the spatial normalization.
        out_bias (`bool`, *optional*, defaults to `True`):
            Set to `True` to use a bias in the output linear layer.
        scale_qk (`bool`, *optional*, defaults to `True`):
            Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
        only_cross_attention (`bool`, *optional*, defaults to `False`):
            Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
            `added_kv_proj_dim` is not `None`.
        eps (`float`, *optional*, defaults to 1e-5):
            An additional value added to the denominator in group normalization that is used for numerical stability.
        rescale_output_factor (`float`, *optional*, defaults to 1.0):
            A factor to rescale the output by dividing it with this value.
        residual_connection (`bool`, *optional*, defaults to `False`):
            Set to `True` to add the residual connection to the output.
        _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
            Set to `True` if the attention block is loaded from a deprecated state dict.
        processor (`AttnProcessor`, *optional*, defaults to `None`):
            The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
            `AttnProcessor` otherwise.
    """

    def __init__(
        self,
        query_dim: int,
        cross_attention_dim: Optional[int] = None,
        heads: int = 8,
        dim_head: int = 64,
        dropout: float = 0.0,
        bias: bool = False,
        upcast_attention: bool = False,
        upcast_softmax: bool = False,
        cross_attention_norm: Optional[str] = None,
        cross_attention_norm_num_groups: int = 32,
        added_kv_proj_dim: Optional[int] = None,
        norm_num_groups: Optional[int] = None,
        spatial_norm_dim: Optional[int] = None,
        out_bias: bool = True,
        scale_qk: bool = True,
        only_cross_attention: bool = False,
        eps: float = 1e-5,
        rescale_output_factor: float = 1.0,
        residual_connection: bool = False,
        _from_deprecated_attn_block: bool = False,
        processor: Optional["AttnProcessor"] = None,
        out_dim: int = None,
    ):
        super().__init__(
            query_dim,
            cross_attention_dim,
            heads,
            dim_head,
            dropout,
            bias,
            upcast_attention,
            upcast_softmax,
            cross_attention_norm,
            cross_attention_norm_num_groups,
            added_kv_proj_dim,
            norm_num_groups,
            spatial_norm_dim,
            out_bias,
            scale_qk,
            only_cross_attention,
            eps,
            rescale_output_factor,
            residual_connection,
            _from_deprecated_attn_block,
            processor,
            out_dim,
        )
        

class CustomAttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        return_attn_score: Optional[bool] = False,
        *args,
        **kwargs,
    ) -> torch.Tensor:

        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        if return_attn_score:
            return hidden_states, attention_probs
        else:
            return hidden_states
    
aa = CustomAttention(
            query_dim=32,
            heads=8,
            dim_head=32,
            dropout=0.0,
            bias=False,
            cross_attention_dim=None,
            upcast_attention=False,
            )

In [27]:
with torch.no_grad():
    self.reference_unet(
        ref_image_latents.repeat(
            (2 if do_classifier_free_guidance else 1), 1, 1, 1
        ),
        torch.zeros_like(t),
        encoder_hidden_states=encoder_hidden_states,
        return_dict=False,
    )

In [None]:

with self.progress_bar(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        # Forward reference image
        if i == 0:
            with torch.no_grad():
                self.reference_unet(
                    ref_image_latents.repeat(
                        (2 if do_classifier_free_guidance else 1), 1, 1, 1
                    ),
                    torch.zeros_like(t),
                    encoder_hidden_states=encoder_hidden_states,
                    return_dict=False,
                )
            reference_control_reader.update(reference_control_writer)
        
        # encoder_hidden_states :: [2, 4, 768]
        # ref_image_latents.shape :: [3, 4, 64, 64]
        # ref_image_latents.repeat :: [6, 4, 64, 64]
        # latents :: [1, 4, 14, 64, 64]
        # audio_tensor :: [2, 14, 32, 768]
        # import pdb;pdb.set_trace()

        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

        with torch.no_grad():
            noise_pred = self.denoising_unet(
                latent_model_input,
                t,
                encoder_hidden_states=encoder_hidden_states,
                mask_cond_fea=face_mask,
                audio_embedding=audio_tensor,
                speed_embedding=speed_emb,
                return_dict=False,
            )[0]
        # perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

        # call the callback, if provided
        if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
            progress_bar.update()

    reference_control_reader.clear()
    reference_control_writer.clear()

# Post-processing
images = self.decode_latents(latents)  # (b, c, f, h, w)

# Convert to tensor
if output_type == "tensor":
    images = torch.from_numpy(images)
torch.cuda.empty_cache()

In [None]:
do_classifier_free_guidance

In [40]:
reader_attn_modules, writer_attn_modules = reference_control_reader.update(reference_control_writer)

In [None]:
# len(reader_attn_modules)
writer_attn_modules[0].bank

In [50]:
torch.cuda.empty_cache()

In [None]:
print(images.min()*255, images.max()*255)

In [23]:
# images.shape
torchvision.io.write_video('tmp.mp4',images[0].permute(1,2,3,0)*255,fps=30)

In [None]:

Video('tmp.mp4')

In [None]:
pipeline_output = pipeline(
    ref_image=pixel_values_ref_img, # [1, 3, 3, 512, 512]
    audio_tensor=audio_tensor, # [1, 16, 32, 768]
    face_emb=source_image_face_emb, # [1, 512]
    face_mask=source_image_face_region, # [1, 3, 512, 512]
    clip_img=source_image_clip_img, # [1, 3, 224, 224]
    speed_emb=speed_emb, # [1, 16]
    width=width,
    height=height,
    video_length=clip_length,
    num_inference_steps=25,
    guidance_scale=3.5,
    generator=generator,
)

# self.reference_unet(ref_image_latents.repeat((2 if do_classifier_free_guidance else 1), 1, 1, 1),torch.zeros_like(t),encoder_hidden_states=encoder_hidden_states,return_dict=False,)

In [43]:
torch.cuda.empty_cache()

In [None]:
# warping self-attention feature
# try Curved Diffusion

In [None]:
num_speed_buckets=8

aa = torch.linspace(-math.pi, math.pi, num_speed_buckets).repeat(8)
print(aa.shape)
print(aa)

bb = torch.linspace(0.01, math.pi, num_speed_buckets).repeat_interleave(8)
print(bb.shape)
print(bb)

In [None]:
head_rotation_speed = torch.randn(num_speed_buckets,1)
cc = torch.tanh((head_rotation_speed - aa)/bb)
cc.shape
cc

In [2]:

from emo.models.speed_encoder import SpeedEncoder

In [None]:
speed_enc = SpeedEncoder(4,768)
out = speed_enc(torch.randn(4,))
print(out.shape)
# print(out)

out = speed_enc(torch.randn(4,1))
print(out.shape)

out = speed_enc(torch.randn(4,2,))
print(out.shape)

In [None]:

from scripts.train_stage1_emo import Net
from hallo.models.face_locator import FaceLocator
from hallo.models.mutual_self_attention import ReferenceAttentionControl
from hallo.models.unet_2d_condition import UNet2DConditionModel
from hallo.models.unet_3d import UNet3DConditionModel
from diffusers import AutoencoderKL, DDIMScheduler

In [None]:
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
        "./pretrained_models/stable-diffusion-v1-5/",
        "",
        subfolder="unet",
        unet_additional_kwargs={
            "use_motion_module": False,
            "unet_use_temporal_attention": False,
        },
        use_landmark=False
    ).to(device="cuda")

In [18]:
class CFG():
    def __init__(self):
        self.base_model_path= "./pretrained_models/stable-diffusion-v1-5/"
        self.vae_model_path= "./pretrained_models/sd-vae-ft-mse"
        self.face_analysis_model_path= "./pretrained_models/face_analysis"
        self.face_locator_pretrained=False
        
cfg = CFG()
cfg.face_locator_pretrained
weight_dtype=torch.float32

In [None]:
# create model
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
    "cuda", dtype=weight_dtype
)
reference_unet = UNet2DConditionModel.from_pretrained(
    cfg.base_model_path,
    subfolder="unet",
).to(device="cuda", dtype=weight_dtype)
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
    cfg.base_model_path,
    "",
    subfolder="unet",
    unet_additional_kwargs={
        "use_motion_module": False,
        "unet_use_temporal_attention": False,
    },
    use_landmark=False
).to(device="cuda", dtype=weight_dtype)

face_locator = FaceLocator_EMO(out_channels=320).to(device="cuda", dtype=weight_dtype)

# Freeze
vae.requires_grad_(False)
denoising_unet.requires_grad_(False)
reference_unet.requires_grad_(False)
face_locator.requires_grad_(False)

reference_control_writer = ReferenceAttentionControl(
    reference_unet,
    do_classifier_free_guidance=False,
    mode="write",
    fusion_blocks="full",
)
reference_control_reader = ReferenceAttentionControl(
    denoising_unet,
    do_classifier_free_guidance=False,
    mode="read",
    fusion_blocks="full",
)

In [46]:

# face_locator_emo = FaceLocator_EMO(out_channels=320).to(device="cuda", dtype=weight_dtype)
face_locator = FaceLocator(
    conditioning_embedding_channels=320,
    conditioning_channels=1,
    act='relu',
).to(device="cuda", dtype=weight_dtype)

In [49]:
net = Net(
    reference_unet,
    denoising_unet,
    face_locator,
    reference_control_writer,
    reference_control_reader
).to(dtype=weight_dtype)

In [None]:
noisy_latents = torch.randn([1, 4, 1, 64, 64]) * 0.18215
timesteps = torch.randint(0, 1000, (1,)).long()
ref_image_latents = torch.randn([1 ,4, 64, 64]) * 0.18215

face_mask = torch.zeros(1, 1, 1, 512, 512)
face_mask[:, :, 100:300, 100:300] = 1

print(noisy_latents.shape)
print(timesteps.shape)
print(ref_image_latents.shape)
print(face_mask.shape)

In [None]:
face_mask_feature = face_locator(face_mask.cuda())
print(face_mask_feature.shape)

In [None]:
output = net(
        noisy_latents.cuda(),
        timesteps.cuda(),
        ref_image_latents.cuda(),
        face_mask.cuda(),
        uncond_fwd=True
    )

In [None]:
from IPython.display import Video, Audio
Video('/home/sihun.cha/Work/hallo/.cache/output.mp4')
# Audio('/home/sihun.cha/Work/hallo/.cache/audio_preprocess/1_(Vocals)_Kim_Vocal_2.wav')

In [None]:
Video('/home/sihun.cha/Work/toydata/vid_256/AdamKinzinger0_0.mp4')

In [None]:
image =  torch.zeros(2,1,512,512)
image[:,:,100:300,100:300]=1
plt.imshow(image.permute(0,2,3,1)[0].numpy())

In [None]:


from hallo.models.face_locator import FaceLocator_EMO

class FaceLocator(nn.Module):
    def __init__(self, in_channel=1, out_channel=4):
        super(FaceLocator, self).__init__()
        # Define convolutional layers
        self.conv1 = nn.Conv2d(in_channel, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        # Define the final convolutional layer that outputs a single channel (mask)
        self.final_conv = nn.Conv2d(64, out_channel, kernel_size=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, images):
        # Forward pass through the convolutional layers
        # Assert that images are of the correct type (floating-point)
        assert images.dtype == torch.float32, 'Images must be of type torch.float32'
        # Assert that images have 4 dimensions [B, C, H, W]
        assert images.ndim == 4, 'Images must have 4 dimensions [B, C, H, W]'

        x = F.relu(self.conv1(images))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)  # Shape after pooling: (B, 64, H/8, W/8)
        

        assert x.size(1) == 64, f"Input to final conv layer has {x.size(1)} channels, expected 64."

        # Pass through the final convolutional layer to get a single channel output
        logits = self.final_conv(x)  # Output logits directly, Shape: (B, 1, H/8, W/8)
        
        # No sigmoid or thresholding here because BCEWithLogitsLoss will handle it

        # Upsample logits to the size of the original image
        # logits = F.interpolate(logits, size=(images.shape[2], images.shape[3]), mode='bilinear', align_corners=False)
        
        return logits

# def normalize(v):
#     norm = np.linalg.norm(v)
#     if norm == 0: 
#        return v
#     return v / norm
def normalize(V):
    V = (V - (V.max(0).max(0) + V.min(0).min(0)) *0.5) / max(V.max(0).max(0) - V.min(0).min(0))
    V = V + 0.5
    return V

In [None]:
foo = FaceLocator()
output = foo(image)
print(output.shape)

plt.imshow(
    normalize(output[0].detach().numpy().transpose(1,2,0))
)

In [None]:
foo = FaceLocator_EMO(1,4)
output = foo(image)
print(output.shape)

In [None]:

plt.imshow(
    normalize(output[0].detach().numpy().transpose(1,2,0))
)