In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import tyro
import numpy as np
import imageio
import cv2

from src.modules.motion_extractor import MotionExtractor
from src.modules.appearance_feature_extractor import AppearanceFeatureExtractor
from src.live_portrait_wrapper import LivePortraitWrapper
from src.modules.vqvae import VQVae
from train_tokenizer import Dataset
from IPython.display import Video, Image

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
ckpt = torch.load("models/checkpoints/last.ckpt")
vqvae_params = {k[6:]: v for k, v in ckpt['state_dict'].items() if 'vqvae' in k}

In [None]:
vqvae = VQVae(
    nfeats=72,
    code_num=512,
    code_dim=512,
    output_emb_width=512,
    down_t=3,
    stride_t=2,
    width=512,
    depth=3,
    dilation_growth_rate=3,
    norm=None,
    activation="relu",
    codebook_logger=None,
).to('cuda')
vqvae.load_state_dict(vqvae_params, strict=True)

In [4]:
m_extr = MotionExtractor(
    num_kp=21,
    backbone='convnextv2_tiny'
)
m_extr.load_pretrained(init_path="pretrained_weights/liveportrait/base_models/motion_extractor.pth")
m_extr.to('cuda')
m_extr.eval()
print()




In [None]:
appearance_feature_extractor = AppearanceFeatureExtractor(
    image_channel=3,
    block_expansion=64,
    num_down_blocks=2,
    max_features=512,
    reshape_channel=32,
    reshape_depth=16,
    num_resblocks=6
)
appearance_feature_extractor.load_state_dict(
    torch.load("pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth")
)
appearance_feature_extractor.to('cuda')
appearance_feature_extractor.eval()
print()

### Load source image

In [6]:
from src.utils.io import load_image_rgb, resize_to_limit

img = load_image_rgb("mark.png")
img = resize_to_limit(img, 1280, 2)

### Load driving video

In [8]:
ds = Dataset(data_path="dataset")

In [10]:
def resize_to_limit(img: np.ndarray, max_dim=1920, division=2):
    """
    ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
    :param img: the image to be processed.
    :param max_dim: the maximum dimension constraint.
    :param n: the number that needs to be multiples of.
    :return: the adjusted image.
    """
    h, w = img.shape[:2]

    # ajust the size of the image according to the maximum dimension
    if max_dim > 0 and max(h, w) > max_dim:
        if h > w:
            new_h = max_dim
            new_w = int(w * (max_dim / h))
        else:
            new_w = max_dim
            new_h = int(h * (max_dim / w))
        img = cv2.resize(img, (new_w, new_h))

    # ensure that the image dimensions are multiples of n
    division = max(division, 1)
    new_h = img.shape[0] - (img.shape[0] % division)
    new_w = img.shape[1] - (img.shape[1] % division)

    if new_h == 0 or new_w == 0:
        # when the width or height is less than n, no need to process
        return img

    if new_h != img.shape[0] or new_w != img.shape[1]:
        img = img[:new_h, :new_w]

    return img

In [9]:
ds[0]

(tensor([[[[[0.0863, 0.0863, 0.0863,  ..., 0.0784, 0.0784, 0.0784],
            [0.0863, 0.0863, 0.0863,  ..., 0.0784, 0.0784, 0.0784],
            [0.0863, 0.0863, 0.0863,  ..., 0.0784, 0.0784, 0.0784],
            ...,
            [0.6824, 0.6824, 0.6824,  ..., 0.5961, 0.6039, 0.5686],
            [0.6824, 0.6824, 0.6824,  ..., 0.5686, 0.5647, 0.5608],
            [0.6824, 0.6824, 0.6824,  ..., 0.5608, 0.5373, 0.5608]],
 
           [[0.0902, 0.0902, 0.0902,  ..., 0.0941, 0.0941, 0.0941],
            [0.0902, 0.0902, 0.0902,  ..., 0.0941, 0.0941, 0.0941],
            [0.0902, 0.0902, 0.0902,  ..., 0.0941, 0.0941, 0.0941],
            ...,
            [0.6745, 0.6745, 0.6745,  ..., 0.5961, 0.6039, 0.5686],
            [0.6745, 0.6745, 0.6745,  ..., 0.5686, 0.5725, 0.5686],
            [0.6745, 0.6745, 0.6745,  ..., 0.5608, 0.5451, 0.5686]],
 
           [[0.1020, 0.1020, 0.1020,  ..., 0.1098, 0.1098, 0.1098],
            [0.1020, 0.1020, 0.1020,  ..., 0.1098, 0.1098, 0.1098],
        

In [None]:
with torch.no_grad():
    batch = ds[0].to('cuda')
    kp_infos = {}
    kps = []
    t = []
    exp = []
    pitch = []
    roll = []
    yaw = []
    for image in batch:
        motion = m_extr(image)

        kp = motion['kp'].squeeze(0)
        kps.append(kp)
        t.append(motion['t'])
        exp.append(motion['exp'])
        pitch.append(motion['pitch'])
        roll.append(motion['roll'])
        yaw.append(motion['yaw'])
    kp_infos['kp'] = torch.stack(kps)
    kp_infos['t'] = torch.stack(t)
    kp_infos['exp'] = torch.stack(exp)
    kp_infos['pitch'] = torch.stack(pitch).squeeze(1)
    kp_infos['roll'] = torch.stack(roll).squeeze(1)
    kp_infos['yaw'] = torch.stack(yaw).squeeze(1)


In [11]:
from src.utils.camera import get_rotation_matrix, headpose_pred_to_degree

In [9]:
def process_kps(kp_infos: dict):
    bs = kp_infos['kp'].shape[0]
    kp_infos['pitch'] = headpose_pred_to_degree(kp_infos['pitch'])[:, None]  # Bx1
    kp_infos['yaw'] = headpose_pred_to_degree(kp_infos['yaw'])[:, None]  # Bx1
    kp_infos['roll'] = headpose_pred_to_degree(kp_infos['roll'])[:, None]  # Bx1
    kp_infos['kp'] = kp_infos['kp'].reshape(bs, -1, 3)  # BxNx3
    kp_infos['exp'] = kp_infos['exp'].reshape(bs, -1, 3)  # BxNx3

    return kp_infos

In [None]:
kp_info = process_kps(kp_infos)

In [None]:
kp_infos['kp'].unsqueeze(0).shape

In [17]:
with torch.no_grad():
    reconstr, commit_loss, perplexity = vqvae(kp_infos['kp'].unsqueeze(0))


In [13]:
rot = get_rotation_matrix(kp_info['pitch'], kp_info['yaw'], kp_info['roll'])

In [2]:
import pickle

In [3]:
with open("dataset/pickles/_BeHUjskbZo_2.pkl", "rb") as f:
    motion_template = pickle.load(f)

In [21]:
motion_template['motion'][0]['t'].shape

(1, 3)

In [3]:
from src.dataset import Dataset

In [30]:
ds = Dataset(data_path="dataset")

Loaded 2596 train samples


In [31]:
kp = ds[98]['kp']
exp = ds[98]['exp']
x_s = ds[98]['x_s']
t = ds[98]['t']
R = ds[98]['R']
scale = ds[98]['scale']
c_eyes_lst = ds[98]['c_eyes_lst']

In [34]:
n_frames = kp.shape[0]

kp = kp.reshape(n_frames, -1)
exp = exp.reshape(n_frames, -1)

torch.concat([kp, exp], dim=1)


torch.Size([149, 126])