In [3]:
# setting device on GPU if available, else CPU
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print()

# Additional Info when using cuda
if device.type == "cuda":
    print(torch.cuda.get_device_name(0))
    print("Memory Usage:")
    print("Allocated:", round(torch.cuda.memory_allocated(0) / 1024**3, 1), "GB")
    print("Cached:   ", round(torch.cuda.memory_reserved(0) / 1024**3, 1), "GB")

Using device: cuda

NVIDIA GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [5]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils.data import DataLoader

from torch.utils import data


import copy
import os
import random
import cv2
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
import functools
from tqdm import tqdm
from datetime import datetime
import numpy as np
from functools import partial
# from core.datasets.vqa_motion_dataset import VQMotionDataset,DATALoader,VQVarLenMotionDataset,MotionCollator
from einops import rearrange, reduce, pack, unpack
import sys

## test rendering

In [6]:
def findAllFile(base):
    """
    Recursively find all files in the specified directory.

    Args:
        base (str): The base directory to start the search.

    Returns:
        list: A list of file paths found in the directory and its subdirectories.
    """
    file_path = []
    for root, ds, fs in os.walk(base, followlinks=True):
        for f in fs:
            fullname = os.path.join(root, f)
            file_path.append(fullname)
    return file_path

In [5]:
from utils.motion_processing.hml_process import recover_from_ric, recover_root_rot_pos
import utils.vis_utils.plot_3d_global as plot_3d
import matplotlib.pyplot as plt

In [6]:
from core.param_dataclasses import pattern_providers
from core.datasets.multimodal_dataset import MotionIndicesAudioTextDataset, load_dataset_gen, simple_collate
from core.models.utils import instantiate_from_config, get_obj_from_str
from core import MotionRep, AudioRep, TextRep
from core.datasets.conditioner import ConditionProvider,ConditionFuser

from configs.config import cfg, get_cfg_defaults
from configs.config_streaming import get_cfg_defaults as strm_get_cfg_defaults


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
gen_cfg = strm_get_cfg_defaults()
gen_cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_streaming/motion_streaming.yaml")
gen_cfg.freeze()

In [8]:

body_cfg = get_cfg_defaults()
body_cfg.merge_from_file(gen_cfg.vqvae.body_config)
# left_cfg = get_cfg_defaults()
# left_cfg.merge_from_file(gen_cfg.vqvae.left_hand_config)
# right_cfg = get_cfg_defaults()
# right_cfg.merge_from_file(gen_cfg.vqvae.right_hand_config)

In [9]:
from core.models.resnetVQ.vqvae import HumanVQVAE


In [10]:
# left_hand_model = HumanVQVAE(left_cfg.vqvae).to(device).eval()
# left_hand_model.load(os.path.join(left_cfg.output_dir, "vqvae_motion.pt"))

# right_hand_model = HumanVQVAE(right_cfg.vqvae).to(device).eval()
# right_hand_model.load(os.path.join(right_cfg.output_dir, "vqvae_motion.pt"))

body_model = HumanVQVAE(body_cfg.vqvae).to(device).eval()
body_model.load(os.path.join(body_cfg.output_dir, "vqvae_motion.pt"))



### VQVAE

In [11]:
from core.datasets.vq_dataset import VQSMPLXMotionDataset
from core.datasets.vq_dataset import load_dataset as load_dataset_vq
from core.datasets.vq_dataset import simple_collate as simple_collate_vq


In [12]:
def get_decoded(model , motion):
    encs = []
    inds = []
    for i in range(0 , motion.shape[1] , 120 ):
        if i + 240 >= motion.shape[1]:
            enc_b = model(motion[: , i:, :].to(device))
            encs.append(enc_b.decoded_motion)
            inds.append(enc_b.indices)
            break
        else:
            enc_b = model(motion[: , i:i + 120, :].to(device))
            encs.append(enc_b.decoded_motion)
            inds.append(enc_b.indices)
    return torch.cat(encs , 1), torch.cat(inds , -1)

In [13]:
condition_provider2 = ConditionProvider(
            motion_rep=MotionRep("body"),
            motion_padding="max_length",

        )



In [14]:
body_cfg.dataset.window_size = -1

In [15]:
dataset_names_default = [
    # "animation",
    "humanml",
    # "perform",
    # "GRAB",
    # "idea400",
    # "humman",
    # "beat",
    # "game_motion",
    # "music",
    # "aist",
    # "fitness",
    # "moyo",
    # "choreomaster",
    # "dance",
    # "kungfu",
    # "EgoBody",
    # "HAA500",
]

In [16]:
test_ds, _, _ = load_dataset_vq(
            dataset_names=dataset_names_default,
            dataset_args=body_cfg.dataset,
            split="train",
        )

Total number of motions humanml: 19984


In [17]:
dl = torch.utils.data.DataLoader(
            test_ds,
            batch_size=1,
            sampler=None,
            shuffle=False,
            collate_fn=partial(simple_collate_vq , conditioner = condition_provider2),
            
        )

In [18]:
errors = []
for batch in tqdm(dl):
    sve = os.path.join("/srv/hays-lab/scratch/sanisetty3/motionx/indices/body" , batch["names"][0] )
    if os.path.exists(sve):
        continue
    os.makedirs(os.path.dirname(sve) , exist_ok=True)
    try:
        enc_b, inb = get_decoded(body_model , batch["motion"][0].to(device))
        np.save(sve , inb.cpu().numpy())
    except:
        errors.append(batch["names"][0])
    

100%|████████████████████████████████████████████████████████████████████████████████| 19984/19984 [1:02:28<00:00,  5.33it/s]


In [19]:
len(errors)

0

In [7]:
fles = findAllFile("/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/")

In [13]:
err = []
for i in fles:
    if np.load(i).shape[-1] < 10:
        err.append(os.path.basename(i))
        print(i, np.load(i).shape)

/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/006859.npy (1, 9)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/013180.npy (1, 6)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/M005024.npy (1, 9)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/M006988.npy (1, 6)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/008362.npy (1, 8)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/M011994.npy (1, 8)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/010908.npy (1, 6)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/009136.npy (1, 7)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/013310.npy (1, 6)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/M006950.npy (1, 6)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/003216.npy (1, 6)
/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/011011.npy (1, 8)
/srv/hays-lab/scratch/sanisetty3/mot

In [14]:
err

['006859.npy',
 '013180.npy',
 'M005024.npy',
 'M006988.npy',
 '008362.npy',
 'M011994.npy',
 '010908.npy',
 '009136.npy',
 '013310.npy',
 'M006950.npy',
 '003216.npy',
 '011011.npy',
 'M005928.npy',
 'M004294.npy',
 'M003544.npy',
 'M012941.npy',
 '013096.npy',
 '005821.npy',
 'M010908.npy',
 '011994.npy',
 'M008362.npy',
 '006988.npy',
 '005024.npy',
 'M006859.npy',
 'M013180.npy',
 '006950.npy',
 'M009136.npy',
 'M013310.npy',
 '003544.npy',
 '004294.npy',
 '005928.npy',
 'M003216.npy',
 'M011011.npy',
 'M005821.npy',
 '012941.npy',
 'M013096.npy',
 '000254.npy',
 'M005529.npy',
 '005117.npy',
 'M000908.npy',
 'M005289.npy',
 '011224.npy',
 '009757.npy',
 'M013920.npy',
 'M010564.npy',
 'M004292.npy',
 '000696.npy',
 'M013025.npy',
 '011216.npy',
 '011211.npy',
 '007050.npy',
 'M002393.npy',
 'M010194.npy',
 'M008391.npy',
 '003928.npy',
 'M001250.npy',
 '001505.npy',
 '003196.npy',
 '005289.npy',
 '000908.npy',
 'M005117.npy',
 'M000254.npy',
 '005529.npy',
 'M000696.npy',
 '004292

In [21]:
np.load("/srv/hays-lab/scratch/sanisetty3/motionx/indices/body/humanml/M012813.npy").shape

(1, 74)

In [15]:
"005485.npy" in err

False

In [None]:
['humanml/005485']
['humanml/006501']
['humanml/M005485']

In [20]:
['humanml/M006117']
['humanml/012400']
  6%|####7                                                                                | 554/9852 [00:11<03:03, 50.70it/s]['humanml/002707']
  9%|#######2                                                                             | 840/9852 [00:16<02:52, 52.25it/s]['humanml/M001419']
 11%|#########1                                                                          | 1068/9852 [00:20<02:36, 56.17it/s]['humanml/M011284']
 13%|###########1                                                                        | 1302/9852 [00:24<02:30, 56.79it/s]['humanml/M013511']
 14%|###########5                                                                        | 1356/9852 [00:25<02:30, 56.57it/s]['humanml/008560']
 16%|#############3                                                                      | 1561/9852 [00:30<03:20, 41.25it/s]['humanml/002178']
 18%|##############7                                                                     | 1731/9852 [00:33<02:40, 50.57it/s]['humanml/M000299']
 20%|################8                                                                   | 1977/9852 [00:38<02:35, 50.66it/s]['humanml/M004412']
 23%|###################5                                                                | 2294/9852 [00:44<02:30, 50.08it/s]['humanml/M006836']
 26%|#####################7                                                              | 2552/9852 [00:49<02:13, 54.52it/s]['humanml/M008371']
 28%|#######################4                                                            | 2750/9852 [00:53<02:11, 53.98it/s]['humanml/M014110']
 31%|#########################6                                                          | 3011/9852 [00:58<02:31, 45.17it/s]['humanml/005485']

 32%|##########################7                                                         | 3141/9852 [01:01<02:32, 44.00it/s]['humanml/007554']
 33%|###########################5                                                        | 3236/9852 [01:03<02:30, 43.94it/s]['humanml/009716']
 35%|#############################2                                                      | 3426/9852 [01:07<02:29, 43.09it/s]['humanml/M006163']
 35%|#############################3                                                      | 3443/9852 [01:08<02:09, 49.57it/s]['humanml/010850']
 37%|###############################1                                                    | 3659/9852 [01:12<01:52, 54.85it/s]['humanml/M009402']
 38%|################################1                                                   | 3767/9852 [01:14<01:51, 54.46it/s]['humanml/M012250']

 44%|#####################################2                                              | 4367/9852 [01:25<01:40, 54.78it/s]
['humanml/M012813']
 45%|#####################################4                                              | 4397/9852 [01:25<01:39, 54.71it/s]
['humanml/004574']
 47%|#######################################8                                            | 4673/9852 [01:30<01:34, 54.74it/s]
['humanml/009908']
 48%|########################################3                                           | 4727/9852 [01:31<01:33, 54.74it/s]
['humanml/012791']
 51%|##########################################5                                         | 4985/9852 [01:36<01:29, 54.62it/s]
['humanml/012813']
 58%|################################################7                                   | 5715/9852 [01:51<01:33, 44.21it/s]
['humanml/M012791']
 59%|#################################################4                                  | 5795/9852 [01:53<01:33, 43.40it/s]
['humanml/M009908']
 61%|###################################################1                                | 5995/9852 [01:57<01:29, 43.00it/s]
['humanml/M004574']
 62%|###################################################6                                | 6060/9852 [01:59<01:09, 54.32it/s]['humanml/M007554']
 64%|#####################################################4                              | 6270/9852 [02:02<01:05, 54.88it/s]['humanml/M005485']

['humanml/014110']
['humanml/M006501']
['humanml/012250']
['humanml/009402']
['humanml/M010850']
['humanml/006163']
['humanml/M009716']
['humanml/000299']
['humanml/M002178']
['humanml/M008560']
['humanml/013511']
['humanml/008371']
['humanml/006836']
['humanml/004412']
['humanml/M002707']
['humanml/M012400']
['humanml/006117']
['humanml/011284']
['humanml/001419']



torch.Size([1, 198, 192])

In [None]:
90

In [26]:
for batch in tqdm(dl):
    test_ds.datasets[0].render_hml(batch["motion"][0][0][:300] , save_path = f"/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/true_all/{os.path.basename(batch['names'][0])}.gif" )

100%|████████████████████████████████████████████████████████████████████████████████████████| 33/33 [10:38<00:00, 19.34s/it]


In [38]:
batch.keys()

dict_keys(['names', 'motion'])

In [23]:
batch["motion"][0].shape

torch.Size([1, 299, 462])

In [41]:
test_ds.datasets[0].render_hml(batch["motion"][0][0][:300] , save_path = "/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_streaming/samples/0/beatgtgt.gif" )