In [1]:
# setting device on GPU if available, else CPU
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print()

# Additional Info when using cuda
if device.type == "cuda":
    print(torch.cuda.get_device_name(0))
    print("Memory Usage:")
    print("Allocated:", round(torch.cuda.memory_allocated(0) / 1024**3, 1), "GB")
    print("Cached:   ", round(torch.cuda.memory_reserved(0) / 1024**3, 1), "GB")

Using device: cuda

NVIDIA GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils.data import DataLoader

from torch.utils import data


import copy
import os
import random
import cv2
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
import functools
from tqdm import tqdm
from datetime import datetime
import numpy as np
from functools import partial
# from core.datasets.vqa_motion_dataset import VQMotionDataset,DATALoader,VQVarLenMotionDataset,MotionCollator
from einops import rearrange, reduce, pack, unpack
import sys

## test rendering

In [12]:
def findAllFile(base):
    """
    Recursively find all files in the specified directory.

    Args:
        base (str): The base directory to start the search.

    Returns:
        list: A list of file paths found in the directory and its subdirectories.
    """
    file_path = []
    for root, ds, fs in os.walk(base, followlinks=True):
        for f in fs:
            fullname = os.path.join(root, f)
            file_path.append(fullname)
    return file_path

In [13]:
from utils.motion_processing.hml_process import recover_from_ric, recover_root_rot_pos
import utils.vis_utils.plot_3d_global as plot_3d
import matplotlib.pyplot as plt

In [47]:
from core.param_dataclasses import pattern_providers
from core.datasets.multimodal_dataset import MotionIndicesAudioTextDataset, load_dataset_gen, simple_collate
from core.models.utils import instantiate_from_config, get_obj_from_str
from core import MotionRep, AudioRep, TextRep
from core.datasets.conditioner import ConditionProvider,ConditionFuser
from core.models.generation.lm import LMModel, MotionGen

from configs.config import cfg, get_cfg_defaults
from configs.config_streaming import get_cfg_defaults as strm_get_cfg_defaults


In [15]:
gen_cfg = strm_get_cfg_defaults()
gen_cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_streaming/motion_streaming.yaml")
gen_cfg.freeze()

In [16]:

body_cfg = get_cfg_defaults()
body_cfg.merge_from_file(gen_cfg.vqvae.body_config)
left_cfg = get_cfg_defaults()
left_cfg.merge_from_file(gen_cfg.vqvae.left_hand_config)
right_cfg = get_cfg_defaults()
right_cfg.merge_from_file(gen_cfg.vqvae.right_hand_config)

In [17]:
from core.models.resnetVQ.vqvae import HumanVQVAE


In [78]:
left_hand_model = HumanVQVAE(left_cfg.vqvae).to(device).eval()
left_hand_model.load(os.path.join(left_cfg.output_dir, "vqvae_motion.pt"))

right_hand_model = HumanVQVAE(right_cfg.vqvae).to(device).eval()
right_hand_model.load(os.path.join(right_cfg.output_dir, "vqvae_motion.pt"))

body_model = HumanVQVAE(body_cfg.vqvae).to(device).eval()
body_model.load(os.path.join(body_cfg.output_dir, "vqvae_motion.pt"))



### VQVAE

In [60]:
from core.datasets.vq_dataset import VQSMPLXMotionDataset
from core.datasets.vq_dataset import load_dataset as load_dataset_vq
from core.datasets.vq_dataset import simple_collate as simple_collate_vq


In [68]:
def get_decoded(model , motion):
    encs = []
    inds = []
    for i in range(0 , motion.shape[1] , 120 ):
        if i + 240 >= motion.shape[1]:
            enc_b = model(motion[: , i:, :].to(device))
            encs.append(enc_b.decoded_motion)
            inds.append(enc_b.indices)
            break
        else:
            enc_b = model(motion[: , i:i + 120, :].to(device))
            encs.append(enc_b.decoded_motion)
            inds.append(enc_b.indices)
    return torch.cat(encs , 1), torch.cat(inds , -1)

In [61]:
condition_provider2 = ConditionProvider(
            motion_rep=MotionRep("full"),
            motion_padding="max_length",

        )



In [62]:
body_cfg.dataset.window_size = -1

In [88]:
test_ds, _, _ = load_dataset_vq(
            # dataset_names=["choreomaster"],
            dataset_args=body_cfg.dataset,
            split="train",
        )

Total number of motions animation: 312
Total number of motions humanml: 19984
Total number of motions perform: 451
Total number of motions GRAB: 1268
Total number of motions idea400: 11886
Total number of motions humman: 706
Total number of motions beat: 1458
Total number of motions game_motion: 9699
Total number of motions music: 3385
Total number of motions aist: 1396
Total number of motions fitness: 15886
Total number of motions moyo: 161
Total number of motions choreomaster: 34
Total number of motions dance: 153
Total number of motions kungfu: 987
Total number of motions EgoBody: 931
Total number of motions HAA500: 4969


In [89]:
dl = torch.utils.data.DataLoader(
            test_ds,
            batch_size=1,
            sampler=None,
            shuffle=False,
            collate_fn=partial(simple_collate_vq , conditioner = condition_provider2),
            
        )

In [None]:
errors = []
for batch in tqdm(dl):
    sve = os.path.join("/srv/hays-lab/scratch/sanisetty3/motionx/indices/body_rv" , batch["names"][0] )
    if os.path.exists(sve):
        continue
    os.makedirs(os.path.dirname(sve) , exist_ok=True)
    try:
        enc_b, inb = get_decoded(body_model , inputs["motion"][0].to(device))
        np.save(sve , inb.cpu().numpy())
    except:
        errors.append(batch["names"][0])
    

 30%|███████████████████████▍                                                      | 22110/73666 [3:12:26<8:07:42,  1.76it/s]

In [66]:
for inputs in dl:
    break
    

(6466, 192)
(1, 6466, 192)


In [67]:
inputs["motion"][0].shape

torch.Size([1, 6466, 192])

In [79]:
codes = body_model.encode(inputs["motion"][0].to(device))

In [80]:
out = body_model.decode(codes)

In [81]:
out.shape

torch.Size([1, 6464, 192])

In [83]:
out2, cde2 = get_decoded(body_model , inputs["motion"][0].to(device))

In [84]:
test_ds.datasets[0].render_hml(
                    out2.detach().squeeze().cpu()[:500],
                    "/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/choreo2.gif"
                )

In [86]:
inputs["names"]

array(['choreomaster/1160.npy'], dtype='<U21')

100%|████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:19<00:00,  1.73it/s]


In [60]:
lm_args = gen_cfg.transformer_lm
target = lm_args.pop("target")
fuse_config = gen_cfg.fuser
pattern_args = gen_cfg.codebooks_pattern
dataset_args = gen_cfg.dataset

In [61]:
model_gen= MotionGen(lm_args , fuse_config , pattern_args ).to(device).eval()

In [65]:
condition_provider = ConditionProvider(
            motion_rep=MotionRep(dataset_args.motion_rep),
            audio_rep=AudioRep(dataset_args.audio_rep),
            text_rep=TextRep(dataset_args.text_rep),
            motion_padding=dataset_args.motion_padding,
            audio_padding=dataset_args.audio_padding,
            motion_max_length_s=10,
            audio_max_length_s=10,
            pad_id = model_gen.model.pad_token_id,
            fps=30/4,
            # device = "cpu"
        )



In [66]:
from core.datasets.multimodal_dataset import MotionIndicesAudioTextDataset, load_dataset_gen, simple_collate
# dset = MotionIndicesAudioTextDataset("chroeomaster" , "/srv/hays-lab/scratch/sanisetty3/motionx" ,motion_rep = "full", split = "train" , fps = 30/4  )


In [69]:
train_ds, sampler_train, weights_train  = load_dataset_gen(dataset_args=dataset_args, split = "train" , dataset_names = ["animation" , "choreomaster" ] )
train_loader = torch.utils.data.DataLoader(
        train_ds,
        4,
        sampler=sampler_train,
        # shuffle = False,
        collate_fn=partial(simple_collate , conditioner = condition_provider , permute = True),
        drop_last=True,
    )

Total number of motions animation: 194 and texts 194
Total number of motions choreomaster: 34 and texts 34


In [70]:
for inputs, conditions in train_loader:
    break
    

In [71]:
input_mask = inputs["motion"][1]
motions_or_ids = inputs["motion"][0]
B, K, T = motions_or_ids.shape