In [1]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

NVIDIA GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [5]:
from IPython.display import Image


In [2]:
import torch
import numpy as np
import os
import torch.nn as nn
from tqdm import tqdm
import json
from functools import partial
from torch import einsum, nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from einops import pack, rearrange, reduce, repeat, unpack


In [3]:
def findAllFile(base):
    file_path = []
    for root, ds, fs in os.walk(base, followlinks=True):
        for f in fs:
            fullname = os.path.join(root, f)
            file_path.append(fullname)
    return file_path



## Helpers

In [6]:
from configs.config import cfg, get_cfg_defaults
def load_vqvae(gen_cfg):

    body_cfg = get_cfg_defaults()
    body_cfg.merge_from_file(gen_cfg.vqvae.body_config)
    body_model = (
        instantiate_from_config(body_cfg.vqvae).to(device).eval()
    )
    body_model.load(os.path.join(body_cfg.output_dir, "vqvae_motion.pt"))

    if gen_cfg.vqvae.left_hand_config is  None and gen_cfg.vqvae.right_hand_config is None:
        return body_model, body_cfg
    
    if gen_cfg.vqvae.left_hand_config is not None:
        left_cfg = get_cfg_defaults()
        left_cfg.merge_from_file(gen_cfg.vqvae.left_hand_config)
        left_hand_model = instantiate_from_config(left_cfg.vqvae).to(device).eval()
        left_hand_model.load(
            os.path.join(left_cfg.output_dir, "vqvae_motion.pt")
        )
    else:
        left_hand_model = None
        
    if gen_cfg.vqvae.right_hand_config is not None:
        right_cfg = get_cfg_defaults()
        right_cfg.merge_from_file(gen_cfg.vqvae.right_hand_config)
        right_hand_model = instantiate_from_config(right_cfg.vqvae).to(device).eval()
        right_hand_model.load(
            os.path.join(right_cfg.output_dir, "vqvae_motion.pt")
        )
    else:
        right_hand_model = None

    return body_model, left_hand_model , right_hand_model , body_cfg , left_cfg , right_cfg

def bkn_to_motion( codes, dset , remove_translation = True):
    # codes b k n

    k = codes.shape[1]
    mrep = dset.motion_rep

    if k == 1:
        if mrep == "body":

            body_inds = codes[:, 0]
            body_motion = body_model.decode(body_inds[0:1]).detach().cpu()

            if remove_translation:
                z = torch.zeros(
                    body_motion.shape[:-1] + (2,),
                    dtype=body_motion.dtype,
                    device=body_motion.device,
                )
                body_motion = torch.cat(
                    [body_motion[..., 0:1], z, body_motion[..., 1:]], -1
                )

            body_M = dset.toMotion(
                body_motion[0],
                motion_rep=MotionRep("body"),
                hml_rep=body_cfg.dataset.hml_rep,
            )

            return body_M

        elif mrep == "left_hand":

            left_inds = codes[:, 0]
            left_motion = left_hand_model.decode(left_inds[0:1]).detach().cpu()
            left_M = dset.toMotion(
                left_motion[0],
                motion_rep=MotionRep(left_cfg.dataset.motion_rep),
                hml_rep=left_cfg.dataset.hml_rep,
            )
            return left_M

        elif mrep == "right_hand":
            right_inds = codes[:, 0]
            right_motion = (
                right_hand_model.decode(right_inds[0:1]).detach().cpu()
            )
            right_M = dset.toMotion(
                right_motion[0],
                motion_rep=MotionRep(right_cfg.dataset.motion_rep),
                hml_rep=right_cfg.dataset.hml_rep,
            )
            return right_M

    if k == 2:
        left_inds = codes[:, 0]
        right_inds = codes[:, 1]

        left_motion = left_hand_model.decode(left_inds[0:1]).detach().cpu()
        right_motion = right_hand_model.decode(right_inds[0:1]).detach().cpu()

        left_M = dset.toMotion(
            left_motion[0],
            motion_rep=MotionRep(left_cfg.dataset.motion_rep),
            hml_rep=left_cfg.dataset.hml_rep,
        )
        right_M = dset.toMotion(
            right_motion[0],
            motion_rep=MotionRep(right_cfg.dataset.motion_rep),
            hml_rep=right_cfg.dataset.hml_rep,
        )
        hand_M = left_M + right_M
        hand_M.motion_rep = MotionRep.HAND
        hand_M.hml_rep = "".join(
            [i for i in left_M.hml_rep if i in right_M.hml_rep]
        )
        return hand_M

    if k == 3:
        left_inds = codes[:, 1]
        right_inds = codes[:, 2]
        body_inds = codes[:, 0]
        body_motion = body_model.decode(body_inds[0:1]).detach().cpu()

        
        if remove_translation:
            z = torch.zeros(
                body_motion.shape[:-1] + (2,),
                dtype=body_motion.dtype,
                device=body_motion.device,
            )
            body_motion = torch.cat([body_motion[..., 0:1], z, body_motion[..., 1:]], -1)

        left_motion = left_hand_model.decode(left_inds[0:1]).detach().cpu()
        right_motion = right_hand_model.decode(right_inds[0:1]).detach().cpu()

        body_M = dset.toMotion(
            body_motion[0],
            motion_rep=MotionRep("body"),
            hml_rep = body_cfg.dataset.hml_rep)

        left_M = dset.toMotion(
            left_motion[0],
            motion_rep=MotionRep("left_hand"),
            hml_rep=left_cfg.dataset.hml_rep)
        right_M = dset.toMotion(
            right_motion[0],
            motion_rep=MotionRep("right_hand"),
            hml_rep=right_cfg.dataset.hml_rep)
        full_M = dset.to_full_joint_representation(body_M, left_M, right_M)
        return full_M



## Translation

In [7]:
from core.models.utils import instantiate_from_config, get_obj_from_str


In [8]:
from configs.config_t2o import get_cfg_defaults as trans_get_cfg_defaults
trans_cfg = trans_get_cfg_defaults()
trans_cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/simple_motion_translation/simple_motion_translation.yaml")

In [9]:
trans_model = instantiate_from_config(trans_cfg.vqvae).to(device).eval()
trans_model.load(os.path.join(trans_cfg.output_dir, "tcn_model.pt"))
trans_model.freeze()

Sync is turned on False
loaded model with  0.06821402907371521 tensor([40000.], device='cuda:0') steps


In [46]:
def transform(data: np.ndarray , data_root = "/srv/hays-lab/scratch/sanisetty3/motionx") -> np.ndarray:
    mean_pos = torch.Tensor(np.load(os.path.join(data_root, "motion_data/Mean_rel_pos.npy"))[[0,2]]).to(data.device)
    std_pos = torch.Tensor(np.load(os.path.join(data_root, "motion_data/Std_rel_pos.npy"))[[0,2]]).to(data.device)
    return (data - mean_pos) / (std_pos + 1e-8)

In [196]:
def traj2orient(traj):
    rel_pos = torch.zeros_like(traj)
    rel_pos[..., 1:] = traj[..., 1:] - traj[..., :-1]
    rel_pos2 = transform(rel_pos)
    with torch.no_grad():
        pred_orient = trans_model.predict(rel_pos2.to(device))

    return pred_orient, rel_pos



    

In [34]:
traj.shape

torch.Size([1, 100, 2])

In [33]:
pred_orient.shape

torch.Size([1, 100, 4])

In [58]:
def reverse_recover_root_rot_pos(r_rot_quat, r_pos):
    ## B N 4, B N 2
    # Step 1: Compute the original root linear velocity
    root_linear_velocity = torch.zeros_like(r_pos[..., :, [0, 1]])
    root_linear_velocity[..., 1:, 0] = r_pos[..., 1:, 0] - r_pos[..., :-1, 0]
    root_linear_velocity[..., 1:, 1] = r_pos[..., 1:, 1] - r_pos[..., :-1, 1]

    # Step 2: Compute the original root rotation velocity
    r_rot_ang = torch.atan2(r_rot_quat[..., 2], r_rot_quat[..., 0])
    root_rot_velocity = torch.zeros_like(r_rot_ang)
    root_rot_velocity[..., 1:] = r_rot_ang[..., 1:] - r_rot_ang[..., :-1]

    # Step 3: Combine root linear velocity and root rotation velocity to get root_params
    # root_params = torch.cat((root_rot_velocity.unsqueeze(-1), root_linear_velocity, r_pos[..., 1:, [1]].unsqueeze(-1)), dim=-1)

    root_params = torch.cat((root_rot_velocity.unsqueeze(-1), root_linear_velocity), dim=-1)


    return root_params

## Refiner

In [158]:
refiner_cfg = get_cfg_defaults()
refiner_cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/vqvae/vqvae_full_gpvc/vqvae_full_gpvc.yaml")

# dataset_args = refiner_cfg.dataset
refiner_model = instantiate_from_config(refiner_cfg.vqvae).to(device).eval()
# refiner_model.load("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/vqvae/vqvae_body_gprvc/vqvae_motion.pt")
refiner_model.load(os.path.join(refiner_cfg.output_dir, "vqvae_motion.pt"))
refiner_model.freeze()

## MUSE

In [16]:
from core.models.generation.muse2 import generate_animation
from core import MotionTokenizerParams, pattern_providers

from core.param_dataclasses import pattern_providers
from core.datasets.multimodal_dataset import MotionIndicesAudioTextDataset, load_dataset_gen, simple_collate
from core.models.utils import instantiate_from_config, get_obj_from_str
from core import MotionRep, AudioRep, TextRep
from core.datasets.conditioner import ConditionProvider,ConditionFuser
from core.models.generation.muse2 import MotionMuse as MotionMuse2
import einops
from configs.config_t2m import get_cfg_defaults as muse_get_cfg_defaults
from core import MotionTokenizerParams

In [17]:
gen_cfg = muse_get_cfg_defaults()
gen_cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_muse_body_hands/motion_muse_body_hands.yaml")
gen_cfg.freeze()
tranformer_config = gen_cfg.motion_generator
fuse_config = gen_cfg.fuser
pattern_config = gen_cfg.codebooks_pattern
dataset_args = gen_cfg.dataset

target = tranformer_config.pop("target")


In [18]:
motion_gen = MotionMuse2(tranformer_config , fuse_config , pattern_config).to(device).eval()
pkg = torch.load("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_muse_body_hands/motion_muse.pt", map_location="cuda")
motion_gen.load_state_dict(pkg["model"])

Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda


<All keys matched successfully>

In [19]:
body_model, left_hand_model , right_hand_model , body_cfg , left_cfg , right_cfg = load_vqvae(gen_cfg)

Sync is turned on False
loaded model with  0.03015906736254692 tensor([110000.], device='cuda:0') steps


In [20]:
condition_provider = ConditionProvider(
            text_conditioner_name = dataset_args.text_conditioner_name,
            motion_rep=MotionRep(dataset_args.motion_rep),
            audio_rep=AudioRep(dataset_args.audio_rep),
            text_rep=TextRep(dataset_args.text_rep),
            motion_padding=dataset_args.motion_padding,
            audio_padding=dataset_args.audio_padding,
            motion_max_length_s=dataset_args.motion_max_length_s,
            audio_max_length_s=dataset_args.audio_max_length_s,
            pad_id = MotionTokenizerParams(tranformer_config.num_tokens).pad_token_id,
            fps=30/4,
            # device = "cpu"
        )



In [60]:
from core.datasets.base_dataset import BaseMotionDataset
base_dset = BaseMotionDataset(motion_rep=MotionRep.BODY , hml_rep= "gpvc")

In [62]:
aud_clip =  None #"/srv/hays-lab/scratch/sanisetty3/motionx/audio/wav/wild/despacito.mp3"
text_ = "A person walking forward"
neg_text_ = None, #"dancing"

In [64]:
all_ids_body = generate_animation(motion_gen , condition_provider ,overlap = 5, duration_s = 8 , aud_file=aud_clip, text = text_ , neg_text=neg_text_, use_token_critic = True, timesteps = 24 )
gen_motion = bkn_to_motion(all_ids_body, base_dset)


100%|███████████████████████████████████████████████████████████████████| 72/72 [00:01<00:00, 43.93it/s]
100%|███████████████████████████████████████████████████████████████████| 72/72 [00:01<00:00, 43.95it/s]


In [69]:
gen_motion().shape

torch.Size([220, 317])

In [70]:
base_dset.render_hml(
                    gen_motion,
                    f"/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/gen_novel_full.gif",
                    zero_trans = True,
                    zero_orient = True,
    
                )

## Refine

In [84]:
def traj2orient(traj):
    rel_pos = torch.zeros_like(traj)
    rel_pos[..., 1:] = traj[..., 1:] - traj[..., :-1]
    rel_pos2 = transform(rel_pos)
    with torch.no_grad():
        pred_orient = trans_model.predict(rel_pos2.to(device))

    return pred_orient, rel_pos



    

In [111]:


rel_pos = torch.zeros_like(traj)
rel_pos[..., 1:] = traj[..., 1:] - traj[..., :-1]
pred_orient = traj2orient(traj)
root_params_pred = reverse_recover_root_rot_pos(pred_orient , traj)


torch.Size([1, 220]) torch.Size([1, 220, 2])


In [180]:
pred_orient, rel_pos= traj2orient(traj)

In [188]:
gen_motion = bkn_to_motion(all_ids_body, base_dset)


In [255]:
traj = torch.zeros((1 , 220 , 2)).to(device)
traj[...,1] = torch.linspace(0,20,220)
rel_pos = torch.zeros_like(traj)
rel_pos[:, 1:] = traj[:, 1:] - traj[:, :-1]

In [256]:
gen_motion = bkn_to_motion(all_ids_body, base_dset)
ohpvc_full = gen_motion
ohpvc_full_inv = base_dset.inv_transform(ohpvc_full)
ohpvc_full_inv.root_params[:,1:3] = rel_pos
ohpvc_full_inv.root_params[:,0:1] = 0
ohpvc_full_trns = base_dset.transform(ohpvc_full_inv)
ohpvc_full_inv.contact[:,:] = 0

In [257]:
out = refiner_model(ohpvc_full_trns()[None].to(device))

In [258]:
refined_Motion = base_dset.toMotion(out.decoded_motion[0] , MotionRep("full") , hml_rep = "gpvc" )

In [259]:
base_dset.render_hml(
                    refined_Motion,
                    f"/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/gen_novel_full_add_traj.gif",
                    # zero_trans = True,
                    # zero_orient = True,
    
                )

In [1]:
# Image(open(f"/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/gen_novel_full_add_traj.gif",'rb').read())