In [1]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import wandb

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader

from torch.utils import data


import copy
import os
import random
import cv2
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
import functools
from tqdm import tqdm
from datetime import datetime
import numpy as np
from core.datasets.vqa_motion_dataset import VQMotionDataset,DATALoader,VQVarLenMotionDataset,MotionCollator,VQFullMotionDataset
from einops import rearrange, reduce, pack, unpack
import librosa

In [4]:
from utils.motion_process import recover_from_ric
import visualize.plot_3d_global as plot_3d
from glob import glob
def to_xyz(motion, mean ,std , j = 22):
    motion_xyz = recover_from_ric(motion.cpu().float()*std+mean, j)
    motion_xyz = motion_xyz.reshape(motion.shape[0],-1, j, 3)
    return motion_xyz

            
def sample_render(motion_xyz , name , save_path):
    print(f"render start")
    
    gt_pose_vis = plot_3d.draw_to_batch(motion_xyz.numpy(),None, [os.path.join(save_path,name + ".gif")])



In [5]:
from configs.config import cfg, get_cfg_defaults
from core.models.vqvae import VQMotionModel
from core.models.motion_regressor import MotionRegressorModel


cfg_vq = get_cfg_defaults()
cfg_vq.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist_vq.yaml")




In [6]:
# reg_model = MotionRegressorModel(args = cfg.motion_trans , ignore_index=1025 ,pad_value=1025 ).eval()

In [7]:
vqvae_model = VQMotionModel(cfg_vq.vqvae).eval()
pkg = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/vqvae_motion_best_fid.pt", map_location = 'cpu')
print(pkg["steps"])
vqvae_model.load_state_dict(pkg["model"])
vqvae_model =vqvae_model.cuda()


tensor([295000.])


In [8]:
collate_fn = MotionCollator()


In [9]:
train_ds = VQVarLenMotionDataset("t2m", split = "render" , max_length_seconds = 10, data_root = "/srv/scratch/sanisetty3/music_motion/HumanML3D/HumanML3D")
train_loader = DATALoader(train_ds,1,collate_fn=collate_fn)

changing range to: 60 - 60


100%|██████████| 10/10 [00:00<00:00, 37.82it/s]

Total number of motions 10





In [10]:
aist_ds = VQVarLenMotionDataset("aist", split = "render" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , num_stages = 6 ,min_length_seconds=20, max_length_seconds=30)
aist_loader = DATALoader(aist_ds,1,collate_fn=collate_fn)

changing range to: 400 - 400


100%|██████████| 8/8 [00:00<00:00, 39.72it/s]

Total number of motions 8





In [11]:
aist_loader.dataset.set_stage(5)

changing range to: 400 - 600


In [12]:
for aist_batch in aist_loader:
    break
aist_batch["motion"].shape

torch.Size([1, 159, 263])

## Trans model

In [5]:
from configs.config import cfg, get_cfg_defaults
from core.models.motion_regressor import MotionRegressorModel


cfg_trans = get_cfg_defaults()
cfg_trans.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist.yaml")



In [6]:
trans_model = MotionRegressorModel(args = cfg_trans.motion_trans,pad_value=1025 ).eval()
pkg_trans = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist/trans_motion_best_fid.pt", map_location = 'cpu')
print(pkg_trans["steps"])
trans_model.load_state_dict(pkg_trans["model"])
trans_model =trans_model.cuda()


tensor([195000.])


## Encode Decode

In [52]:
ind = vqvae_model.encode(aist_batch["motion"].cuda())
print(ind.shape)
quant , out_motion = vqvae_model.decode(ind)

torch.Size([1, 191])


In [39]:
out = torch.empty(aist_batch["motion"].shape)

In [27]:
ind = vqvae_model.encode(aist_batch["motion"][:,:400].cuda())
quant , out_motion = vqvae_model.decode(ind)

In [53]:
quant , out_motion = vqvae_model.decode(ind[:,400:].to(torch.long).cuda())

In [54]:
out[:,400:] = out_motion


In [53]:
sample_render(to_xyz(aist_batch["motion"][0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "rnd_og_motion" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


In [57]:
sample_render(to_xyz(out[0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "rnd_motion_ind_400" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


In [47]:
indices = torch.randint(0,1024,(1,60))
quant , out_motion = vqvae_model.decode(indices.cuda())

In [54]:
sample_render(to_xyz(out_motion[0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "rnd_motion" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


## Music Eval stuff

In [7]:
from utils.motion_process import recover_from_ric
from utils.aist_metrics import calculate_fid_scores
from utils.aist_metrics.calculate_fid_scores import calculate_avg_distance, extract_feature,calculate_frechet_feature_distance,calculate_frechet_distance
from utils.aist_metrics.features import kinetic,manual
from utils.aist_metrics.calculate_beat_scores import motion_peak_onehot,alignment_score

In [8]:
from core.datasets.vqa_motion_dataset import MotionCollatorConditional, TransMotionDatasetConditional,VQMotionDataset,DATALoader,VQVarLenMotionDataset,MotionCollator,VQFullMotionDataset


In [9]:
from utils.eval_music import evaluate_music_motion_vqvae, evaluate_music_motion_generative,evaluate_music_motion_trans

In [14]:
for aist_batch in tqdm(aist_loader):
    break

  0%|          | 0/40 [00:00<?, ?it/s]


### Const len trained transformer

In [15]:
from core.models.motion_regressor import MotionRegressorModel

trans_model = MotionRegressorModel(args = cfg_trans.motion_trans , ignore_index=1025 ,pad_value=1025 ).eval()
pkg_trans = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/const_len/trans_768_768_aist/vqvae_motion.pt", map_location = 'cpu')
print(pkg_trans["steps"])
trans_model.load_state_dict(pkg_trans["model"])
trans_model =trans_model.cuda()


tensor([85000.])


In [39]:
0.23/0.243 * 0.292

0.2763786008230453

## Evaluate Music Motion transformer

In [8]:
encodec = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist/var_len_768_768_aist.yaml"
encodec_sine = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_sine_aist/var_len_768_768_sine_aist.yaml"
librosa = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist_35/var_len_768_768_aist_35.yaml"
encodec_prob50 = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist_mask_prob50/trans_768_768_albi_aist_mask_prob50.yaml"


In [17]:
from configs.config import cfg, get_cfg_defaults
from core.models.motion_regressor import MotionRegressorModel


trans_option = "encodec_sine"

cfg_trans = get_cfg_defaults()
cfg_trans.merge_from_file(encodec_sine)


trans_model = MotionRegressorModel(args = cfg_trans.motion_trans,pad_value=1025 ).eval()
pkg_trans = torch.load(f"{os.path.dirname(encodec_sine)}/trans_motion_best_fid.pt", map_location = 'cpu')
print(pkg_trans["steps"])
trans_model.load_state_dict(pkg_trans["model"])
trans_model =trans_model.cuda()


tensor([145000.])


In [18]:
audio_encoding_dir = "/srv/scratch/sanisetty3/music_motion/AIST/music"
audio_features_dir = "/srv/scratch/sanisetty3/music_motion/AIST/audio_features/"
use35 = False
if trans_option == "librosa":
    audio_encoding_dir = audio_features_dir
    use35 = True

In [19]:
use35

False

## Evaluate Music Motion Generative

In [12]:
from utils.eval_music import evaluate_music_motion_vqvae, evaluate_music_motion_generative,evaluate_music_motion_generative2,evaluate_music_motion_trans

In [40]:
aist_ds = VQFullMotionDataset("aist", split = "test" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 40/40 [00:00<00:00, 1700.32it/s]

Total number of motions 40





In [49]:
from utils.eval_music import evaluate_music_motion_generative
print("pretrained mix")
best_fid_k = []
best_fid_g = []
best_div_k = []
best_div_g = []
best_beat_align = []

for i in range(1):

    a,b,c,d,e = evaluate_music_motion_generative(aist_loader , vqvae_model= vqvae_model ,net = trans_model,use35=use35)
    best_fid_k.append(a)
    best_fid_g.append(b)
    best_div_k.append(c)
    best_div_g.append(d)
    best_beat_align.append(e)

    
print("best_fid_k" , np.mean(best_fid_k))
print("best_fid_g" , np.mean(best_fid_g))
print("best_div_k" , np.mean(best_div_k))
print("best_div_g" , np.mean(best_div_g))
print("best_beat_align" , np.mean(best_beat_align))



  2%|▏         | 7/400 [00:00<00:06, 64.18it/s]

pretrained mix


100%|██████████| 400/400 [00:06<00:00, 60.42it/s]
100%|██████████| 400/400 [00:06<00:00, 60.73it/s]
100%|██████████| 400/400 [00:06<00:00, 60.44it/s]
100%|██████████| 400/400 [00:06<00:00, 60.62it/s]
100%|██████████| 400/400 [00:06<00:00, 60.56it/s]
100%|██████████| 400/400 [00:06<00:00, 60.68it/s]
100%|██████████| 400/400 [00:06<00:00, 60.08it/s]
100%|██████████| 400/400 [00:06<00:00, 59.90it/s]
100%|██████████| 400/400 [00:06<00:00, 60.57it/s]
100%|██████████| 400/400 [00:06<00:00, 60.68it/s]
100%|██████████| 400/400 [00:06<00:00, 60.56it/s]
100%|██████████| 400/400 [00:06<00:00, 60.67it/s]
100%|██████████| 400/400 [00:06<00:00, 60.68it/s]
100%|██████████| 400/400 [00:06<00:00, 60.73it/s]
100%|██████████| 400/400 [00:06<00:00, 60.83it/s]
100%|██████████| 400/400 [00:06<00:00, 60.69it/s]
100%|██████████| 400/400 [00:06<00:00, 60.64it/s]
100%|██████████| 400/400 [00:06<00:00, 60.73it/s]
100%|██████████| 400/400 [00:06<00:00, 60.77it/s]
100%|██████████| 400/400 [00:06<00:00, 60.50it/s]


FID_k:  6.403103264620739 Diversity_k: 10.598105856088491
FID_g:  11.891712829203748 Diversity_g: 7.263086085441785

Beat score on real data: 0.244


Beat score on generated data: 0.212

\PFC score on real data: 2.113

\PFC score on generated data: 2.808

best_fid_k 6.403103264620739
best_fid_g 11.891712829203748
best_div_k 10.598105856088491
best_div_g 7.263086085441785
best_beat_align 0.24413006020223502





In [None]:
FID_k:  5.848359006888529 Diversity_k: 9.930449794347469
FID_g:  10.19790933574938 Diversity_g: 7.301737963236295

Beat score on real data: 0.244


Beat score on generated data: 0.233

\PFC score on real data: 2.113

\PFC score on generated data: 1.169

In [13]:
aist_ds_train = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader_train = DATALoader(aist_ds_train,1,collate_fn=None)

100%|██████████| 1910/1910 [00:01<00:00, 1568.45it/s]

Total number of motions 1910





In [20]:

print("pretrained mix")
best_fid_k = []
best_fid_g = []
best_div_k = []
best_div_g = []
best_beat_align = []
seq_len = 900
for i in range(1):

    a,b,c,d,e,f,g = evaluate_music_motion_generative2(aist_loader_train , vqvae_model= vqvae_model ,net = trans_model, use35 = use35,seq_len = seq_len)
    best_fid_k.append(a)
    best_fid_g.append(b)
    best_div_k.append(c)
    best_div_g.append(d)
    best_beat_align.append(e)

    
print("best_fid_k" , np.mean(best_fid_k))
print("best_fid_g" , np.mean(best_fid_g))
print("best_div_k" , np.mean(best_div_k))
print("best_div_g" , np.mean(best_div_g))
print("best_beat_align" , np.mean(best_beat_align))



  1%|          | 7/900 [00:00<00:14, 62.63it/s]

pretrained mix


100%|██████████| 900/900 [00:20<00:00, 42.96it/s]
100%|██████████| 900/900 [00:20<00:00, 43.03it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.70it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.74it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.69it/s]it]
100%|██████████| 900/900 [00:21<00:00, 42.81it/s]it]
100%|██████████| 900/900 [00:21<00:00, 42.37it/s]]  
100%|██████████| 900/900 [00:21<00:00, 42.73it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.74it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.77it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.46it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.69it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.79it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.13it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.60it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.56it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.75it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.85it/s]]
100%|██████████| 900/900 [00:21<00:00, 42.64it/s]t]
100%|██████████| 900/900 

FID_k:  6.394353653410235 Diversity_k: 10.558077923262992
FID_g:  8.632308692625102 Diversity_g: 7.294414699949869
\PFC score on real data: 1.677

\PFC score on generated data: 1.239


Beat score on real data: 0.170


Beat score on generated data: 0.164

best_fid_k 6.394353653410235
best_fid_g 8.632308692625102
best_div_k 10.558077923262992
best_div_g 7.294414699949869
best_beat_align 0.17011690074139812





In [33]:
0.243*0.169 / 0.182

0.22564285714285715

In [None]:
Sine 800

FID_k:  3.3585897511335645 Diversity_k: 10.084781883693323
FID_g:  13.196066750688807 Diversity_g: 7.4867808908950995
\PFC score on real data: 1.910

\PFC score on generated data: 0.598


Beat score on real data: 0.180


Beat score on generated data: 0.178

In [None]:
FID_k:  5.514617609935726 Diversity_k: 10.241459463863839
FID_g:  14.856035117858639 Diversity_g: 6.822329583618699
\PFC score on real data: 2.361

\PFC score on generated data: 0.116

## Sinusoidal pos emb

FID_k:  3.8717503038434415 Diversity_k: 10.580599220288105
FID_g:  14.323117971786502 Diversity_g: 7.208645957555526

Beat score on real data: 0.244


Beat score on generated data: 0.196

### Albi

FID_k:  5.974822501678858 Diversity_k: 9.894211104435799
FID_g:  10.945797702730552 Diversity_g: 7.33098736114991

Beat score on real data: 0.244


Beat score on generated data: 0.207

## Style motion generative

In [59]:
from utils.eval_music import evaluate_music_motion_generative_style, evaluate_music_motion_generative_style2

In [None]:
aist_ds = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)


In [56]:
aist_ds = VQFullMotionDataset("aist", split = "test" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 40/40 [00:00<00:00, 40.00it/s]

Total number of motions 40





In [52]:
import clip
clip_model, clip_preprocess = clip.load("ViT-B/32", device=torch.device('cuda'), jit=False)  # Must set jit=False for training
clip_model.eval()
for p in clip_model.parameters():
    p.requires_grad = False


In [51]:
encodec_style = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/generator/var_len/trans_768_768_albi_aist_style/var_len_768_768_aist_style.yaml"


In [57]:
from configs.config import cfg, get_cfg_defaults
from core.models.motion_regressor import MotionRegressorModel


cfg_trans = get_cfg_defaults()
cfg_trans.merge_from_file(encodec_style)



trans_model = MotionRegressorModel(args = cfg_trans.motion_trans,pad_value=1025 ).eval()
pkg_trans = torch.load(f"{os.path.dirname(encodec_style)}/trans_motion.pt", map_location = 'cpu')
print(pkg_trans["steps"])
trans_model.load_state_dict(pkg_trans["model"])
trans_model =trans_model.cuda()


tensor([70000.])


In [58]:
evaluate_music_motion_generative_style(aist_loader , vqvae_model= vqvae_model ,net = trans_model,clip_model=clip_model,seq_len = 800)


100%|██████████| 800/800 [00:19<00:00, 40.98it/s]
100%|██████████| 800/800 [00:19<00:00, 41.33it/s]
100%|██████████| 800/800 [00:19<00:00, 40.70it/s]
100%|██████████| 800/800 [00:19<00:00, 41.77it/s]
100%|██████████| 800/800 [00:19<00:00, 41.00it/s]
100%|██████████| 800/800 [00:19<00:00, 41.29it/s]
100%|██████████| 800/800 [00:19<00:00, 41.86it/s]
100%|██████████| 800/800 [00:19<00:00, 41.85it/s]
100%|██████████| 800/800 [00:19<00:00, 41.46it/s]
100%|██████████| 800/800 [00:19<00:00, 41.29it/s]
100%|██████████| 800/800 [00:19<00:00, 41.94it/s]
100%|██████████| 800/800 [00:19<00:00, 41.01it/s]
100%|██████████| 800/800 [00:19<00:00, 41.38it/s]
100%|██████████| 800/800 [00:19<00:00, 41.67it/s]
100%|██████████| 800/800 [00:19<00:00, 41.45it/s]
100%|██████████| 800/800 [00:19<00:00, 41.13it/s]
100%|██████████| 800/800 [00:19<00:00, 40.22it/s]
100%|██████████| 800/800 [00:20<00:00, 39.98it/s]
100%|██████████| 800/800 [00:19<00:00, 41.99it/s]
100%|██████████| 800/800 [00:18<00:00, 42.36it/s]


FID_k:  4.266941177652797 Diversity_k: 10.405532233531659
FID_g:  12.157996814590398 Diversity_g: 7.164521244856028

Beat score on real data: 0.244


Beat score on generated data: 0.183






(4.266941177652797,
 12.157996814590398,
 10.405532233531659,
 7.164521244856028,
 0.24413006020223502)

In [60]:
aist_ds = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 1910/1910 [00:22<00:00, 84.98it/s] 

Total number of motions 1910





In [61]:
evaluate_music_motion_generative_style2(aist_loader , vqvae_model= vqvae_model ,net = trans_model,clip_model=clip_model,seq_len = 800)


100%|██████████| 800/800 [00:18<00:00, 43.81it/s]
100%|██████████| 800/800 [00:18<00:00, 43.57it/s]]
100%|██████████| 800/800 [00:18<00:00, 43.35it/s]t]
100%|██████████| 800/800 [00:18<00:00, 43.43it/s]t]
100%|██████████| 800/800 [00:18<00:00, 43.57it/s]t]
100%|██████████| 800/800 [00:18<00:00, 43.32it/s]t]
100%|██████████| 800/800 [00:18<00:00, 43.45it/s]t]
100%|██████████| 800/800 [00:18<00:00, 43.29it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.91it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.63it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.25it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.44it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.56it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.40it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.53it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.26it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.35it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.34it/s]it]
100%|██████████| 800/800 [00:18<00:00, 43.71it/s]it]
100

FID_k:  3.755731383900468 Diversity_k: 10.429904168698846
FID_g:  7.751290849860666 Diversity_g: 7.238621158134647

Beat score on real data: 0.179


Beat score on generated data: 0.174






(3.755731383900468,
 7.751290849860666,
 10.429904168698846,
 7.238621158134647,
 0.17897518020423844)

In [62]:
0.243*0.174 / 0.179

0.23621229050279327

## Render

In [7]:
audio_encoding_dir = "/srv/scratch/sanisetty3/music_motion/AIST/music"

genre_dict = {
"mBR" : "Break",
"mPO" : "Pop",
"mLO" : "Lock",
"mMH" : "Middle Hip-hop",
"mLH" : "LA style Hip-hop",
"mHO" : "House",    
"mWA" : "Waack",
"mKR" : "Krump",
"mJS" : "Street Jazz",
"mJB" : "Ballet Jazz",
}



In [42]:
style = None

In [49]:
for i,aist_batch in enumerate(tqdm(aist_loader)):
    break
motion_name = aist_batch["names"][0]

music_name = motion_name.split('_')[-2]
music_encoding=  np.load(os.path.join(audio_encoding_dir , music_name + ".npy"))

print(genre_dict.get(music_name[:3]))

  0%|          | 0/40 [00:00<?, ?it/s]

Ballet Jazz





In [50]:


mot_len = aist_batch["motion_lengths"][0]
motion_name = aist_batch["names"][0]

music_name = motion_name.split('_')[-2]
music_encoding=  np.load(os.path.join(audio_encoding_dir , music_name + ".npy"))

print(genre_dict.get(music_name[:3]))

genre = (genre_dict.get(music_name[:3])) if style is None else style

text = clip.tokenize([genre], truncate=True).cuda()
style_embeddings = clip_model.encode_text(text).cpu().float().reshape(-1) if clip_model is not None else None
gen_motion_indices = torch.randint(0 , 1024 , (1,1))
gen_motion_indices = trans_model.generate(start_tokens =gen_motion_indices.cuda(),\
                                        seq_len=400 , \
                                        context = torch.Tensor(music_encoding)[None,...].cuda(), \
                                        context_mask=torch.ones((1 ,music_encoding.shape[0]) , dtype = torch.bool).cuda(),\
                                        style_context = torch.Tensor(style_embeddings.reshape(-1))[None,...].cuda(),
                                        )
gen_motion_indices = gen_motion_indices[gen_motion_indices<1024][None,...]

quant , out_motion = vqvae_model.decode(gen_motion_indices)

Ballet Jazz


100%|██████████| 400/400 [00:07<00:00, 56.84it/s]


In [52]:
aist_batch["motion"].shape

torch.Size([1, 147, 263])

In [44]:
genre

'Slow'

In [19]:
sample_render(to_xyz(aist_batch["motion"][0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "style_gt" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


In [51]:
sample_render(to_xyz(out_motion[:,:mot_len].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "style_none" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/style/")

render start


### Music VQ

In [13]:

load_path_mix = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/checkpoints/vqvae_motion.295000.pt"
load_path_hml = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768/vqvae_motion.pt"
load_path_aist = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_aist/vqvae_motion.pt"


In [14]:
from configs.config import cfg, get_cfg_defaults
from core.models.vqvae import VQMotionModel
from utils.eval_music import evaluate_music_motion_vqvae

cfg_vq = get_cfg_defaults()
cfg_vq.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/var_len_768_768_aist_vq.yaml")

load_path = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/vqvae_motion_best_fid.pt"


In [15]:

vqvae_model = VQMotionModel(cfg_vq.vqvae).eval()
pkg = torch.load(f"{load_path_mix}", map_location = 'cpu')
print(pkg["steps"])
vqvae_model.load_state_dict(pkg["model"])
vqvae_model =vqvae_model.cuda()



tensor([295000.])


In [17]:
aist_ds = VQFullMotionDataset("aist", split = "test" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 40/40 [00:00<00:00, 1496.39it/s]

Total number of motions 40





In [21]:
print("pretrained only t2m")
evaluate_music_motion_vqvae(aist_loader,vqvae_model)

pretrained only t2m


100%|██████████| 40/40 [00:41<00:00,  1.04s/it]

FID_k:  3.2996737750179364 Diversity_k: 10.26604298215646
FID_g:  10.78302279919913 Diversity_g: 7.181474344852643
FID_k_real:  -7.86550347697812e-06 Diversity_k_real: 10.195780532558759
FID_g_real:  -1.9184653865522705e-13 Diversity_g_real: 7.348854861503992

Beat score on real data: 0.245


Beat score on generated data: 0.176






(3.2996737750179364,
 10.78302279919913,
 10.26604298215646,
 7.181474344852643,
 0.24494051462936942)

In [18]:

# print("pretrained mix")
# best_fid_k = []
# best_fid_g = []
# best_div_k = []
# best_div_g = []
# best_beat_align = []

# for i in range(20):

#     a,b,c,d,e = evaluate_music_motion_vqvae(aist_loader,vqvae_model)
#     best_fid_k.append(a)
#     best_fid_g.append(b)
#     best_div_k.append(c)
#     best_div_g.append(d)
#     best_beat_align.append(e)

    
# print("best_fid_k" , np.mean(best_fid_k))
# print("best_fid_g" , np.mean(best_fid_g))
# print("best_div_k" , np.mean(best_div_k))
# print("best_div_g" , np.mean(best_div_g))
# print("best_beat_align" , np.mean(best_beat_align))

In [19]:
print("mix")
evaluate_music_motion_vqvae(aist_loader,vqvae_model)


mix


100%|██████████| 40/40 [00:39<00:00,  1.00it/s]

FID_k:  2.635362356342995 Diversity_k: 10.163608189500295
FID_g:  7.295345718653849 Diversity_g: 7.234946262225127
FID_k_real:  -7.757004908626186e-06 Diversity_k_real: 10.205963216454554
FID_g_real:  -1.903529778246593e-09 Diversity_g_real: 7.344472836225461

Beat score on real data: 0.244


Beat score on generated data: 0.234






(2.635362356342995,
 7.295345718653849,
 10.163608189500295,
 7.234946262225127,
 0.244130060202235)

In [100]:
print("pretrained only t2m, finetuned only aist")
evaluate_music_motion_vqvae(aist_loader,vqvae_model)

pretrained only t2m, finetuned only aist


  5%|▌         | 2/40 [00:01<00:37,  1.01it/s]


KeyboardInterrupt: 

In [28]:
### Mixture
evaluate_music_motion_vqvae(aist_loader,vqvae_model)

100%|██████████| 1910/1910 [29:56<00:00,  1.06it/s]


FID_k:  0.010750366559051372 Diversity_k: 9.172959109891172
FID_g:  1.2350136226828567 Diversity_g: 7.381343867734843

Beat score on real data: 0.249


Beat score on generated data: 0.250



(0.010750366559051372,
 1.2350136226828567,
 9.172959109891172,
 7.381343867734843,
 0.24940255611332512)

## Generating token dataset

In [10]:
(0.11*255, 0.53*255, 0.8*255, 0.5*255)

(28.05, 135.15, 204.0, 127.5)

In [64]:
aist_ds = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 1910/1910 [00:01<00:00, 1244.08it/s]

Total number of motions 1910





In [67]:
for batch in tqdm(aist_loader):
    
    n = int(batch["motion_lengths"])
    name = str(batch["names"][0])
    if n< 400:
        ind = vqvae_model.encode(batch["motion"].cuda())
    else:
#         ind = vqvae_model.encode(batch["motion"][:,:400].cuda())
        inds = []
        for i in range(0 , n, 200):
            ii = vqvae_model.encode(batch["motion"][:,i:i+200].cuda())
            inds.append(ii[0])
#             print(ii.shape)
        
        ind = torch.concatenate(inds)[None,...]
        
#     print(ind.shape)
    
    np.save(os.path.join("/srv/scratch/sanisetty3/music_motion/AIST/joint_indices" , name+".npy"),ind.cpu().numpy()[0])
        
#         quant , out_motion = vqvae_model.decode(ind)
    

100%|██████████| 1910/1910 [01:47<00:00, 17.82it/s]


In [68]:
mot_list = glob("/srv/scratch/sanisetty3/music_motion/AIST/joint_indices/*.npy")

In [69]:
np.load(mot_list[0]).shape[0]

174

In [71]:
lens = []
for i in mot_list:
    lens.append(np.load(i).shape[0])

In [72]:
max(lens)

959

In [11]:
hlm_ds = VQFullMotionDataset("t2m", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/HumanML3D/HumanML3D/" , window_size = -1)
hlm_loader = DATALoader(hlm_ds,1,collate_fn=None)

100%|██████████| 23384/23384 [08:53<00:00, 43.85it/s] 

Total number of motions 23384





In [12]:
for batch in tqdm(hlm_loader):
    break

  0%|          | 0/23384 [00:00<?, ?it/s]


In [13]:
n = int(batch["motion_lengths"])
name = str(batch["names"][0])
print(n,name)

199 M003397


In [14]:
ind = vqvae_model.encode(batch["motion"].cuda())

In [15]:
ind.shape

torch.Size([1, 199])

In [16]:
for batch in tqdm(hlm_loader):
    
    n = int(batch["motion_lengths"])
    name = str(batch["names"][0])
    if n< 400:
        ind = vqvae_model.encode(batch["motion"].cuda())
    else:
        #ind = vqvae_model.encode(batch["motion"][:,:400].cuda())
#         out_motion = torch.zeros((batch["motion"].shape[0] ,gen_motion_indices.shape[-1] , aist_batch["motion"].shape[-1]))
        inds = []
        for i in range(0 , n, 200):
            inds.append(vqvae_model.encode(batch["motion"].cuda()))
        
        ind = torch.stack(inds)
    
    np.save(os.path.join("/srv/scratch/sanisetty3/music_motion/HumanML3D/HumanML3D/joint_indices" , name+".npy"), ind.cpu().numpy()[0])
    
        
    

  0%|          | 0/23384 [00:00<?, ?it/s]


FileNotFoundError: [Errno 2] No such file or directory: '/srv/scratch/sanisetty3/music_motion/HumanML3D/HumanML3D/joint_indices/007648.npy'

In [20]:
sample_render(to_xyz(batch["motion"][0:1].detach().cpu(),mean = hlm_ds.mean , std = hlm_ds.std), "rnd_motion" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


## T2M Eval

In [6]:
import utils.utils_model as utils_model
from core.datasets import dataset_TM_eval
import utils.eval_trans as eval_trans
from core.models.evaluator_wrapper import EvaluatorModelWrapper
from utils.word_vectorizer import WordVectorizer
from utils.eval_trans import evaluation_vqvae_loss,evaluation_vqvae
from utils.eval_trans import calculate_R_precision,calculate_activation_statistics,calculate_diversity,calculate_frechet_distance
from tqdm import tqdm

In [7]:
w_vectorizer = WordVectorizer('/srv/scratch/sanisetty3/music_motion/T2M-GPT/glove', 'our_vab')
eval_wrapper = EvaluatorModelWrapper(cfg.eval_model)
tm_eval = dataset_TM_eval.DATALoader("t2m", True, 20, w_vectorizer, unit_length=4)


Loading Evaluation Model Wrapper (Epoch 28) Completed!!


100%|██████████| 4384/4384 [02:23<00:00, 30.56it/s]

4648 4648
Pointer Pointing at 0





In [15]:
from configs.config import cfg, get_cfg_defaults
from core.models.vqvae import VQMotionModel
from core.models.motion_regressor import MotionRegressorModel

load_path = "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_aist/vqvae_motion.pt"
cfg_vq = get_cfg_defaults()
cfg_vq.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist_vq.yaml")

vqvae_model = VQMotionModel(cfg_vq.vqvae).eval()
pkg = torch.load(f"{load_path}", map_location = 'cpu')
print(pkg["steps"])
vqvae_model.load_state_dict(pkg["model"])
vqvae_model =vqvae_model.cuda()



tensor([275000.])


In [13]:
### Pretrained on t2m only
metrics = evaluation_vqvae_loss(val_loader = tm_eval, net= vqvae_model,nb_iter= 0, eval_wrapper = eval_wrapper,save = False,)

100%|██████████| 232/232 [00:43<00:00,  5.31it/s]


--> 	 Eva. Iter 0 :, FID. 0.0668, Diversity Real. 9.5584, Diversity. 9.9187, R_precision_real. [0.60193966 0.78189655 0.86810345], R_precision. [0.59439655 0.77801724 0.85991379], matching_score_real. 2.9862875124503825, matching_score_pred. 3.028119134902954


In [16]:
print("pretrained on t2m only and finetuned on aist")
metrics = evaluation_vqvae_loss(val_loader = tm_eval, net= vqvae_model,nb_iter= 0, eval_wrapper = eval_wrapper,save = False,)

pretrained on t2m only and finetuned on aist


100%|██████████| 232/232 [00:43<00:00,  5.30it/s]


--> 	 Eva. Iter 0 :, FID. 3.2204, Diversity Real. 9.3818, Diversity. 7.4288, R_precision_real. [0.59698276 0.78728448 0.86551724], R_precision. [0.43512931 0.64073276 0.75625   ], matching_score_real. 2.9870332890543443, matching_score_pred. 4.043809700012207


In [37]:
## Pretrained on a mix of aist and t2m
metrics = evaluation_vqvae_loss(val_loader = tm_eval, net= vqvae_model,nb_iter= 0, eval_wrapper = eval_wrapper,save = False,)

100%|██████████| 232/232 [00:42<00:00,  5.50it/s]


--> 	 Eva. Iter 0 :, FID. 0.0637, Diversity Real. 9.4620, Diversity. 9.4266, R_precision_real. [0.61616379 0.79482759 0.86982759], R_precision. [0.60172414 0.78405172 0.86077586], matching_score_real. 2.9635313979510602, matching_score_pred. 3.0240239735307366
