In [1]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import wandb

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader

from torch.utils import data


import copy
import os
import random
import cv2
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
import functools
from tqdm import tqdm
from datetime import datetime
import numpy as np
from core.datasets.vqa_motion_dataset import VQMotionDataset,DATALoader,VQVarLenMotionDataset,MotionCollator,VQFullMotionDataset
from einops import rearrange, reduce, pack, unpack
import librosa

In [4]:
from utils.motion_process import recover_from_ric
import visualize.plot_3d_global as plot_3d
from glob import glob
def to_xyz(motion, mean ,std , j = 22):
    motion_xyz = recover_from_ric(motion.cpu().float()*std+mean, j)
    motion_xyz = motion_xyz.reshape(motion.shape[0],-1, j, 3)
    return motion_xyz

            
def sample_render(motion_xyz , name , save_path):
    print(f"render start")
    
    gt_pose_vis = plot_3d.draw_to_batch(motion_xyz.numpy(),None, [os.path.join(save_path,name + ".gif")])



In [5]:
from configs.config import cfg, get_cfg_defaults
from core.models.vqvae import VQMotionModel
from core.models.motion_regressor import MotionRegressorModel


cfg_vq = get_cfg_defaults()
cfg_vq.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist_vq.yaml")




In [6]:
vqvae_model = VQMotionModel(cfg_vq.vqvae).eval()
pkg = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/vqvae_motion_best_fid.pt", map_location = 'cpu')
print(pkg["steps"])
vqvae_model.load_state_dict(pkg["model"])
vqvae_model =vqvae_model.cuda()


tensor([295000.])


In [7]:
collate_fn = MotionCollator()


In [9]:
train_ds = VQVarLenMotionDataset("t2m", split = "render" , max_length_seconds = 10, data_root = "/srv/scratch/sanisetty3/music_motion/HumanML3D/HumanML3D")
train_loader = DATALoader(train_ds,1,collate_fn=collate_fn)

changing range to: 60 - 60


100%|██████████| 10/10 [00:00<00:00, 37.82it/s]

Total number of motions 10





In [10]:
aist_ds = VQVarLenMotionDataset("aist", split = "render" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , num_stages = 6 ,min_length_seconds=20, max_length_seconds=30)
aist_loader = DATALoader(aist_ds,1,collate_fn=collate_fn)

changing range to: 400 - 400


100%|██████████| 8/8 [00:00<00:00, 39.72it/s]

Total number of motions 8





In [11]:
aist_loader.dataset.set_stage(5)

changing range to: 400 - 600


In [12]:
for aist_batch in aist_loader:
    break
aist_batch["motion"].shape

torch.Size([1, 159, 263])

## Trans model

In [30]:
from configs.config import cfg, get_cfg_defaults
from core.models.motion_regressor import MotionRegressorModel


cfg_trans = get_cfg_defaults()
cfg_trans.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist.yaml")



In [31]:
trans_model = MotionRegressorModel(args = cfg_trans.motion_trans,pad_value=1025 ).eval()
pkg_trans = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist/trans_motion_best_fid.pt", map_location = 'cpu')
print(pkg_trans["steps"])
trans_model.load_state_dict(pkg_trans["model"])
trans_model =trans_model.cuda()


tensor([210000.])


## Encode Decode

In [52]:
ind = vqvae_model.encode(aist_batch["motion"].cuda())
print(ind.shape)
quant , out_motion = vqvae_model.decode(ind)

torch.Size([1, 191])


In [39]:
out = torch.empty(aist_batch["motion"].shape)

In [27]:
ind = vqvae_model.encode(aist_batch["motion"][:,:400].cuda())
quant , out_motion = vqvae_model.decode(ind)

In [53]:
quant , out_motion = vqvae_model.decode(ind[:,400:].to(torch.long).cuda())

In [54]:
out[:,400:] = out_motion


In [53]:
sample_render(to_xyz(aist_batch["motion"][0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "rnd_og_motion" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


In [57]:
sample_render(to_xyz(out[0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "rnd_motion_ind_400" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


In [47]:
indices = torch.randint(0,1024,(1,60))
quant , out_motion = vqvae_model.decode(indices.cuda())

In [54]:
sample_render(to_xyz(out_motion[0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "rnd_motion" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


## Music Eval stuff

In [7]:
from utils.motion_process import recover_from_ric
from utils.aist_metrics import calculate_fid_scores
from utils.aist_metrics.calculate_fid_scores import calculate_avg_distance, extract_feature,calculate_frechet_feature_distance,calculate_frechet_distance
from utils.aist_metrics.features import kinetic,manual
from utils.aist_metrics.calculate_beat_scores import motion_peak_onehot,alignment_score

In [8]:
from core.datasets.vqa_motion_dataset import MotionCollatorConditional, TransMotionDatasetConditional,VQMotionDataset,DATALoader,VQVarLenMotionDataset,MotionCollator,VQFullMotionDataset


In [9]:
from utils.eval_music import evaluate_music_motion_vqvae, evaluate_music_motion_generative,evaluate_music_motion_trans

In [14]:
for aist_batch in tqdm(aist_loader):
    break

  0%|          | 0/40 [00:00<?, ?it/s]


### Const len trained transformer

In [15]:
from core.models.motion_regressor import MotionRegressorModel

trans_model = MotionRegressorModel(args = cfg_trans.motion_trans , ignore_index=1025 ,pad_value=1025 ).eval()
pkg_trans = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/const_len/trans_768_768_aist/vqvae_motion.pt", map_location = 'cpu')
print(pkg_trans["steps"])
trans_model.load_state_dict(pkg_trans["model"])
trans_model =trans_model.cuda()


tensor([85000.])


In [39]:
0.23/0.243 * 0.292

0.2763786008230453

## Evaluate Music Motion transformer

In [7]:
encodec = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist/var_len_768_768_aist.yaml"
encodec_sine = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_sine_aist/var_len_768_768_sine_aist.yaml"
librosa = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist_35/var_len_768_768_aist_35.yaml"
encodec_prob50 = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist_mask_prob50/trans_768_768_albi_aist_mask_prob50.yaml"


In [8]:
from configs.config import cfg, get_cfg_defaults
from core.models.motion_regressor import MotionRegressorModel


trans_option = "encodec_sine"

cfg_trans = get_cfg_defaults()
cfg_trans.merge_from_file(encodec_sine)


trans_model = MotionRegressorModel(args = cfg_trans.motion_trans,pad_value=1025 ).eval()
pkg_trans = torch.load(f"{os.path.dirname(encodec_sine)}/trans_motion_best_fid.pt", map_location = 'cpu')
print(pkg_trans["steps"])
trans_model.load_state_dict(pkg_trans["model"])
trans_model =trans_model.cuda()


tensor([205000.])


In [9]:
audio_encoding_dir = "/srv/scratch/sanisetty3/music_motion/AIST/music"
audio_features_dir = "/srv/scratch/sanisetty3/music_motion/AIST/audio_features/"
use35 = False
if trans_option == "librosa":
    audio_encoding_dir = audio_features_dir
    use35 = True

In [10]:
use35

False

## Evaluate Music Motion Generative

In [11]:
from render_final import hml2aist

In [12]:
from utils.eval_music import evaluate_music_motion_vqvae, evaluate_music_motion_generative,evaluate_music_motion_generative2,evaluate_music_motion_trans

In [13]:
aist_ds = VQFullMotionDataset("aist", split = "test" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 40/40 [00:00<00:00, 1040.35it/s]

Total number of motions 40





In [14]:
from utils.eval_music import evaluate_music_motion_generative
print("pretrained mix")
best_fid_k = []
best_fid_g = []
best_div_k = []
best_div_g = []
best_beat_align = []

for i in range(1):

    gen_motions , a,b,c,d,e = evaluate_music_motion_generative(aist_loader , vqvae_model= vqvae_model ,net = trans_model,use35=use35)
    best_fid_k.append(a)
    best_fid_g.append(b)
    best_div_k.append(c)
    best_div_g.append(d)
    best_beat_align.append(e)

    
print("best_fid_k" , np.mean(best_fid_k))
print("best_fid_g" , np.mean(best_fid_g))
print("best_div_k" , np.mean(best_div_k))
print("best_div_g" , np.mean(best_div_g))
print("best_beat_align" , np.mean(best_beat_align))



  0%|          | 0/400 [00:00<?, ?it/s]

pretrained mix


100%|██████████| 400/400 [00:14<00:00, 27.78it/s]
100%|██████████| 400/400 [00:13<00:00, 29.94it/s]
100%|██████████| 400/400 [00:13<00:00, 29.81it/s]
100%|██████████| 400/400 [00:13<00:00, 30.30it/s]
100%|██████████| 400/400 [00:13<00:00, 30.17it/s]
100%|██████████| 400/400 [00:13<00:00, 30.23it/s]
100%|██████████| 400/400 [00:13<00:00, 30.62it/s]
100%|██████████| 400/400 [00:13<00:00, 30.41it/s]
100%|██████████| 400/400 [00:13<00:00, 30.31it/s]
100%|██████████| 400/400 [00:13<00:00, 30.56it/s]
100%|██████████| 400/400 [00:12<00:00, 30.96it/s]
100%|██████████| 400/400 [00:13<00:00, 30.33it/s]
100%|██████████| 400/400 [00:13<00:00, 29.87it/s]
100%|██████████| 400/400 [00:13<00:00, 30.41it/s]
100%|██████████| 400/400 [00:13<00:00, 30.43it/s]
100%|██████████| 400/400 [00:08<00:00, 46.37it/s]
100%|██████████| 400/400 [00:07<00:00, 50.60it/s]
100%|██████████| 400/400 [00:07<00:00, 50.65it/s]
100%|██████████| 400/400 [00:07<00:00, 50.40it/s]
100%|██████████| 400/400 [00:07<00:00, 50.85it/s]


FID_k:  7.162948547746254 Diversity_k: 9.617357408618316
FID_g:  12.577118332144806 Diversity_g: 7.296678166206067

Beat score on real data: 0.244


Beat score on generated data: 0.208

\PFC score on real data: 2.113

\PFC score on generated data: 2.339

best_fid_k 7.162948547746254
best_fid_g 12.577118332144806
best_div_k 9.617357408618316
best_div_g 7.296678166206067
best_beat_align 0.244130060202235





In [35]:
real_features = {
        "kinetic": [np.load(f) for f in glob("/srv/share/datasets/AIST/aist_features/*_kinetic.npy")],
        "manual": [np.load(f) for f in glob("/srv/share/datasets/AIST/aist_features/*_manual.npy")],
    }

In [18]:
from utils.aist_metrics.features import kinetic,manual
from utils.aist_metrics.calculate_beat_scores import motion_peak_onehot,alignment_score


In [None]:
aist_motions = []
musics = []
gt_motions = []
motion_names=  []
for gt_motion , motion, motion_name in tqdm(gen_motions):
    music_name = motion_name.split('_')[-2]
    print(gt_motion.shape, motion.shape, audio_name, motion_name)
    
    motion_names.append(motion_name)
    
    motion_xyz = to_xyz(motion.detach().cpu() , mean= aist_ds.mean , std = aist_ds.std)
    aist = hml2aist(motion_xyz[0].cpu().numpy() )
    aist_motions.append(aist)
    
    motion_xyz = to_xyz(gt_motion.detach().cpu(), mean= aist_ds.mean , std = aist_ds.std)
    gt = hml2aist(motion_xyz[0].cpu().numpy())
    gt_motions.append(gt)
    
    
    audio_feature = np.load(os.path.join(audio_features_dir, f"{audio_name}.npy"))
    musics.append(audio_feature)

  0%|          | 0/40 [00:00<?, ?it/s]

torch.Size([1, 213, 263]) torch.Size([1, 400, 263]) mKR2 gPO_sBM_cAll_d11_mPO1_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


  vertices, rotations, global_orient, out, x_translations = rot2xyz(torch.tensor(motion_tensor).clone(), mask=None,


torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


  2%|▎         | 1/40 [04:10<2:42:55, 250.67s/it]

torch.Size([213, 3, 3]) torch.Size([213, 23, 3, 3]) torch.Size([1, 3, 213])
torch.Size([1, 159, 263]) torch.Size([1, 400, 263]) mKR2 gLH_sBM_cAll_d17_mLH4_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


  5%|▌         | 2/40 [07:50<2:32:55, 241.46s/it]

torch.Size([159, 3, 3]) torch.Size([159, 23, 3, 3]) torch.Size([1, 3, 159])
torch.Size([1, 174, 263]) torch.Size([1, 400, 263]) mKR2 gJS_sBM_cAll_d01_mJS3_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


  8%|▊         | 3/40 [12:47<2:39:05, 257.98s/it]

torch.Size([174, 3, 3]) torch.Size([174, 23, 3, 3]) torch.Size([1, 3, 174])
torch.Size([1, 159, 263]) torch.Size([1, 400, 263]) mKR2 gLH_sBM_cAll_d18_mLH4_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 10%|█         | 4/40 [18:10<2:46:34, 277.63s/it]

torch.Size([159, 3, 3]) torch.Size([159, 23, 3, 3]) torch.Size([1, 3, 159])
torch.Size([1, 191, 263]) torch.Size([1, 400, 263]) mKR2 gLO_sBM_cAll_d15_mLO2_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 12%|█▎        | 5/40 [22:44<2:41:13, 276.39s/it]

torch.Size([191, 3, 3]) torch.Size([191, 23, 3, 3]) torch.Size([1, 3, 191])
torch.Size([1, 191, 263]) torch.Size([1, 400, 263]) mKR2 gKR_sBM_cAll_d28_mKR2_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 15%|█▌        | 6/40 [27:05<2:33:59, 271.76s/it]

torch.Size([191, 3, 3]) torch.Size([191, 23, 3, 3]) torch.Size([1, 3, 191])
torch.Size([1, 174, 263]) torch.Size([1, 400, 263]) mKR2 gMH_sBM_cAll_d24_mMH3_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 18%|█▊        | 7/40 [30:33<2:18:56, 252.61s/it]

torch.Size([174, 3, 3]) torch.Size([174, 23, 3, 3]) torch.Size([1, 3, 174])
torch.Size([1, 239, 263]) torch.Size([1, 400, 263]) mKR2 gBR_sBM_cAll_d04_mBR0_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 20%|██        | 8/40 [34:28<2:12:03, 247.60s/it]

torch.Size([239, 3, 3]) torch.Size([239, 23, 3, 3]) torch.Size([1, 3, 239])
torch.Size([1, 147, 263]) torch.Size([1, 400, 263]) mKR2 gJB_sBM_cAll_d08_mJB5_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 22%|██▎       | 9/40 [39:28<2:15:58, 263.19s/it]

torch.Size([147, 3, 3]) torch.Size([147, 23, 3, 3]) torch.Size([1, 3, 147])
torch.Size([1, 174, 263]) torch.Size([1, 400, 263]) mKR2 gMH_sBM_cAll_d24_mMH3_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 25%|██▌       | 10/40 [44:32<2:17:39, 275.33s/it]

torch.Size([174, 3, 3]) torch.Size([174, 23, 3, 3]) torch.Size([1, 3, 174])
torch.Size([1, 239, 263]) torch.Size([1, 400, 263]) mKR2 gWA_sBM_cAll_d25_mWA0_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 28%|██▊       | 11/40 [48:58<2:11:42, 272.51s/it]

torch.Size([239, 3, 3]) torch.Size([239, 23, 3, 3]) torch.Size([1, 3, 239])
torch.Size([1, 147, 263]) torch.Size([1, 400, 263]) mKR2 gJB_sBM_cAll_d09_mJB5_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 30%|███       | 12/40 [53:34<2:07:45, 273.76s/it]

torch.Size([147, 3, 3]) torch.Size([147, 23, 3, 3]) torch.Size([1, 3, 147])
torch.Size([1, 141, 263]) torch.Size([1, 400, 263]) mKR2 gHO_sBM_cAll_d20_mHO5_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 32%|███▎      | 13/40 [58:23<2:05:13, 278.29s/it]

torch.Size([141, 3, 3]) torch.Size([141, 23, 3, 3]) torch.Size([1, 3, 141])
torch.Size([1, 174, 263]) torch.Size([1, 400, 263]) mKR2 gJS_sBM_cAll_d01_mJS3_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 35%|███▌      | 14/40 [1:03:09<2:01:32, 280.50s/it]

torch.Size([174, 3, 3]) torch.Size([174, 23, 3, 3]) torch.Size([1, 3, 174])
torch.Size([1, 174, 263]) torch.Size([1, 400, 263]) mKR2 gJS_sBM_cAll_d03_mJS3_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 38%|███▊      | 15/40 [1:07:32<1:54:40, 275.21s/it]

torch.Size([174, 3, 3]) torch.Size([174, 23, 3, 3]) torch.Size([1, 3, 174])
torch.Size([1, 239, 263]) torch.Size([1, 400, 263]) mKR2 gBR_sBM_cAll_d04_mBR0_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 40%|████      | 16/40 [1:12:33<1:53:10, 282.94s/it]

torch.Size([239, 3, 3]) torch.Size([239, 23, 3, 3]) torch.Size([1, 3, 239])
torch.Size([1, 191, 263]) torch.Size([1, 400, 263]) mKR2 gKR_sBM_cAll_d30_mKR2_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 42%|████▎     | 17/40 [1:17:33<1:50:30, 288.29s/it]

torch.Size([191, 3, 3]) torch.Size([191, 23, 3, 3]) torch.Size([1, 3, 191])
torch.Size([1, 147, 263]) torch.Size([1, 400, 263]) mKR2 gJB_sBM_cAll_d09_mJB5_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 45%|████▌     | 18/40 [1:22:07<1:44:07, 283.99s/it]

torch.Size([147, 3, 3]) torch.Size([147, 23, 3, 3]) torch.Size([1, 3, 147])
torch.Size([1, 191, 263]) torch.Size([1, 400, 263]) mKR2 gLO_sBM_cAll_d15_mLO2_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 48%|████▊     | 19/40 [1:28:26<1:49:21, 312.43s/it]

torch.Size([191, 3, 3]) torch.Size([191, 23, 3, 3]) torch.Size([1, 3, 191])
torch.Size([1, 239, 263]) torch.Size([1, 400, 263]) mKR2 gWA_sBM_cAll_d26_mWA0_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 50%|█████     | 20/40 [1:34:12<1:47:29, 322.47s/it]

torch.Size([239, 3, 3]) torch.Size([239, 23, 3, 3]) torch.Size([1, 3, 239])
torch.Size([1, 191, 263]) torch.Size([1, 400, 263]) mKR2 gLO_sBM_cAll_d13_mLO2_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 52%|█████▎    | 21/40 [1:39:06<1:39:24, 313.90s/it]

torch.Size([191, 3, 3]) torch.Size([191, 23, 3, 3]) torch.Size([1, 3, 191])
torch.Size([1, 159, 263]) torch.Size([1, 400, 263]) mKR2 gLH_sBM_cAll_d18_mLH4_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 55%|█████▌    | 22/40 [1:44:21<1:34:14, 314.12s/it]

torch.Size([159, 3, 3]) torch.Size([159, 23, 3, 3]) torch.Size([1, 3, 159])
torch.Size([1, 191, 263]) torch.Size([1, 400, 263]) mKR2 gKR_sBM_cAll_d30_mKR2_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 57%|█████▊    | 23/40 [1:49:01<1:26:08, 304.03s/it]

torch.Size([191, 3, 3]) torch.Size([191, 23, 3, 3]) torch.Size([1, 3, 191])
torch.Size([1, 174, 263]) torch.Size([1, 400, 263]) mKR2 gMH_sBM_cAll_d22_mMH3_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 60%|██████    | 24/40 [1:53:30<1:18:13, 293.35s/it]

torch.Size([174, 3, 3]) torch.Size([174, 23, 3, 3]) torch.Size([1, 3, 174])
torch.Size([1, 191, 263]) torch.Size([1, 400, 263]) mKR2 gKR_sBM_cAll_d28_mKR2_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 62%|██████▎   | 25/40 [1:58:05<1:11:59, 287.96s/it]

torch.Size([191, 3, 3]) torch.Size([191, 23, 3, 3]) torch.Size([1, 3, 191])
torch.Size([1, 159, 263]) torch.Size([1, 400, 263]) mKR2 gLH_sBM_cAll_d17_mLH4_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 65%|██████▌   | 26/40 [2:02:23<1:05:06, 279.05s/it]

torch.Size([159, 3, 3]) torch.Size([159, 23, 3, 3]) torch.Size([1, 3, 159])
torch.Size([1, 141, 263]) torch.Size([1, 400, 263]) mKR2 gHO_sBM_cAll_d20_mHO5_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 68%|██████▊   | 27/40 [2:06:34<58:36, 270.49s/it]  

torch.Size([141, 3, 3]) torch.Size([141, 23, 3, 3]) torch.Size([1, 3, 141])
torch.Size([1, 147, 263]) torch.Size([1, 400, 263]) mKR2 gJB_sBM_cAll_d08_mJB5_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 70%|███████   | 28/40 [2:10:35<52:21, 261.76s/it]

torch.Size([147, 3, 3]) torch.Size([147, 23, 3, 3]) torch.Size([1, 3, 147])
torch.Size([1, 141, 263]) torch.Size([1, 400, 263]) mKR2 gHO_sBM_cAll_d21_mHO5_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 72%|███████▎  | 29/40 [2:14:24<46:12, 252.02s/it]

torch.Size([141, 3, 3]) torch.Size([141, 23, 3, 3]) torch.Size([1, 3, 141])
torch.Size([1, 174, 263]) torch.Size([1, 400, 263]) mKR2 gMH_sBM_cAll_d22_mMH3_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 75%|███████▌  | 30/40 [2:18:19<41:06, 246.70s/it]

torch.Size([174, 3, 3]) torch.Size([174, 23, 3, 3]) torch.Size([1, 3, 174])
torch.Size([1, 239, 263]) torch.Size([1, 400, 263]) mKR2 gBR_sBM_cAll_d05_mBR0_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 78%|███████▊  | 31/40 [2:22:42<37:45, 251.72s/it]

torch.Size([239, 3, 3]) torch.Size([239, 23, 3, 3]) torch.Size([1, 3, 239])
torch.Size([1, 239, 263]) torch.Size([1, 400, 263]) mKR2 gWA_sBM_cAll_d25_mWA0_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 80%|████████  | 32/40 [2:27:45<35:36, 267.10s/it]

torch.Size([239, 3, 3]) torch.Size([239, 23, 3, 3]) torch.Size([1, 3, 239])
torch.Size([1, 213, 263]) torch.Size([1, 400, 263]) mKR2 gPO_sBM_cAll_d10_mPO1_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 82%|████████▎ | 33/40 [2:32:14<31:12, 267.56s/it]

torch.Size([213, 3, 3]) torch.Size([213, 23, 3, 3]) torch.Size([1, 3, 213])
torch.Size([1, 239, 263]) torch.Size([1, 400, 263]) mKR2 gBR_sBM_cAll_d05_mBR0_ch01
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([400, 3, 3]) torch.Size([400, 23, 3, 3]) torch.Size([1, 3, 400])
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


 85%|████████▌ | 34/40 [2:36:42<26:46, 267.78s/it]

torch.Size([239, 3, 3]) torch.Size([239, 23, 3, 3]) torch.Size([1, 3, 239])
torch.Size([1, 239, 263]) torch.Size([1, 400, 263]) mKR2 gWA_sBM_cAll_d26_mWA0_ch02
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.


### HML2AIST

In [None]:
from smplx import SMPL
smpl = SMPL(model_path="/srv/share/datasets/AIST/smpl_models/", gender='MALE', batch_size=1)

In [None]:
real_features["kinetic"].append(extract_feature(keypoints3d_gt, "kinetic"))
real_features["manual"].append(extract_feature(keypoints3d_gt, "manual"))

result_features["kinetic"].append(extract_feature(keypoints3d_pred, "kinetic"))
result_features["manual"].append(extract_feature(keypoints3d_pred, "manual"))

real_pfc.append(calc_physical_score(keypoints3d_gt))
pred_pfc.append(calc_physical_score(keypoints3d_pred))




motion_beats = motion_peak_onehot(keypoints3d_gt[:mot_len])
# get real data music beats
audio_name = motion_name.split("_")[-2]

audio_feature = np.load(os.path.join(audio_feature_dir, f"{audio_name}.npy"))
audio_beats = audio_feature[:mot_len, -1] # last dim is the music beats
# get beat alignment scores
beat_score = alignment_score(audio_beats, motion_beats, sigma=1)
beat_scores_real.append(beat_score)


motion_beats = motion_peak_onehot(keypoints3d_pred[:mot_len])
beat_score_pred = alignment_score(audio_beats, motion_beats, sigma=1)
beat_scores_pred.append(beat_score_pred)

### Long motion test

In [52]:
aist_ds_train = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader_train = DATALoader(aist_ds_train,1,collate_fn=None)

100%|██████████| 1910/1910 [01:04<00:00, 29.50it/s]

Total number of motions 1910





In [19]:
for batch in aist_loader_train:
    break

In [20]:

print("pretrained mix")
best_fid_k = []
best_fid_g = []
best_div_k = []
best_div_g = []
best_beat_align = []
seq_len = 900
for i in range(1):

    a,b,c,d,e,f,g = evaluate_music_motion_generative2(aist_loader_train , vqvae_model= vqvae_model ,net = trans_model, use35 = use35,seq_len = seq_len)
    best_fid_k.append(a)
    best_fid_g.append(b)
    best_div_k.append(c)
    best_div_g.append(d)
    best_beat_align.append(e)

    
print("best_fid_k" , np.mean(best_fid_k))
print("best_fid_g" , np.mean(best_fid_g))
print("best_div_k" , np.mean(best_div_k))
print("best_div_g" , np.mean(best_div_g))
print("best_beat_align" , np.mean(best_beat_align))



  1%|          | 7/900 [00:00<00:14, 62.63it/s]

pretrained mix


100%|██████████| 900/900 [00:20<00:00, 42.96it/s]
100%|██████████| 900/900 [00:20<00:00, 43.03it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.70it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.74it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.69it/s]it]
100%|██████████| 900/900 [00:21<00:00, 42.81it/s]it]
100%|██████████| 900/900 [00:21<00:00, 42.37it/s]]  
100%|██████████| 900/900 [00:21<00:00, 42.73it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.74it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.77it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.46it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.69it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.79it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.13it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.60it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.56it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.75it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.85it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.64it/s]t]
100%|██████████| 900/900 

FID_k:  6.394353653410235 Diversity_k: 10.558077923262992
FID_g:  8.632308692625102 Diversity_g: 7.294414699949869
\PFC score on real data: 1.677

\PFC score on generated data: 1.239


Beat score on real data: 0.170


Beat score on generated data: 0.164

best_fid_k 6.394353653410235
best_fid_g 8.632308692625102
best_div_k 10.558077923262992
best_div_g 7.294414699949869
best_beat_align 0.17011690074139812





## Eval using evaluators

In [62]:
from core.models.evaluator_wrapper import AISTEvaluatorModelWrapper

In [63]:
from utils.eval_music import evaluate_music_motion_generative_extractors

In [64]:
from configs.config import cfg, get_cfg_defaults

cfg_eval = get_cfg_defaults()
cfg_eval.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/extractors/big/aist_extractor.yaml")

In [65]:
from core.datasets.evaluator_dataset import EvaluatorMotionCollator, EvaluatorVarLenMotionDataset, EvaluatorMotionDataset
aist_ds = EvaluatorMotionDataset(split = "test" ,data_root = "/srv/scratch/sanisetty3/music_motion/AIST" ,window_size = -1,librosa = use35 )
collate_fn = EvaluatorMotionCollator()

aist_loader = DATALoader(aist_ds,1,collate_fn=collate_fn)

100%|██████████| 40/40 [00:00<00:00, 518.24it/s]

Total number of motions 40





In [66]:
aist_evaluator = AISTEvaluatorModelWrapper(cfg_eval)

loading from:  /srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/extractors/big/extractors.pt steps:  tensor([95000.], device='cuda:0')


In [71]:
fids = []
div_R = []
div_p = []

for i in range(1):
    fid, diversity_real, diversity = evaluate_music_motion_generative_extractors(aist_loader , vqvae_model= vqvae_model ,net = trans_model, eval_wrapper = aist_evaluator, seq_len = 600)
    fids.append(fid)
    div_R.append(diversity_real)
    div_p.append(diversity)

100%|██████████| 141/141 [00:05<00:00, 25.96it/s]
100%|██████████| 191/191 [00:05<00:00, 36.13it/s]
100%|██████████| 191/191 [00:03<00:00, 55.90it/s]
100%|██████████| 213/213 [00:03<00:00, 55.80it/s]
100%|██████████| 191/191 [00:03<00:00, 55.84it/s]
100%|██████████| 239/239 [00:04<00:00, 54.01it/s]
100%|██████████| 213/213 [00:03<00:00, 54.64it/s]
100%|██████████| 159/159 [00:02<00:00, 53.85it/s]
100%|██████████| 239/239 [00:04<00:00, 53.96it/s]
100%|██████████| 159/159 [00:02<00:00, 55.07it/s]
100%|██████████| 239/239 [00:04<00:00, 55.62it/s]
100%|██████████| 174/174 [00:03<00:00, 55.22it/s]
100%|██████████| 141/141 [00:02<00:00, 55.10it/s]
100%|██████████| 191/191 [00:03<00:00, 54.62it/s]
100%|██████████| 147/147 [00:02<00:00, 54.32it/s]
100%|██████████| 147/147 [00:02<00:00, 55.58it/s]
100%|██████████| 191/191 [00:03<00:00, 54.45it/s]
100%|██████████| 174/174 [00:03<00:00, 54.54it/s]
100%|██████████| 191/191 [00:03<00:00, 51.83it/s]
100%|██████████| 213/213 [00:03<00:00, 54.35it/s]


40
70.26463 77.746445
--> 	 :, FID. 1.1213, Diversity Real. 1.2656, Diversity. 1.3185, R_precision_real. [0.075 0.15  0.175], R_precision. [0.025 0.05  0.1  ], matching_score_real. 1.7566158294677734, matching_score_pred. 1.9436611175537108





Encodec: FID. 0.7455, Diversity Real. 1.3312, Diversity. 1.2632
Sine: FID. 0.6938, Diversity Real. 1.3411, Diversity. 1.2442

In [None]:
600 big
Encodec: FID. 0.7329, Diversity Real. 1.3029, Diversity. 1.2565, R_precision_real. [0.65853659 0.92682927 1.        ], R_precision. [0.07317073 0.12195122 0.12195122], matching_score_real. 0.5742702949337843, matching_score_pred. 1.9448303594821836
Encodec sine: FID. 0.7063, Diversity Real. 1.3348, Diversity. 1.2455, R_precision_real. [0.68292683 0.90243902 0.97560976], R_precision. [0.07317073 0.12195122 0.14634146], matching_score_real. 0.5801425561672304, matching_score_pred. 1.938593980742664

In [None]:
Test abt 300 len big
Encodec sine: FID. 1.0863, Diversity Real. 1.3563, Diversity. 1.3132, R_precision_real. [0.075 0.15  0.175], R_precision. [0.1   0.1   0.125], matching_score_real. 1.7566158294677734, matching_score_pred. 1.9094511032104493
Encodec: FID. 1.1263, Diversity Real. 1.3098, Diversity. 1.3185, R_precision_real. [0.075 0.15  0.175], R_precision. [0.025 0.05  0.1  ], matching_score_real. 1.7566158294677734, matching_score_pred. 1.9520971298217773

In [None]:
small 600
Encodec sine: FID. 0.5658, Diversity Real. 1.3857, Diversity. 1.0602, R_precision_real. [0.56097561 0.68292683 0.7804878 ], R_precision. [0.04878049 0.07317073 0.09756098], matching_score_real. 0.9467918582078887, matching_score_pred. 1.7773874794564597
Encodec: FID. 0.3988, Diversity Real. 1.2957, Diversity. 1.1036, R_precision_real. [0.53658537 0.68292683 0.73170732], R_precision. [0.04878049 0.07317073 0.17073171], matching_score_real. 0.9478622994771818, matching_score_pred. 1.5739107829768484

In [None]:
Test 

In [231]:


seq_len = 800
motion_annotation_list = []
motion_pred_list = []

music_annotation_list = []
music_pred_list = []

R_precision_real = 0
R_precision = 0

nb_sample = 0

audio_dir = audio_feature_dir if use35 else audio_encoding_dir

matching_score_real = 0
matching_score_pred = 0

for i,aist_batch in enumerate(tqdm(aist_loader)):
    
    


    mot_len = int(aist_batch["motion_lengths"][0])
#     print(mot_len)
    motion_name = aist_batch["names"][0]
    
    if len(music_annotation_list)>40:
        break

    if mot_len < seq_len:
        continue

    bs, seq = aist_batch["motion"].shape[0], aist_batch["motion"].shape[1]

    music_name = motion_name.split('_')[-2]
    gen_motion_indices = torch.randint(0 , 1024 , (1,1))

#     et, em = aist_evaluator.get_co_embeddings(music = aist_batch["condition"] , motions = aist_batch["motion"], m_lens = aist_batch["motion_lengths"])

    et, em = aist_evaluator.get_co_embeddings(music = aist_batch["condition"] ,
                                            motions = aist_batch["motion"],
                                            m_lens = torch.Tensor([seq_len]) )
    
    while gen_motion_indices.shape[1]<=seq_len:
        gen_motion_indices = trans_model.generate(start_tokens =gen_motion_indices.cuda(),\
                                                    seq_len=seq_len , \
                                                    context = aist_batch["condition"].cuda(), \
                                                    context_mask=torch.ones((1 ,aist_batch["condition"].shape[1]) , dtype = torch.bool).cuda()
                                                    )


    out_motion = torch.zeros((aist_batch["motion"].shape[0] ,gen_motion_indices.shape[-1] , aist_batch["motion"].shape[-1]))
    for i in range(0 , seq_len, 200):
        quant , out_motion_= vqvae_model.decode(gen_motion_indices[:,i:i+200])
        out_motion[:,i:i+200] = out_motion_
        

#     et_pred, em_pred = aist_evaluator.get_co_embeddings(
#                                                 music = aist_batch["condition"]  , \
#                                                 motions = out_motion[:,1:int(mot_len)+1], \
#                                                 m_lens = aist_batch["motion_lengths"])

    et_pred, em_pred = aist_evaluator.get_co_embeddings(music = aist_batch["condition"] , \
                                            motions =out_motion[:,1:int(seq_len) + 1], \
                                            m_lens = torch.Tensor([seq_len]))


    motion_pred_list.append(em_pred)
    motion_annotation_list.append(em)

    music_pred_list.append(et_pred)
    music_annotation_list.append(et)



    nb_sample += bs

print(nb_sample)


motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()

music_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
music_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()


100%|██████████| 800/800 [00:17<00:00, 45.63it/s]
100%|██████████| 800/800 [00:17<00:00, 45.46it/s]
100%|██████████| 800/800 [00:17<00:00, 45.20it/s]
100%|██████████| 800/800 [00:17<00:00, 45.37it/s]
100%|██████████| 800/800 [00:17<00:00, 45.31it/s]
100%|██████████| 800/800 [00:17<00:00, 45.27it/s]]
100%|██████████| 800/800 [00:17<00:00, 45.25it/s]]
100%|██████████| 800/800 [00:17<00:00, 45.52it/s]]
100%|██████████| 800/800 [00:17<00:00, 45.25it/s]]
100%|██████████| 800/800 [00:17<00:00, 45.22it/s]it]
100%|██████████| 800/800 [00:17<00:00, 45.21it/s]it]
100%|██████████| 800/800 [00:17<00:00, 45.53it/s]it]
100%|██████████| 800/800 [00:17<00:00, 45.17it/s]it]
100%|██████████| 800/800 [00:17<00:00, 45.00it/s]it]
100%|██████████| 800/800 [00:17<00:00, 45.18it/s]it]
100%|██████████| 800/800 [00:17<00:00, 45.12it/s]]  
100%|██████████| 800/800 [00:17<00:00, 45.19it/s]]
100%|██████████| 800/800 [00:17<00:00, 45.10it/s]it]
100%|██████████| 800/800 [00:17<00:00, 45.03it/s]it]
100%|██████████| 8

41





In [189]:
out_motion.shape

torch.Size([1, 401, 263])

In [190]:
aist_batch["condition"].shape

torch.Size([1, 174, 128])

In [177]:

motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()

music_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
music_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()


In [233]:
gt_mu, gt_cov  = calculate_activation_statistics(motion_annotation_np)
mu, cov= calculate_activation_statistics(motion_pred_np)
calculate_frechet_distance(gt_mu, gt_cov, mu, cov)

0.6694914762965667

In [234]:
diversity_real = calculate_diversity(motion_annotation_np, 100 if nb_sample > 100 else 30)
diversity = calculate_diversity(motion_pred_np, 100 if nb_sample > 100 else 30)
print(diversity_real , diversity)

1.3413532 1.2625911


## AICHOREO

In [89]:
from utils.motion_process import recover_from_ric
from utils.aist_metrics.calculate_fid_scores import calculate_avg_distance, extract_feature,calculate_frechet_feature_distance,calculate_frechet_distance
from utils.aist_metrics.features import kinetic,manual
from utils.aist_metrics.calculate_beat_scores import motion_peak_onehot,alignment_score


In [97]:


result_features = {"kinetic": [], "manual": []}
real_features = {"kinetic": [], "manual": []}

mean = aist_loader.dataset.mean
std = aist_loader.dataset.std

beat_scores_real = []
beat_scores_pred = []

real_pfc = []
pred_pfc = []

seq_len =800

audio_dir = audio_encoding_dir = "/srv/scratch/sanisetty3/music_motion/AIST/music"
audio_feature_dir = "/srv/scratch/sanisetty3/music_motion/AIST/audio_features/"


# smpl_motions_aist = glob("/srv/scratch/sanisetty3/clean/mint/evals/eval60/*.npy")
# for i in glob("/srv/scratch/sanisetty3/clean/mint/evals/eval60/hml/*.npy"):
#     smpl_motion = np.load(i)[120: , 6:]
# #     print(smpl_motion.shape)
#     seq_name = os.path.basename(i).split(".")[0]
#     smpl_motions_aist.append({"motion":smpl_motion , "name" : seq_name })



for i,path in enumerate(tqdm(smpl_motions_aist)):
    
    motion_name = os.path.basename(path)
    
    gt_motion = torch.Tensor(np.load(f"/srv/scratch/sanisetty3/music_motion/AIST/new_joint_vecs/{motion_name[:-9]}.npy")[None,...])
#     print(gt_motion.shape)
    out_motion = torch.Tensor(np.load(f"/srv/scratch/sanisetty3/clean/mint/evals/eval60/hml/{motion_name}")[None,...])
#     print(out_motion.shape)

    keypoints3d_gt = recover_from_ric(gt_motion[0,:seq_len] , 22).detach().cpu().numpy()
    keypoints3d_pred = recover_from_ric(out_motion[0,:seq_len] , 22).detach().cpu().numpy()

    real_features["kinetic"].append(extract_feature(keypoints3d_gt, "kinetic"))
    real_features["manual"].append(extract_feature(keypoints3d_gt, "manual"))

    result_features["kinetic"].append(extract_feature(keypoints3d_pred, "kinetic"))
    result_features["manual"].append(extract_feature(keypoints3d_pred, "manual"))

#     real_pfc.append(calc_physical_score(keypoints3d_gt))
#     pred_pfc.append(calc_physical_score(keypoints3d_pred))






    motion_beats = motion_peak_onehot(keypoints3d_gt[:seq_len])
    # get real data music beats
    audio_name = motion_name.split("_")[-3]

    audio_feature = np.load(os.path.join(audio_feature_dir, f"{audio_name}.npy"))
    audio_beats = audio_feature[:seq_len, -1] # last dim is the music beats
    # get beat alignment scores
    beat_score = alignment_score(audio_beats, motion_beats, sigma=1)
    beat_scores_real.append(beat_score)


    motion_beats = motion_peak_onehot(keypoints3d_pred[:seq_len])
    beat_score_pred = alignment_score(audio_beats, motion_beats, sigma=1)
    beat_scores_pred.append(beat_score_pred)


FID_k, Dist_k = calculate_frechet_feature_distance(real_features["kinetic"], result_features["kinetic"])
FID_g, Dist_g = calculate_frechet_feature_distance(real_features["manual"], result_features["manual"])


print("FID_k: ",FID_k,"Diversity_k:",Dist_k)
print("FID_g: ",FID_g,"Diversity_g:",Dist_g)

# print ("\PFC score on real data: %.3f\n" % (np.mean(real_pfc)))
# print ("\PFC score on generated data: %.3f\n" % (np.mean(pred_pfc)))


print ("\nBeat score on real data: %.3f\n" % (np.mean(beat_scores_real)))
print ("\nBeat score on generated data: %.3f\n" % (np.mean(beat_scores_pred)))



best_fid_k = FID_k if FID_k < best_fid_k else best_fid_k
best_fid_g = FID_g if FID_g < best_fid_g else best_fid_g
best_div_k = Dist_k if Dist_k > best_div_k else best_div_k
best_div_g = Dist_g if Dist_g > best_div_g else best_div_g

best_beat_align = np.mean(beat_scores_real) if np.mean(beat_scores_real) > best_beat_align else best_beat_align 



# return best_fid_k, best_fid_g,best_div_k,best_div_g,best_beat_align , np.mean(real_pfc), np.mean(pred_pfc)



100%|██████████| 44/44 [01:47<00:00,  2.44s/it]

FID_k:  7.528682831845714 Diversity_k: 10.279027991012589
FID_g:  7.848392292979781 Diversity_g: 7.421904608008726

Beat score on real data: 0.310


Beat score on generated data: 0.225






NameError: name 'best_fid_k' is not defined

## Transformed HML

In [143]:
keypoints3d_hml = np.load("/srv/scratch/sanisetty3/clean/mint/evals/eval60/keypoints3D_hml60.npy").reshape(-1,22,3)
keypoints3d_aist = np.load("/srv/scratch/sanisetty3/clean/mint/evals/eval60/keypoints3D_aist.npy")[:52756,:22,:]

from pycpd import RigidRegistration, DeformableRegistration
import numpy as np


In [149]:
reg = RigidRegistration(X=keypoints3d_aist[:100].reshape(-1,3), Y=keypoints3d_hml[:100].reshape(-1,3))

In [150]:
TY, (s_reg, R_reg, t_reg) = reg.register()

In [151]:
np.mean(TY - keypoints3d_aist[:100].reshape(-1,3))

6.646927100911472e-05

NameError: name 'points_to_transform' is not defined

In [162]:
aist_ds = VQFullMotionDataset("aist", split = "test" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
val_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 40/40 [00:00<00:00, 1387.74it/s]

Total number of motions 40





In [152]:
for batch in aist_loader:
    break

In [174]:
keypoints3d_gt = recover_from_ric(batch["motion"][0,:100] , 22).detach().cpu().numpy()

In [176]:
np.save("/srv/scratch/sanisetty3/music_motion/motion_vqvae/paper_renders/smpl/keypoints3d_gt.npy" , reg.transform_point_cloud(Y=keypoints3d_gt.reshape(-1,3)).reshape(-1,22,3))

In [172]:
result_features = {"kinetic": [], "manual": []}
real_features = {"kinetic": [], "manual": []}

mean = val_loader.dataset.mean
std = val_loader.dataset.std

beat_scores_real = []
beat_scores_pred = []

for i,aist_batch in enumerate(tqdm(val_loader)):


    mot_len = aist_batch["motion_lengths"][0]
    motion_name = aist_batch["names"][0]

    music_name = motion_name.split('_')[-2]
    music_encoding=  np.load(os.path.join(audio_encoding_dir , music_name + ".npy"))

    gen_motion_indices = torch.randint(0 , 1024 , (1,1))

    while gen_motion_indices.shape[1]<=seq_len:

       
        gen_motion_indices = trans_model.generate(start_tokens =gen_motion_indices.cuda(),\
                                                    seq_len=400 , \
                                                    context = torch.Tensor(music_encoding)[None,...].cuda(), \
                                                    context_mask=torch.ones((1 ,music_encoding.shape[0]) , dtype = torch.bool).cuda(),\
                                                    )

        gen_motion_indices = gen_motion_indices[gen_motion_indices<1024][None,...]

    try:
        out_motion = torch.zeros((aist_batch["motion"].shape[0] ,gen_motion_indices.shape[-1] , aist_batch["motion"].shape[-1]))
        for i in range(0 , seq_len, 200):
            quant , out_motion_= vqvae_model.decode(gen_motion_indices[:,i:i+200])
            out_motion[:,i:i+200] = out_motion_

        # quant , out_motion = vqvae_model.module.decode(gen_motion_indices[:,:mot_len])
    except:
        # quant , out_motion = vqvae_model.decode(gen_motion_indices[:,:mot_len])
        out_motion = torch.zeros((aist_batch["motion"].shape[0] ,gen_motion_indices.shape[-1] , aist_batch["motion"].shape[-1]))
        for i in range(0 , seq_len, 200):
            quant , out_motion_= vqvae_model.decode(gen_motion_indices[:,i:i+200])
            out_motion[:,i:i+200] = out_motion_


    keypoints3d_gt = recover_from_ric(aist_batch["motion"][0,:mot_len] , 22).detach().cpu().numpy()
    keypoints3d_gt = reg.transform_point_cloud(Y=keypoints3d_gt.reshape(-1,3)).reshape(-1,22,3)
    
    
    keypoints3d_pred = recover_from_ric(out_motion[0,:mot_len] , 22).detach().cpu().numpy()
    keypoints3d_pred = reg.transform_point_cloud(Y=keypoints3d_pred.reshape(-1,3)).reshape(-1,22,3)
    
    
    real_features["kinetic"].append(extract_feature(keypoints3d_gt, "kinetic"))
    real_features["manual"].append(extract_feature(keypoints3d_gt, "manual"))

    result_features["kinetic"].append(extract_feature(keypoints3d_pred, "kinetic"))
    result_features["manual"].append(extract_feature(keypoints3d_pred, "manual"))




    motion_beats = motion_peak_onehot(keypoints3d_gt[:mot_len])
    # get real data music beats
    audio_name = motion_name.split("_")[-2]

    audio_feature = np.load(os.path.join(audio_feature_dir, f"{audio_name}.npy"))
    audio_beats = audio_feature[:mot_len, -1] # last dim is the music beats
    # get beat alignment scores
    beat_score = alignment_score(audio_beats, motion_beats, sigma=1)
    beat_scores_real.append(beat_score)


    motion_beats = motion_peak_onehot(keypoints3d_pred[:mot_len])
    beat_score_pred = alignment_score(audio_beats, motion_beats, sigma=1)
    beat_scores_pred.append(beat_score_pred)


FID_k, Dist_k = calculate_frechet_feature_distance(real_features["kinetic"], result_features["kinetic"])
FID_g, Dist_g = calculate_frechet_feature_distance(real_features["manual"], result_features["manual"])


print("FID_k: ",FID_k,"Diversity_k:",Dist_k)
print("FID_g: ",FID_g,"Diversity_g:",Dist_g)


print ("\nBeat score on real data: %.3f\n" % (np.mean(beat_scores_real)))
print ("\nBeat score on generated data: %.3f\n" % (np.mean(beat_scores_pred)))

best_fid_k = FID_k if FID_k < best_fid_k else best_fid_k
best_fid_g = FID_g if FID_g < best_fid_g else best_fid_g
best_div_k = Dist_k if Dist_k > best_div_k else best_div_k
best_div_g = Dist_g if Dist_g > best_div_g else best_div_g

best_beat_align = np.mean(beat_scores_real) if np.mean(beat_scores_real) > best_beat_align else best_beat_align 



return best_fid_k, best_fid_g,best_div_k,best_div_g,best_beat_align


100%|██████████| 400/400 [00:08<00:00, 45.42it/s]
100%|██████████| 400/400 [00:10<00:00, 39.73it/s]
100%|██████████| 400/400 [00:08<00:00, 45.42it/s]
100%|██████████| 400/400 [00:10<00:00, 39.57it/s]
100%|██████████| 400/400 [00:08<00:00, 45.31it/s]
100%|██████████| 400/400 [00:10<00:00, 39.58it/s]
100%|██████████| 400/400 [00:08<00:00, 45.29it/s]
100%|██████████| 400/400 [00:10<00:00, 39.51it/s]
100%|██████████| 400/400 [00:08<00:00, 45.41it/s]
100%|██████████| 400/400 [00:10<00:00, 39.47it/s]
100%|██████████| 400/400 [00:08<00:00, 45.03it/s]
100%|██████████| 400/400 [00:10<00:00, 39.45it/s]
100%|██████████| 400/400 [00:08<00:00, 45.43it/s]
100%|██████████| 400/400 [00:10<00:00, 39.43it/s]
100%|██████████| 400/400 [00:08<00:00, 45.43it/s]
100%|██████████| 400/400 [00:10<00:00, 39.43it/s]
100%|██████████| 400/400 [00:08<00:00, 45.32it/s]
100%|██████████| 400/400 [00:10<00:00, 39.39it/s]
100%|██████████| 400/400 [00:08<00:00, 45.39it/s]
100%|██████████| 400/400 [00:10<00:00, 39.35it/s]


FID_k:  6.9577679600177476 Diversity_k: 10.06756729009824
FID_g:  11.167501952145543 Diversity_g: 7.1934799310488575

Beat score on real data: 0.244


Beat score on generated data: 0.263






NameError: name 'best_fid_k' is not defined

## Sinusoidal pos emb

FID_k:  3.8717503038434415 Diversity_k: 10.580599220288105
FID_g:  14.323117971786502 Diversity_g: 7.208645957555526

Beat score on real data: 0.244


Beat score on generated data: 0.196

### Albi

FID_k:  5.974822501678858 Diversity_k: 9.894211104435799
FID_g:  10.945797702730552 Diversity_g: 7.33098736114991

Beat score on real data: 0.244


Beat score on generated data: 0.207

## Style motion generative

In [79]:
from utils.eval_music import evaluate_music_motion_generative_style, evaluate_music_motion_generative_style2

In [None]:
aist_ds = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)


In [None]:
aist_ds = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

In [75]:
import clip
clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False)  # Must set jit=False for training
clip_model.eval()
for p in clip_model.parameters():
    p.requires_grad = False


In [76]:
encodec_style = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist_style/var_len_768_768_aist_style.yaml"


In [77]:
from configs.config import cfg, get_cfg_defaults
from core.models.motion_regressor import MotionRegressorModel


cfg_trans = get_cfg_defaults()
cfg_trans.merge_from_file(encodec_style)



trans_model = MotionRegressorModel(args = cfg_trans.motion_trans,pad_value=1025 ).eval()
pkg_trans = torch.load(f"{os.path.dirname(encodec_style)}/trans_motion.pt", map_location = 'cpu')
print(pkg_trans["steps"])
trans_model.load_state_dict(pkg_trans["model"])
trans_model =trans_model.cuda()


tensor([95000.])


In [78]:
evaluate_music_motion_generative_style(aist_loader , vqvae_model= vqvae_model ,net = trans_model,clip_model=clip_model,seq_len = 800)


NameError: name 'evaluate_music_motion_generative_style' is not defined

In [60]:
aist_ds = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 1910/1910 [00:22<00:00, 84.98it/s] 

Total number of motions 1910





In [80]:
evaluate_music_motion_generative_style2(aist_loader , vqvae_model= vqvae_model ,net = trans_model,clip_model=clip_model,seq_len = 800)


100%|██████████| 800/800 [00:21<00:00, 37.43it/s]
100%|██████████| 800/800 [00:21<00:00, 37.46it/s]]
100%|██████████| 800/800 [00:21<00:00, 37.45it/s]t]
100%|██████████| 800/800 [00:21<00:00, 37.42it/s]t]
100%|██████████| 800/800 [00:21<00:00, 37.37it/s]t]
100%|██████████| 800/800 [00:21<00:00, 37.33it/s]t]
100%|██████████| 800/800 [00:21<00:00, 37.35it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.35it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.30it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.34it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.30it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.31it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.32it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.32it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.30it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.31it/s]it]
100%|██████████| 800/800 [00:21<00:00, 37.25it/s]]  
100%|██████████| 800/800 [00:21<00:00, 37.31it/s]]
100%|██████████| 800/800 [00:21<00:00, 37.30it/s]]
100%|█

FID_k:  3.522845989252062 Diversity_k: 10.215643210963504
FID_g:  8.57888238870914 Diversity_g: 7.215440534382331

Beat score on real data: 0.172


Beat score on generated data: 0.184






(3.522845989252062,
 8.57888238870914,
 10.215643210963504,
 7.215440534382331,
 0.1715415597282555)

In [81]:
0.243*0.184 / 0.172

0.25995348837209303

## Render

In [7]:
audio_encoding_dir = "/srv/scratch/sanisetty3/music_motion/AIST/music"

genre_dict = {
"mBR" : "Break",
"mPO" : "Pop",
"mLO" : "Lock",
"mMH" : "Middle Hip-hop",
"mLH" : "LA style Hip-hop",
"mHO" : "House",    
"mWA" : "Waack",
"mKR" : "Krump",
"mJS" : "Street Jazz",
"mJB" : "Ballet Jazz",
}



In [42]:
style = None

In [49]:
for i,aist_batch in enumerate(tqdm(aist_loader)):
    break
motion_name = aist_batch["names"][0]

music_name = motion_name.split('_')[-2]
music_encoding=  np.load(os.path.join(audio_encoding_dir , music_name + ".npy"))

print(genre_dict.get(music_name[:3]))

  0%|          | 0/40 [00:00<?, ?it/s]

Ballet Jazz





In [50]:


mot_len = aist_batch["motion_lengths"][0]
motion_name = aist_batch["names"][0]

music_name = motion_name.split('_')[-2]
music_encoding=  np.load(os.path.join(audio_encoding_dir , music_name + ".npy"))

print(genre_dict.get(music_name[:3]))

genre = (genre_dict.get(music_name[:3])) if style is None else style

text = clip.tokenize([genre], truncate=True).cuda()
style_embeddings = clip_model.encode_text(text).cpu().float().reshape(-1) if clip_model is not None else None
gen_motion_indices = torch.randint(0 , 1024 , (1,1))
gen_motion_indices = trans_model.generate(start_tokens =gen_motion_indices.cuda(),\
                                        seq_len=400 , \
                                        context = torch.Tensor(music_encoding)[None,...].cuda(), \
                                        context_mask=torch.ones((1 ,music_encoding.shape[0]) , dtype = torch.bool).cuda(),\
                                        style_context = torch.Tensor(style_embeddings.reshape(-1))[None,...].cuda(),
                                        )
gen_motion_indices = gen_motion_indices[gen_motion_indices<1024][None,...]

quant , out_motion = vqvae_model.decode(gen_motion_indices)

Ballet Jazz


100%|██████████| 400/400 [00:07<00:00, 56.84it/s]


In [52]:
aist_batch["motion"].shape

torch.Size([1, 147, 263])

In [44]:
genre

'Slow'

In [19]:
sample_render(to_xyz(aist_batch["motion"][0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "style_gt" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


In [51]:
sample_render(to_xyz(out_motion[:,:mot_len].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "style_none" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/style/")

render start


### Music VQ

In [13]:

load_path_mix = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/checkpoints/vqvae_motion.295000.pt"
load_path_hml = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768/vqvae_motion.pt"
load_path_aist = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_aist/vqvae_motion.pt"


In [14]:
from configs.config import cfg, get_cfg_defaults
from core.models.vqvae import VQMotionModel
from utils.eval_music import evaluate_music_motion_vqvae

cfg_vq = get_cfg_defaults()
cfg_vq.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/var_len_768_768_aist_vq.yaml")

load_path = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/vqvae_motion_best_fid.pt"


In [15]:

vqvae_model = VQMotionModel(cfg_vq.vqvae).eval()
pkg = torch.load(f"{load_path_mix}", map_location = 'cpu')
print(pkg["steps"])
vqvae_model.load_state_dict(pkg["model"])
vqvae_model =vqvae_model.cuda()



tensor([295000.])


In [17]:
aist_ds = VQFullMotionDataset("aist", split = "test" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 40/40 [00:00<00:00, 1496.39it/s]

Total number of motions 40





In [21]:
print("pretrained only t2m")
evaluate_music_motion_vqvae(aist_loader,vqvae_model)

pretrained only t2m


100%|██████████| 40/40 [00:41<00:00,  1.04s/it]

FID_k:  3.2996737750179364 Diversity_k: 10.26604298215646
FID_g:  10.78302279919913 Diversity_g: 7.181474344852643
FID_k_real:  -7.86550347697812e-06 Diversity_k_real: 10.195780532558759
FID_g_real:  -1.9184653865522705e-13 Diversity_g_real: 7.348854861503992

Beat score on real data: 0.245


Beat score on generated data: 0.176






(3.2996737750179364,
 10.78302279919913,
 10.26604298215646,
 7.181474344852643,
 0.24494051462936942)

In [18]:

# print("pretrained mix")
# best_fid_k = []
# best_fid_g = []
# best_div_k = []
# best_div_g = []
# best_beat_align = []

# for i in range(20):

#     a,b,c,d,e = evaluate_music_motion_vqvae(aist_loader,vqvae_model)
#     best_fid_k.append(a)
#     best_fid_g.append(b)
#     best_div_k.append(c)
#     best_div_g.append(d)
#     best_beat_align.append(e)

    
# print("best_fid_k" , np.mean(best_fid_k))
# print("best_fid_g" , np.mean(best_fid_g))
# print("best_div_k" , np.mean(best_div_k))
# print("best_div_g" , np.mean(best_div_g))
# print("best_beat_align" , np.mean(best_beat_align))

In [19]:
print("mix")
evaluate_music_motion_vqvae(aist_loader,vqvae_model)


mix


100%|██████████| 40/40 [00:39<00:00,  1.00it/s]

FID_k:  2.635362356342995 Diversity_k: 10.163608189500295
FID_g:  7.295345718653849 Diversity_g: 7.234946262225127
FID_k_real:  -7.757004908626186e-06 Diversity_k_real: 10.205963216454554
FID_g_real:  -1.903529778246593e-09 Diversity_g_real: 7.344472836225461

Beat score on real data: 0.244


Beat score on generated data: 0.234






(2.635362356342995,
 7.295345718653849,
 10.163608189500295,
 7.234946262225127,
 0.244130060202235)

In [100]:
print("pretrained only t2m, finetuned only aist")
evaluate_music_motion_vqvae(aist_loader,vqvae_model)

pretrained only t2m, finetuned only aist


  5%|▌         | 2/40 [00:01<00:37,  1.01it/s]


KeyboardInterrupt: 

In [28]:
### Mixture
evaluate_music_motion_vqvae(aist_loader,vqvae_model)

100%|██████████| 1910/1910 [29:56<00:00,  1.06it/s]


FID_k:  0.010750366559051372 Diversity_k: 9.172959109891172
FID_g:  1.2350136226828567 Diversity_g: 7.381343867734843

Beat score on real data: 0.249


Beat score on generated data: 0.250



(0.010750366559051372,
 1.2350136226828567,
 9.172959109891172,
 7.381343867734843,
 0.24940255611332512)

## Generating token dataset

In [10]:
(0.11*255, 0.53*255, 0.8*255, 0.5*255)

(28.05, 135.15, 204.0, 127.5)

In [64]:
aist_ds = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 1910/1910 [00:01<00:00, 1244.08it/s]

Total number of motions 1910





In [67]:
for batch in tqdm(aist_loader):
    
    n = int(batch["motion_lengths"])
    name = str(batch["names"][0])
    if n< 400:
        ind = vqvae_model.encode(batch["motion"].cuda())
    else:
#         ind = vqvae_model.encode(batch["motion"][:,:400].cuda())
        inds = []
        for i in range(0 , n, 200):
            ii = vqvae_model.encode(batch["motion"][:,i:i+200].cuda())
            inds.append(ii[0])
#             print(ii.shape)
        
        ind = torch.concatenate(inds)[None,...]
        
#     print(ind.shape)
    
    np.save(os.path.join("/srv/scratch/sanisetty3/music_motion/AIST/joint_indices" , name+".npy"),ind.cpu().numpy()[0])
        
#         quant , out_motion = vqvae_model.decode(ind)
    

100%|██████████| 1910/1910 [01:47<00:00, 17.82it/s]


In [68]:
mot_list = glob("/srv/scratch/sanisetty3/music_motion/AIST/joint_indices/*.npy")

In [69]:
np.load(mot_list[0]).shape[0]

174

In [71]:
lens = []
for i in mot_list:
    lens.append(np.load(i).shape[0])

In [72]:
max(lens)

959

In [11]:
hlm_ds = VQFullMotionDataset("t2m", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/HumanML3D/HumanML3D/" , window_size = -1)
hlm_loader = DATALoader(hlm_ds,1,collate_fn=None)

100%|██████████| 23384/23384 [08:53<00:00, 43.85it/s] 

Total number of motions 23384





In [12]:
for batch in tqdm(hlm_loader):
    break

  0%|          | 0/23384 [00:00<?, ?it/s]


In [13]:
n = int(batch["motion_lengths"])
name = str(batch["names"][0])
print(n,name)

199 M003397


In [14]:
ind = vqvae_model.encode(batch["motion"].cuda())

In [15]:
ind.shape

torch.Size([1, 199])

In [16]:
for batch in tqdm(hlm_loader):
    
    n = int(batch["motion_lengths"])
    name = str(batch["names"][0])
    if n< 400:
        ind = vqvae_model.encode(batch["motion"].cuda())
    else:
        #ind = vqvae_model.encode(batch["motion"][:,:400].cuda())
#         out_motion = torch.zeros((batch["motion"].shape[0] ,gen_motion_indices.shape[-1] , aist_batch["motion"].shape[-1]))
        inds = []
        for i in range(0 , n, 200):
            inds.append(vqvae_model.encode(batch["motion"].cuda()))
        
        ind = torch.stack(inds)
    
    np.save(os.path.join("/srv/scratch/sanisetty3/music_motion/HumanML3D/HumanML3D/joint_indices" , name+".npy"), ind.cpu().numpy()[0])
    
        
    

  0%|          | 0/23384 [00:00<?, ?it/s]


FileNotFoundError: [Errno 2] No such file or directory: '/srv/scratch/sanisetty3/music_motion/HumanML3D/HumanML3D/joint_indices/007648.npy'

In [20]:
sample_render(to_xyz(batch["motion"][0:1].detach().cpu(),mean = hlm_ds.mean , std = hlm_ds.std), "rnd_motion" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


## T2M Eval

In [57]:
import utils.utils_model as utils_model
from core.datasets import dataset_TM_eval
import utils.eval_trans as eval_trans
from core.models.evaluator_wrapper import EvaluatorModelWrapper
from utils.word_vectorizer import WordVectorizer
from utils.eval_trans import evaluation_vqvae_loss,evaluation_vqvae
from utils.eval_trans import calculate_R_precision,calculate_activation_statistics,calculate_diversity,calculate_frechet_distance
from tqdm import tqdm

In [59]:
w_vectorizer = WordVectorizer('/srv/scratch/sanisetty3/music_motion/T2M-GPT/glove', 'our_vab')
eval_wrapper = EvaluatorModelWrapper(cfg.eval_model)
tm_eval = dataset_TM_eval.DATALoader("t2m", False, 20, w_vectorizer, unit_length=4)


  0%|          | 4/1460 [00:00<00:43, 33.22it/s]

Loading Evaluation Model Wrapper (Epoch 28) Completed!!


100%|██████████| 1460/1460 [00:51<00:00, 28.32it/s]

1530 1530
Pointer Pointing at 0





In [60]:
for batch in tm_eval:
    break

In [64]:
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token, name = batch


In [70]:
sent_len

tensor([22, 22, 22, 21, 19, 17, 17, 15, 15, 15, 14, 13, 12, 12, 11, 10, 10,  8,
         8,  7])

In [66]:
word_embeddings.shapea

torch.Size([20, 22, 300])

In [15]:
from configs.config import cfg, get_cfg_defaults
from core.models.vqvae import VQMotionModel
from core.models.motion_regressor import MotionRegressorModel

load_path = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_aist/vqvae_motion.pt"
cfg_vq = get_cfg_defaults()
cfg_vq.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist_vq.yaml")

vqvae_model = VQMotionModel(cfg_vq.vqvae).eval()
pkg = torch.load(f"{load_path}", map_location = 'cpu')
print(pkg["steps"])
vqvae_model.load_state_dict(pkg["model"])
vqvae_model =vqvae_model.cuda()



tensor([275000.])


In [13]:
### Pretrained on t2m only
metrics = evaluation_vqvae_loss(val_loader = tm_eval, net= vqvae_model,nb_iter= 0, eval_wrapper = eval_wrapper,save = False,)

100%|██████████| 232/232 [00:43<00:00,  5.31it/s]


--> 	 Eva. Iter 0 :, FID. 0.0668, Diversity Real. 9.5584, Diversity. 9.9187, R_precision_real. [0.60193966 0.78189655 0.86810345], R_precision. [0.59439655 0.77801724 0.85991379], matching_score_real. 2.9862875124503825, matching_score_pred. 3.028119134902954


In [16]:
print("pretrained on t2m only and finetuned on aist")
metrics = evaluation_vqvae_loss(val_loader = tm_eval, net= vqvae_model,nb_iter= 0, eval_wrapper = eval_wrapper,save = False,)

pretrained on t2m only and finetuned on aist


100%|██████████| 232/232 [00:43<00:00,  5.30it/s]


--> 	 Eva. Iter 0 :, FID. 3.2204, Diversity Real. 9.3818, Diversity. 7.4288, R_precision_real. [0.59698276 0.78728448 0.86551724], R_precision. [0.43512931 0.64073276 0.75625   ], matching_score_real. 2.9870332890543443, matching_score_pred. 4.043809700012207


In [37]:
## Pretrained on a mix of aist and t2m
metrics = evaluation_vqvae_loss(val_loader = tm_eval, net= vqvae_model,nb_iter= 0, eval_wrapper = eval_wrapper,save = False,)

100%|██████████| 232/232 [00:42<00:00,  5.50it/s]


--> 	 Eva. Iter 0 :, FID. 0.0637, Diversity Real. 9.4620, Diversity. 9.4266, R_precision_real. [0.61616379 0.79482759 0.86982759], R_precision. [0.60172414 0.78405172 0.86077586], matching_score_real. 2.9635313979510602, matching_score_pred. 3.0240239735307366


## Render

In [4]:
from render_final import render, saveSMPL
from core.datasets.vqa_motion_dataset import TransMotionDatasetConditionalFull
from glob import glob

In [5]:
motions_list = glob("/srv/scratch/sanisetty3/clean/mint/evals/eval60/hml/*.npy" , recursive=False)


In [6]:
motions_list = glob("/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/generator/var_len/trans_768_768_albi_aist/*.npy" , recursive=False)


In [11]:
std = np.load("/srv/scratch/sanisetty3/music_motion/AIST/Std.npy")
mean = np.load("/srv/scratch/sanisetty3/music_motion/AIST/Mean.npy")

In [13]:
torch.Tensor(np.load(motions_list[0])).shape

torch.Size([1, 401, 263])

In [17]:
for i in tqdm(motions_list):
    name =os.path.basename(i).split(".")[0]
    print(name)
    motion = torch.Tensor(np.load(i))
    motion_xyz = to_xyz(torch.Tensor(motion) , mean= mean , std = std)
    saveSMPL(motion_xyz[0].numpy(), outdir= os.path.join(os.path.dirname(i) , "SMPL_dict"), name=name+"smpl_dict", pred=True)

  0%|          | 0/60 [00:00<?, ?it/s]

mWA3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


  vertices, rotations, global_orient, out, x_translations = rot2xyz(torch.tensor(motion_tensor).clone(), mask=None,
  2%|▏         | 1/60 [02:46<2:43:58, 166.75s/it]

torch.Size([1, 6890, 3, 401])
mWA4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


  3%|▎         | 2/60 [06:02<2:49:36, 175.46s/it]

torch.Size([1, 6890, 3, 401])
mHO4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


  5%|▌         | 3/60 [08:53<2:45:30, 174.23s/it]

torch.Size([1, 6890, 3, 401])
mHO3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


  7%|▋         | 4/60 [11:29<2:37:20, 168.57s/it]

torch.Size([1, 6890, 3, 401])
mMH3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


  8%|▊         | 5/60 [14:22<2:35:51, 170.04s/it]

torch.Size([1, 6890, 3, 401])
mJB4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 10%|█         | 6/60 [17:19<2:34:48, 172.01s/it]

torch.Size([1, 6890, 3, 401])
mPO3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 12%|█▏        | 7/60 [19:53<2:27:19, 166.78s/it]

torch.Size([1, 6890, 3, 401])
mPO4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 13%|█▎        | 8/60 [22:52<2:27:39, 170.37s/it]

torch.Size([1, 6890, 3, 401])
mJB3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 15%|█▌        | 9/60 [25:18<2:18:29, 162.93s/it]

torch.Size([1, 6890, 3, 401])
mMH4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 17%|█▋        | 10/60 [28:12<2:18:42, 166.44s/it]

torch.Size([1, 6890, 3, 401])
mJS4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 18%|█▊        | 11/60 [30:44<2:12:21, 162.07s/it]

torch.Size([1, 6890, 3, 401])
mBR0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 20%|██        | 12/60 [33:24<2:09:04, 161.35s/it]

torch.Size([1, 6890, 3, 401])
mJS3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 22%|██▏       | 13/60 [36:11<2:07:38, 162.94s/it]

torch.Size([1, 6890, 3, 401])
mLH3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 23%|██▎       | 14/60 [38:52<2:04:34, 162.48s/it]

torch.Size([1, 6890, 3, 401])
mLH4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 25%|██▌       | 15/60 [40:55<1:52:59, 150.66s/it]

torch.Size([1, 6890, 3, 401])
mKR2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 27%|██▋       | 16/60 [43:19<1:49:06, 148.78s/it]

torch.Size([1, 6890, 3, 401])
mLO1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 28%|██▊       | 17/60 [45:26<1:41:55, 142.21s/it]

torch.Size([1, 6890, 3, 401])
mKR5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 30%|███       | 18/60 [48:14<1:44:56, 149.92s/it]

torch.Size([1, 6890, 3, 401])
mPO5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 32%|███▏      | 19/60 [51:06<1:46:54, 156.46s/it]

torch.Size([1, 6890, 3, 401])
mMH5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 33%|███▎      | 20/60 [54:06<1:48:58, 163.47s/it]

torch.Size([1, 6890, 3, 401])
mJB2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 35%|███▌      | 21/60 [56:42<1:44:52, 161.35s/it]

torch.Size([1, 6890, 3, 401])
mJB5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 37%|███▋      | 22/60 [59:45<1:46:15, 167.79s/it]

torch.Size([1, 6890, 3, 401])
mMH2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 38%|███▊      | 23/60 [1:02:02<1:37:47, 158.57s/it]

torch.Size([1, 6890, 3, 401])
mPO2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 40%|████      | 24/60 [1:04:46<1:36:04, 160.12s/it]

torch.Size([1, 6890, 3, 401])
mHO2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 42%|████▏     | 25/60 [1:07:47<1:37:05, 166.44s/it]

torch.Size([1, 6890, 3, 401])
mHO5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 43%|████▎     | 26/60 [1:11:00<1:38:48, 174.37s/it]

torch.Size([1, 6890, 3, 401])
mWA5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 45%|████▌     | 27/60 [1:13:53<1:35:42, 174.00s/it]

torch.Size([1, 6890, 3, 401])
mWA2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 47%|████▋     | 28/60 [1:16:21<1:28:34, 166.08s/it]

torch.Size([1, 6890, 3, 401])
mLO0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 48%|████▊     | 29/60 [1:19:11<1:26:32, 167.50s/it]

torch.Size([1, 6890, 3, 401])
mKR4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 50%|█████     | 30/60 [1:21:37<1:20:27, 160.90s/it]

torch.Size([1, 6890, 3, 401])
mKR3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 52%|█████▏    | 31/60 [1:24:28<1:19:12, 163.88s/it]

torch.Size([1, 6890, 3, 401])
mLH5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 53%|█████▎    | 32/60 [1:27:17<1:17:16, 165.58s/it]

torch.Size([1, 6890, 3, 401])
mLH2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 55%|█████▌    | 33/60 [1:30:05<1:14:47, 166.20s/it]

torch.Size([1, 6890, 3, 401])
mBR1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 57%|█████▋    | 34/60 [1:32:29<1:09:05, 159.45s/it]

torch.Size([1, 6890, 3, 401])
mJS2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 58%|█████▊    | 35/60 [1:35:21<1:07:59, 163.19s/it]

torch.Size([1, 6890, 3, 401])
mJS5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 60%|██████    | 36/60 [1:38:16<1:06:43, 166.81s/it]

torch.Size([1, 6890, 3, 401])
mLO4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 62%|██████▏   | 37/60 [1:41:13<1:05:10, 170.03s/it]

torch.Size([1, 6890, 3, 401])
mKR0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 63%|██████▎   | 38/60 [1:44:13<1:03:23, 172.89s/it]

torch.Size([1, 6890, 3, 401])
mLO3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 65%|██████▌   | 39/60 [1:46:39<57:41, 164.84s/it]  

torch.Size([1, 6890, 3, 401])
mLH1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 67%|██████▋   | 40/60 [1:49:16<54:12, 162.63s/it]

torch.Size([1, 6890, 3, 401])
mBR5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 68%|██████▊   | 41/60 [1:52:10<52:34, 166.04s/it]

torch.Size([1, 6890, 3, 401])
mJS1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 70%|███████   | 42/60 [1:55:01<50:15, 167.52s/it]

torch.Size([1, 6890, 3, 401])
mBR2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 72%|███████▏  | 43/60 [1:57:49<47:28, 167.57s/it]

torch.Size([1, 6890, 3, 401])
mMH1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 73%|███████▎  | 44/60 [2:00:50<45:43, 171.45s/it]

torch.Size([1, 6890, 3, 401])
mPO1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 75%|███████▌  | 45/60 [2:04:06<44:45, 179.01s/it]

torch.Size([1, 6890, 3, 401])
mJB1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 77%|███████▋  | 46/60 [2:07:03<41:35, 178.22s/it]

torch.Size([1, 6890, 3, 401])
mHO1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 78%|███████▊  | 47/60 [2:09:42<37:22, 172.52s/it]

torch.Size([1, 6890, 3, 401])
mWA1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 80%|████████  | 48/60 [2:12:37<34:40, 173.37s/it]

torch.Size([1, 6890, 3, 401])
mJS0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 82%|████████▏ | 49/60 [2:15:44<32:32, 177.52s/it]

torch.Size([1, 6890, 3, 401])
mBR3
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 83%|████████▎ | 50/60 [2:18:16<28:16, 169.68s/it]

torch.Size([1, 6890, 3, 401])
mBR4
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 85%|████████▌ | 51/60 [2:20:55<24:59, 166.67s/it]

torch.Size([1, 6890, 3, 401])
mLH0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 87%|████████▋ | 52/60 [2:23:58<22:50, 171.35s/it]

torch.Size([1, 6890, 3, 401])
mLO2
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 88%|████████▊ | 53/60 [2:27:17<20:58, 179.77s/it]

torch.Size([1, 6890, 3, 401])
mLO5
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 90%|█████████ | 54/60 [2:30:14<17:53, 178.92s/it]

torch.Size([1, 6890, 3, 401])
mKR1
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 92%|█████████▏| 55/60 [2:33:06<14:44, 176.83s/it]

torch.Size([1, 6890, 3, 401])
mWA0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 93%|█████████▎| 56/60 [2:36:24<12:12, 183.10s/it]

torch.Size([1, 6890, 3, 401])
mHO0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 95%|█████████▌| 57/60 [2:39:20<09:03, 181.00s/it]

torch.Size([1, 6890, 3, 401])
mJB0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 97%|█████████▋| 58/60 [2:41:35<05:34, 167.20s/it]

torch.Size([1, 6890, 3, 401])
mMH0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


 98%|█████████▊| 59/60 [2:44:13<02:44, 164.42s/it]

torch.Size([1, 6890, 3, 401])
mPO0
cuda:0
./body_models/
Running SMPLify, it may take a few minutes.
torch.Size([1, 25, 6, 401]) dict_keys(['pose', 'betas', 'cam'])


100%|██████████| 60/60 [2:46:40<00:00, 166.68s/it]

torch.Size([1, 6890, 3, 401])





In [14]:
smpl_dict_list = glob("/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/generator/var_len/trans_768_768_albi_aist/SMPL_dict/*pt")

In [19]:
smpl_dict = torch.load(smpl_dict_list[0], map_location = "cpu")

In [20]:
smpl_dict["global_orient"].shape

torch.Size([401, 3, 3])

In [21]:
smpl_dict["rotations"].shape

torch.Size([401, 23, 3, 3])

In [24]:
rots = np.concatenate((smpl_dict["global_orient"][:,None,:,:] ,smpl_dict["rotations"] ) , axis = 1).reshape(smpl_dict["global_orient"].shape[0] , -1)

In [34]:
smpl_dict["x_translations"].shape

torch.Size([1, 3, 401])

In [27]:
trans = smpl_dict["x_translations"][0].cpu().numpy().T
trans.shape

(401, 3)

In [31]:
aist = np.concatenate((np.zeros((trans.shape[0] , 6)) , trans , rots) , axis = 1)

In [32]:
aist.shape

(401, 225)