In [1]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

NVIDIA GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import torch
import numpy as np
import os
import torch.nn as nn
from tqdm import tqdm
import json
from functools import partial
from torch import einsum, nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from einops import pack, rearrange, reduce, repeat, unpack


In [4]:
def findAllFile(base):
    file_path = []
    for root, ds, fs in os.walk(base, followlinks=True):
        for f in fs:
            fullname = os.path.join(root, f)
            file_path.append(fullname)
    return file_path



In [5]:
# from utils.motion_processing.hml_process import recover_from_ric, recover_root_rot_pos,recover_from_rot
import utils.vis_utils.plot_3d_global as plot_3d
import matplotlib.pyplot as plt

def vis(mot , dset , name = "motion"):

    if isinstance(mot , torch.Tensor):
        mot = dset.toMotion(mot)
    mot =dset.inv_transform(mot)



    xyz = np.array(dset.to_xyz(mot).cpu())

    print(xyz.shape)

    
    plot_3d.render(xyz , f"/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/{name}.gif")

In [6]:
from configs.config import cfg, get_cfg_defaults
from configs.config_streaming import get_cfg_defaults as strm_get_cfg_defaults

In [7]:
gen_cfg = strm_get_cfg_defaults()
gen_cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_streaming/motion_streaming.yaml")
gen_cfg.freeze()

## VQVAE

In [8]:
from core.models.resnetVQ.vqvae import HumanVQVAE


In [9]:

body_cfg = get_cfg_defaults()
body_cfg.merge_from_file(gen_cfg.vqvae.body_config)
left_cfg = get_cfg_defaults()
left_cfg.merge_from_file(gen_cfg.vqvae.left_hand_config)
right_cfg = get_cfg_defaults()
right_cfg.merge_from_file(gen_cfg.vqvae.right_hand_config)

In [10]:
left_hand_model = HumanVQVAE(left_cfg.vqvae).to("cpu").eval()
left_hand_model.load(os.path.join(left_cfg.output_dir, "vqvae_motion.pt"))

right_hand_model = HumanVQVAE(right_cfg.vqvae).to("cpu").eval()
right_hand_model.load(os.path.join(right_cfg.output_dir, "vqvae_motion.pt"))

body_model = HumanVQVAE(body_cfg.vqvae).to("cpu").eval()
body_model.load(os.path.join(body_cfg.output_dir, "vqvae_motion.pt"))



## Motion Gen

In [12]:
from core import MotionRep, AudioRep, TextRep
from core.datasets.conditioner import ConditionProvider, ConditionFuser
from core.datasets.multimodal_dataset import MotionIndicesAudioTextDataset, load_dataset_gen, simple_collate
from core.models.generation.motion_generator import Transformer, MotionMuse
from core.models.utils import instantiate_from_config, get_obj_from_str


  from .autonotebook import tqdm as notebook_tqdm


In [13]:
from configs.config_t2m import cfg, get_cfg_defaults
from configs.config import get_cfg_defaults as get_cfg_defaults3

cfg = get_cfg_defaults()
cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_generation/motion_generation.yaml")
cfg.freeze()
mmuse_args = cfg.motion_generator
dataset_args = cfg.dataset


In [14]:
target = mmuse_args.pop("target")
motion_muse = MotionMuse(mmuse_args).to(device).eval()

Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda


In [9]:
from core.models.resnetVQ.vqvae import HumanVQVAE

In [57]:

vcfg = get_cfg_defaults3()
vcfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/vqvae/vqvae_body_gprvc/vqvae_body_gprvc.yaml")
vqvae_args = vcfg.vqvae
vqvae_args.nb_joints = 22
vqvae_args.motion_dim = 263

In [None]:
np.split(np.cusum)

In [14]:
vqvae_model = HumanVQVAE(vqvae_args).to(device).eval()
vqvae_model.load("/srv/hays-lab/scratch/sanisetty3/music_motion/ACMG/checkpoints/smplx_resnet/vqvae_motion.pt")

In [15]:
condition_provider = ConditionProvider(
            motion_rep=MotionRep(dataset_args.motion_rep),
            audio_rep=AudioRep(dataset_args.audio_rep),
            text_rep=TextRep(dataset_args.text_rep),
            motion_padding=dataset_args.motion_padding,
            audio_padding=dataset_args.audio_padding,
            motion_max_length_s=10,
            audio_max_length_s=10,
            pad_id = motion_muse.transformer.pad_token_id,
            fps=30/4
        )



In [16]:
# bod_ind = np.load("/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/aist/subset_0000/Dance_Break_3_Step_clip_1.npy")
# lh_ind = np.load("/srv/hays-lab/scratch/sanisetty3/motionx/indices/left_hand/aist/subset_0000/Dance_Break_3_Step_clip_1.npy")
# rh_ind = np.load("/srv/hays-lab/scratch/sanisetty3/motionx/indices/right_hand/aist/subset_0000/Dance_Break_3_Step_clip_1.npy")

In [18]:
from core.datasets.multimodal_dataset import MotionIndicesAudioTextDataset, load_dataset_gen, simple_collate

dset = MotionIndicesAudioTextDataset("animation" , "/srv/hays-lab/scratch/sanisetty3/motionx" ,motion_rep = "full", split = "train" , fps = 30/4  )

Total number of motions animation: 265 and texts 265


In [54]:
inpss  = next(iter(dset))

In [58]:
train_ds, sampler_train, weights_train  = load_dataset_gen(dataset_args=dataset_args, split = "train" , dataset_names = ["animation" , "choreomaster" ] )

Total number of motions animation: 73 and texts 73
Total number of motions choreomaster: 34 and texts 34


In [19]:
train_loader = torch.utils.data.DataLoader(
        dset,
        4,
        # sampler=sampler_train,
        # shuffle = False,
        collate_fn=partial(simple_collate , conditioner = condition_provider),
        drop_last=True,
    )

In [20]:
for inputs, conditions in train_loader:
    break
    

In [21]:
inputs["motion"][0].shape

torch.Size([4, 52, 3])

In [22]:
inputs["motion"][1].shape

torch.Size([4, 52])

In [23]:
conditions["audio"][0].shape

torch.Size([4, 1, 128])

In [24]:
conditions["text"][0].shape

torch.Size([4, 1, 768])

In [35]:
motions = inputs["motion"][0].squeeze().to(torch.long)
motion_mask = inputs["motion"][1]

In [164]:
fuse_method = {"cross": ["audio"], "prepend": ["text"]}
condition_fuser = ConditionFuser(fuse_method)

In [None]:
audio_embed = self.project_audio(conditions["audio"][0])
text_embed = self.project_text(conditions["text"][0])

inputs_, cross_inputs = self.condition_fuser(
    input,
    {
        "text": (text_embed, conditions["text"][1]),
        "audio": (audio_embed, conditions["audio"][1]),
    },
)

## MotionMuse

In [36]:
loss , logits = motion_muse((motions , motion_mask) , conditions , cond_drop_prob = 0.4 , return_logits = True)

In [79]:
pred_indices = lologits.argmax(-1)
pred_motion  = vqvae_model.decode(pred_indices[:1])

In [65]:
mod_motion = torch.where(motions >= 1024 , 0 , motions)
gt_motion  = vqvae_model.decode(mod_motion)

In [71]:
gt_motion[1].shape

torch.Size([300, 263])

In [66]:
dset.render_hml(
                    gt_motion[1][:(int(sum(motion_mask[1])) *4)].detach().squeeze().cpu(),
                    "/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/gt_motion_recon.gif"
                )

In [None]:
logits2 = motion_muse.transformer.forward_with_cond_scale((motions , motion_mask) , conditions)
logits3 =motion_muse.transformer.forward_with_neg_prompt((motions , motion_mask) , conditions , conditions)

In [48]:
motions.shape

torch.Size([4, 28])

In [49]:
28*4

112

In [67]:
gen_ids = motion_muse.generate(conditions)

100%|█████████████████████████████████████████████████████████████████████████| 18/18 [00:01<00:00, 14.65it/s]


## Streaming transformer

In [8]:

from core.param_dataclasses import pattern_providers
from core.datasets.multimodal_dataset import MotionIndicesAudioTextDataset, load_dataset_gen, simple_collate
from core.models.utils import instantiate_from_config, get_obj_from_str


  from .autonotebook import tqdm as notebook_tqdm


In [9]:
from core import MotionRep, AudioRep, TextRep
from core.datasets.conditioner import ConditionProvider,ConditionFuserStreamer
from core.models.generation.lm import LMModel, MotionGen
import einops

In [10]:
gen_cfg = strm_get_cfg_defaults()
gen_cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_streaming/motion_streaming.yaml")
gen_cfg.freeze()

In [11]:
lm_args = gen_cfg.transformer_lm
target = lm_args.pop("target")
fuse_config = gen_cfg.fuser
pattern_args = gen_cfg.codebooks_pattern
dataset_args = gen_cfg.dataset

In [12]:
model_gen= MotionGen(lm_args , fuse_config , pattern_args ).to(device)
model_gen = model_gen.eval()

In [64]:
modeling = pattern_args.pop("modeling")
pattern_provider = pattern_providers[modeling](lm_args.n_q, delays = pattern_args.delays , flatten_first = pattern_args.flatten_first , empty_initial = pattern_args.empty_initial )

In [65]:
fuse_method = fuse_config.pop("fuse_method")


In [66]:
fuse_method = {'cross': ['text'], 'input_interpolate': ['audio']}

In [68]:
if isinstance(fuse_method, list):
    fuse_method = fuse_method[0]
condition_fuser = ConditionFuserStreamer(fuse_method, **fuse_config)

In [69]:
model = LMModel(
            pattern_provider=pattern_provider,
            fuser=condition_fuser,
            **lm_args
        ).to(device)

In [13]:
condition_provider = ConditionProvider(
            motion_rep=MotionRep(dataset_args.motion_rep),
            audio_rep=AudioRep(dataset_args.audio_rep),
            text_rep=TextRep(dataset_args.text_rep),
            motion_padding=dataset_args.motion_padding,
            audio_padding=dataset_args.audio_padding,
            motion_max_length_s=10,
            audio_max_length_s=10,
            pad_id = model_gen.model.pad_token_id,
            fps=30/4,
            # device = "cpu"
        )



In [14]:
from core.datasets.multimodal_dataset import MotionIndicesAudioTextDataset, load_dataset_gen, simple_collate
# dset = MotionIndicesAudioTextDataset("chroeomaster" , "/srv/hays-lab/scratch/sanisetty3/motionx" ,motion_rep = "full", split = "train" , fps = 30/4  )


In [15]:
train_ds, sampler_train, weights_train  = load_dataset_gen(dataset_args=dataset_args, split = "test" , dataset_names = ["animation" , "choreomaster"] )
train_loader = torch.utils.data.DataLoader(
        train_ds,
        4,
        # sampler=sampler_train,
        # shuffle = False,
        collate_fn=partial(simple_collate , conditioner = condition_provider , permute = True),
        drop_last=True,
    )

Total number of motions animation: 8 and texts 8
Total number of motions choreomaster: 2 and texts 2


In [21]:
fles = findAllFile("/srv/hays-lab/scratch/sanisetty3/motionx/indices/body")

In [29]:
np.load(fles[2000]).shape

(1, 5)

In [30]:
lenss = []
for p in tqdm(fles):
    try:
        lenss.append(np.load(p).shape[-1])
    except:
        print(p)

 31%|████████████████████████▋                                                       | 24685/80067 [00:17<00:35, 1582.06it/s]

/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/animation/subset_0003/Ways_To_Stand_Groomsman.npy
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/animation/subset_0003/Ways_To_Pick_Up_A_Dollar_Cartwheel.npy
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/animation/subset_0003/Ways_To_Pick_Up_A_Dollar_Dramatic.npy
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/animation/subset_0000/Ways_To_Enter_A_Room_Candy_Store.npy
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/animation/subset_0002/Ways_To_Jump_In_Swim_Get_Out_Of_A_Pool_Too_Cold.npy


100%|████████████████████████████████████████████████████████████████████████████████| 80067/80067 [00:53<00:00, 1493.14it/s]


In [16]:
for inputs, conditions in train_loader:
    break
    

In [17]:
inputs["names"]

array(['animation/subset_0000/Ways_To_Catch_Juggling',
       'animation/subset_0003/Ways_To_Text_One_Minute',
       'animation/subset_0001/Ways_To_Jump_Sit_Fall_Cool_Teacher',
       'animation/subset_0003/Ways_To_Sit_Robot'], dtype='<U56')

In [42]:
input_mask = inputs["motion"][1]
motions_or_ids = inputs["motion"][0].cuda()
B, K, T = motions_or_ids.shape

In [18]:
motions = inputs["motion"][0].squeeze().to(torch.long)
motion_mask = inputs["motion"][1]


In [19]:
motions.shape

torch.Size([4, 3, 28])

In [99]:
conditions["text"][0].shape

torch.Size([4, 24, 768])

In [100]:
conditions["audio"][0].shape

torch.Size([4, 300, 128])

In [101]:
conditions["audio"][1]

tensor([[False, False, False,  ..., False, False, False],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]], device='cuda:0')

In [103]:
cond = conditions["audio"][0]
cond_mask = conditions["audio"][1]
cond = einops.rearrange(cond, "b t d -> b d t")
cond = F.interpolate(cond, size=52)
cond_mask = (
    F.interpolate(
        cond_mask.unsqueeze(1).to(torch.float),
        size=52,
    )
    .squeeze(1)
    .to(torch.bool)
)

In [105]:
out = model_gen((motions, motion_mask), conditions)

In [33]:
B , N , C = conditions["audio"][0].shape

In [84]:
out = model.compute_predictions(inputs["motion"] , conditions)

In [38]:
audio_embed = model.project_audio(conditions["audio"][0])
text_embed = model.project_text(conditions["text"][0])

In [39]:
audio_embed.shape

torch.Size([4, 500, 512])

In [43]:
input_ = sum([model.emb[k](motions_or_ids[:, k]) for k in range(K)])

In [85]:
motions.shape

torch.Size([4, 3, 52])

In [None]:
dset = train_ds.datasets[0]

In [None]:
body_inds = motions_or_ids[:,0]
left_inds = motions_or_ids[:,1]
right_inds = motions_or_ids[:,2]
body_motion = body_model.decode(body_inds[0:1]).detach()
left_motion = left_hand_model.decode(left_inds[0:1]).detach()
right_motion = right_hand_model.decode(right_inds[0:1]).detach()
body_M = dset.toMotion(body_motion[0] , motion_rep = MotionRep(body_cfg.dataset.motion_rep) , hml_rep = body_cfg.dataset.hml_rep , )
left_M = dset.toMotion(left_motion[0] , motion_rep = MotionRep(left_cfg.dataset.motion_rep) , hml_rep = left_cfg.dataset.hml_rep , )
right_M = dset.toMotion(right_motion[0] , motion_rep = MotionRep(right_cfg.dataset.motion_rep) , hml_rep = right_cfg.dataset.hml_rep , )
full_M = dset.to_full_joint_representation(body_M , left_M , right_M )

In [39]:
dset.render_hml(full_M , "/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/full_joined.gif", from_rotation = True)

In [100]:
choreo = np.load("/srv/hays-lab/scratch/sanisetty3/motionx/motion_data/new_joint_vecs/choreomaster/1160.npy")
full_og = dset.toMotion(choreo , motion_rep = MotionRep("full") , hml_rep = "gprvc" )
xyz = dset.to_xyz(full_og, from_rotation=True).cpu()

In [104]:
import utils.vis_utils.plot_3d_global as plot_3d
plot_3d.render(
            np.array(xyz)[:300],
            "/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/full_og.gif",
        )

In [102]:
choreo.shape

(6466, 623)

In [105]:
full_M = dset.inv_transform(full_M)
full_M.tensor()
full_M.root_params = full_og.root_params[:full_M.rotations.shape[0]]
xyz_ = dset.to_xyz(full_M, from_rotation=True).cpu()

In [106]:
xyz_.shape

torch.Size([6448, 52, 3])

In [107]:

plot_3d.render(
            np.array(xyz_)[:300],
            "/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/full_joined.gif",
        )

In [48]:
null_conditions = model.cfg_dropout(conditions, 1.0)

In [43]:
out = model.compute_predictions(inputs["motion"] , conditions)



In [45]:
model = model.eval()

In [49]:
outt = model.generate(conditions = null_conditions , two_step_cfg =True)

In [50]:
outt.shape

torch.Size([4, 3, 225])

In [23]:
import time
start = time.time()
for i in range(10):
    outt = model.generate(conditions = conditions , two_step_cfg =True)
end = time.time()
print('custom attention took {} seconds'.format(end - start))

custom attention took 31.624648332595825 seconds


In [26]:
import time
start = time.time()
for i in range(10):
    outt = model.generate(conditions = conditions , two_step_cfg =True)
end = time.time()
print('custom attention without flash took {} seconds'.format(end - start))

custom attention without flash took 34.69120717048645 seconds


In [41]:
import time
start = time.time()
for i in range(10):
    outt = model.generate(conditions = conditions , two_step_cfg =True)
end = time.time()
print('torch attention took {} seconds'.format(end - start))

torch attention took 22.785192251205444 seconds


In [35]:
model_cfg = cfg_s.transformer_lm

In [36]:
params = MotionTokenizerParams(model_cfg.card)

In [104]:
n_q = 3
codes = inputs["motion"][0].permute(0,2,1).to(torch.long)
code_mask = inputs["motion"][1]

In [105]:
B, K, T = codes.shape
codes = codes.contiguous()

In [106]:
codes.shape

torch.Size([4, 52, 3])

In [107]:
pattern = pattern_provider.get_pattern(T)

In [108]:
pattern

Pattern(layout=[[], [LayoutCoord(t=0, q=0)], [LayoutCoord(t=1, q=0), LayoutCoord(t=0, q=1), LayoutCoord(t=0, q=2)], [LayoutCoord(t=2, q=0), LayoutCoord(t=1, q=1), LayoutCoord(t=1, q=2)], [LayoutCoord(t=3, q=0), LayoutCoord(t=2, q=1), LayoutCoord(t=2, q=2)]], timesteps=3, n_q=3)

In [46]:
sequence_codes, sequence_indexes, sequence_mask = (
            pattern.build_pattern_sequence(
                codes,
                1025,
                keep_only_valid_steps=True,
            )
)

In [48]:
sequence_codes.shape

torch.Size([4, 3, 53])

In [47]:
code_mask.shape

torch.Size([4, 52])

In [49]:
sequence_mask.shape

torch.Size([3, 53])

In [50]:
new_mask = torch.ones_like(sequence_mask).repeat(B , 1 , 1)
for i in range(n_q):
    new_mask[:,i,i+1:] = code_mask[:,0:T-i]

In [53]:
new_mask.shape

torch.Size([4, 3, 53])

In [56]:
new_new_mask = (new_mask.sum(1) == new_mask.shape[1])

In [59]:
new_new_mask[0]

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False])

In [6]:
from core.datasets.audio_encoders import EncodecConditioner

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
audio_list = findAllFile("/srv/hays-lab/scratch/sanisetty3/motionx/audio/wav/")

In [8]:
audenc = EncodecConditioner(target_sr = 9600)
# audlib = AudioConditionerLibrosa()



In [20]:
for src in tqdm(audio_list):
    embs = audenc(audio_list[0])
    np.save(src.replace("/wav" , "/encodec").replace(".wav" , ".npy") , embs.detach().cpu().numpy())

100%|████████████████████████████████████████████████████████████████████████████████████| 1724/1724 [07:27<00:00,  3.85it/s]


In [19]:
audio_list[0]

'/srv/hays-lab/scratch/sanisetty3/motionx/audio/wav/aist/mJS5.wav'

In [15]:
emb = audenc(audio_list[0])

In [16]:
emb.shape

torch.Size([887, 128])

In [17]:
np.load("/srv/hays-lab/scratch/sanisetty3/motionx/audio/encodec/aist/mBR0.npy").shape

(2700, 128)

In [40]:
emb = audenc("/srv/hays-lab/scratch/sanisetty3/motionx/audio/wav/choreomaster/0071.wav")

In [50]:
emb2 = audlib("/srv/hays-lab/scratch/sanisetty3/motionx/audio/wav/choreomaster/0071.wav")

In [51]:
emb2.shape

(5234, 35)

In [52]:
emb.shape

torch.Size([8722, 128])

In [53]:
emb1 = np.load("/srv/hays-lab/scratch/sanisetty3/motionx/audio/encodec/choreomaster/0071.npy")

In [54]:
emb1.shape

(8722, 128)