In [1]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import wandb

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader

from torch.utils import data


import copy
import os
import random
import cv2
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
import functools
from tqdm import tqdm
from datetime import datetime
import numpy as np
from core.datasets.vqa_motion_dataset import VQMotionDataset,DATALoader,VQVarLenMotionDataset,MotionCollator,VQFullMotionDataset
from einops import rearrange, reduce, pack, unpack
import librosa

In [4]:
from utils.motion_process import recover_from_ric
import visualize.plot_3d_global as plot_3d
from glob import glob
def to_xyz(motion, mean ,std , j = 22):
    motion_xyz = recover_from_ric(motion.cpu().float()*std+mean, j)
    motion_xyz = motion_xyz.reshape(motion.shape[0],-1, j, 3)
    return motion_xyz

            
def sample_render(motion_xyz , name , save_path):
    print(f"render start")
    
    gt_pose_vis = plot_3d.draw_to_batch(motion_xyz.numpy(),None, [os.path.join(save_path,name + ".gif")])



In [5]:
from configs.config import cfg, get_cfg_defaults
from core.models.vqvae import VQMotionModel
from core.models.motion_regressor import MotionRegressorModel


cfg = get_cfg_defaults()
cfg.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist_vq.yaml")



In [6]:
reg_model = MotionRegressorModel(args = cfg.motion_trans , ignore_index=1025 ,pad_value=1025 ).eval()

In [6]:
vqvae_model = VQMotionModel(cfg.vqvae).eval()

In [7]:
vqvae_model = VQMotionModel(cfg.vqvae).eval()
pkg = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/vqvae_motion_best_fid.pt", map_location = 'cpu')
print(pkg["steps"])
vqvae_model.load_state_dict(pkg["model"])
vqvae_model =vqvae_model.cuda()


tensor([295000.])


In [43]:
collate_fn = MotionCollator()


In [9]:
train_ds = VQVarLenMotionDataset("t2m", split = "render" , max_length_seconds = 10, data_root = "/srv/scratch/sanisetty3/music_motion/HumanML3D/HumanML3D")
train_loader = DATALoader(train_ds,1,collate_fn=collate_fn)

changing range to: 60 - 60


100%|██████████| 10/10 [00:00<00:00, 1169.54it/s]

Total number of motions 10





In [44]:
aist_ds = VQVarLenMotionDataset("aist", split = "render" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , num_stages = 6 ,min_length_seconds=20, max_length_seconds=30)
aist_loader = DATALoader(aist_ds,1,collate_fn=collate_fn)

changing range to: 400 - 400


100%|██████████| 8/8 [00:00<00:00, 2111.27it/s]

Total number of motions 8





In [45]:
aist_loader.dataset.set_stage(5)

changing range to: 400 - 600


In [51]:
for aist_batch in aist_loader:
    break
aist_batch["motion"].shape

torch.Size([1, 191, 263])

In [None]:
#gBR_sFM_cAll_d06_mBR3_ch17 589

In [26]:
aist_batch["names"]

array(['gBR_sFM_cAll_d06_mBR3_ch17'], dtype='<U26')

In [142]:
for batch in train_loader:
    break

## Encode Decode

In [52]:
ind = vqvae_model.encode(aist_batch["motion"].cuda())
print(ind.shape)
quant , out_motion = vqvae_model.decode(ind)

torch.Size([1, 191])


In [39]:
out = torch.empty(aist_batch["motion"].shape)

In [27]:
ind = vqvae_model.encode(aist_batch["motion"][:,:400].cuda())
quant , out_motion = vqvae_model.decode(ind)

In [53]:
quant , out_motion = vqvae_model.decode(ind[:,400:].to(torch.long).cuda())

In [54]:
out[:,400:] = out_motion


In [53]:
sample_render(to_xyz(aist_batch["motion"][0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "rnd_og_motion" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


In [57]:
sample_render(to_xyz(out[0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "rnd_motion_ind_400" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


In [47]:
indices = torch.randint(0,1024,(1,60))
quant , out_motion = vqvae_model.decode(indices.cuda())

In [54]:
sample_render(to_xyz(out_motion[0:1].detach().cpu(),mean = aist_ds.mean , std = aist_ds.std), "rnd_motion" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


## Music Eval stuff

In [55]:
from utils.motion_process import recover_from_ric
from aist_features import calculate_fid_scores
from aist_features.calculate_fid_scores import calculate_avg_distance, extract_feature,calculate_frechet_feature_distance,calculate_frechet_distance
from aist_features.features import kinetic,manual
from aist_features.calculate_beat_scores import motion_peak_onehot,alignment_score

In [9]:
from core.datasets.vqa_motion_dataset import VQMotionDataset,DATALoader,VQVarLenMotionDataset,MotionCollator,VQFullMotionDataset


In [28]:
from utils.eval_music import evaluate_music_motion_vqvae

In [55]:
(batch["motion_lengths"]).shape

(1,)

In [11]:
aist_ds = VQFullMotionDataset("aist", split = "train" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = 400)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 1910/1910 [00:01<00:00, 1282.88it/s]

Total number of motions 1910





In [184]:
for aist_batch in tqdm(aist_loader):
    break

  0%|          | 0/1910 [00:00<?, ?it/s]


In [29]:
evaluate_music_motion_vqvae(aist_loader,vqvae_model)

100%|██████████| 1910/1910 [46:44<00:00,  1.47s/it] 


FID_k:  0.028328037164357056 Diversity_k: 9.260323240545652
FID_g:  1.2220429949585352 Diversity_g: 7.300148475695237

Beat score on real data: 0.250


Beat score on generated data: 0.246



(0.028328037164357056, 1.2220429949585352, 100, 100, 100)

## Generating token dataset

In [50]:
aist_ds = VQFullMotionDataset("aist", split = "render" , data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , window_size = -1)
aist_loader = DATALoader(aist_ds,1,collate_fn=None)

100%|██████████| 8/8 [00:00<00:00, 1513.10it/s]

Total number of motions 8





In [41]:
for batch in tqdm(aist_loader):
    
    n = int(batch["motion_lengths"])
    name = str(batch["names"][0])
    if n< 400:
        ind = vqvae_model.encode(batch["motion"].cuda())
    else:
        ind = vqvae_model.encode(batch["motion"][:,:400].cuda())
    
    np.save(os.path.join("/srv/scratch/sanisetty3/music_motion/AIST/joint_indices_max_400" , name+".npy"),ind.cpu().numpy()[0])
        
#         quant , out_motion = vqvae_model.decode(ind)
    

100%|██████████| 8/8 [00:00<00:00, 14.90it/s]


## Motion generation

In [7]:
from core.datasets.vqa_motion_dataset import MotionCollatorConditional,VQVarLenMotionDatasetConditional, TransMotionDatasetConditional
cfg = get_cfg_defaults()
cfg.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist_vq.yaml")



In [7]:
# pkg2 = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_aist/vqvae_motion.pt", map_location = 'cpu')
# print(pkg2["steps"])
# vqvae_model_prev.load_state_dict(pkg2["model"])
# vqvae_model_prev =vqvae_model_prev


In [8]:
from core.models.motion_regressor import MotionRegressorModel
trans_model = MotionRegressorModel(args = cfg.motion_trans , ignore_index=1025 ,pad_value=1025 ).eval()


In [9]:
# pkg3 = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/const_len/", map_location = 'cpu')
# trans_model.load_state_dict(pkg3["model"])


In [10]:
vqvae_model = VQMotionModel(cfg.vqvae).eval()
pkg = torch.load(f"/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/var_len/vq_768_768_mix/vqvae_motion_best_fid.pt", map_location = 'cpu')
print(pkg["steps"])
vqvae_model.load_state_dict(pkg["model"])
vqvae_model =vqvae_model.cuda()


tensor([295000.])


In [78]:
train_ds = TransMotionDatasetConditional("aist", split = "render",data_root = "/srv/scratch/sanisetty3/music_motion/AIST" , datafolder="joint_indices_max_400", window_size = 400,force_len=True)


100%|██████████| 8/8 [00:00<00:00, 1076.05it/s]

Total number of motions 8





In [79]:
collate_fn2 = MotionCollatorConditional(dataset_name = "aist" , bos = 1024, pad = 1025, eos = 1026)


In [80]:
dl = DATALoader(train_ds , batch_size = 1,collate_fn=collate_fn2)


In [81]:
for reg_batch in dl:
    break

In [82]:
for k,v in reg_batch.items():
    print(k , v.shape)

motion torch.Size([1, 240])
motion_lengths torch.Size([1])
motion_mask torch.Size([1, 240])
names (1,)
condition torch.Size([1, 239, 128])
condition_mask torch.Size([1, 239])


In [83]:
inp, target = reg_batch["motion"][:, :-1], reg_batch["motion"][:, 1:]

In [84]:
quant , out_motion = vqvae_model.decode(target[target<1024][None,...].cuda())

In [85]:
sample_render(to_xyz(out_motion[0:1].detach().cpu(),mean = train_ds.mean , std = train_ds.std), "rnd_og_motion" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/evals/decode_test")

render start


In [88]:
inp[:,:10]

tensor([[234, 587, 137,  19, 115, 182, 117,  29,   0, 113]])

In [89]:
gen_motion_indices = trans_model.generate(start_tokens = inp[:,:1], seq_len=100 , context = reg_batch["condition"], context_mask = reg_batch["condition_mask"])

100%|██████████| 100/100 [00:17<00:00,  5.79it/s]


In [90]:
gen_motion_indices

tensor([[ 234,  851,  496,  413,  997,  505,  681,  796,  585,  158,  413,  264,
          250,  850,  524,  244,  157,   74,  270,  844,   72,  817,  468,  740,
          851,  806,  608,  263,  924,  740,  845,  924,  608,  960,  314,  700,
          352,  804,  157,   82,  846,   40,  616,   97,  967,  650,  162,  997,
          740,  314,  851,  616,  180,  966,  250,  314,  642,  960,   52,  779,
          393,  665,  496,  772,  264,  270,  426,  524,  920,  608,  817,  772,
          806,  208,  997,  585,  214,  287,  560,  817,  844,  112,  833,   72,
          985,  848,  800,  266,  264,  266,  814,  985,  740,  180,  496,  985,
          112, 1022,  353,   97,   97]])

In [32]:
logits = trans_model(motion = inp , mask = reg_batch["motion_mask"][:,:-1]  , \
        context = batch["condition"], context_mask = reg_batch["condition_mask"])

In [56]:
logits.shape

torch.Size([1, 213, 1027])

In [57]:
ce = torch.nn.CrossEntropyLoss(ignore_index=1025)

In [58]:
(reg_batch["motion_lengths"])

tensor([213.])

In [59]:
probs = torch.softmax(logits[0][:int(reg_batch["motion_lengths"][0])], dim=-1)

In [60]:
probs.shape

torch.Size([213, 1027])

In [61]:
_ ,indx =torch.max(probs, dim=-1)


In [63]:
sum(indx == target[0][:int(reg_batch["motion_lengths"][0])].flatten(0))

tensor(0)

In [64]:
dist = torch.distributions.Categorical(probs)

In [67]:
target[0][:int(reg_batch["motion_lengths"][0])][target[0][:int(reg_batch["motion_lengths"][0])] < 1024]

tensor([614, 672,  90, 735, 735,  40, 107, 349, 349,  51, 786, 258, 173, 246,
         12, 216, 463, 303, 114, 114, 140, 593, 905, 278, 278, 358, 212, 140,
        216, 252, 472, 472, 593, 675, 675,  51, 472, 553, 135, 221, 216, 205,
        252, 463, 419, 586, 114,  41,  58, 104, 612, 903, 303, 246, 540, 205,
        252, 634,  40,  59, 240, 240,  51, 258, 258, 757, 840,  12, 840, 252,
        463, 485, 114, 840, 593, 757, 612, 903, 405, 787, 212, 205, 430, 485,
        699, 886,  51, 349, 349,  59, 472, 553, 135, 409, 882, 205, 303, 463,
        787, 988, 114, 268, 258, 877, 612, 405, 303, 429, 540, 485, 485, 603,
        798,  59, 240,  51, 634, 798, 757, 634, 246, 391, 840, 252, 252, 634,
        840, 840, 593, 877, 612, 903, 463, 212, 140, 205, 485, 485,  10, 593,
        240, 240, 240, 472, 553, 135, 421, 882, 429, 429, 252, 813, 840,   4,
        111, 586, 104, 405, 405, 114, 136, 121, 840, 485, 593, 798,  51, 349,
        240,  59, 798, 786, 258, 315, 391,  12, 485, 303,  41, 8

In [158]:
pred_indx = dist.sample()
pred_indx

tensor([ 568,   79, 1006,  518,  571,  592,  956,  880,  978,   89,  465,  712,
         858,  314,  136,  134,  429,  383,  312,  142,  475,  684,  786,  434,
         484,  259,  169,  635,  281,   98,   13,  680,  793,  465,  984,  750,
         812,   55,  346,  356,  864,  988,  825,  845,  714,   39,  901,  747,
         537,  587, 1004,  827,  274,  632, 1008,  592,  257,  494,  113,  456,
         605,  184,  188, 1005,  984,  484,  663,  576,  221,  537,  690,  366,
         916,  314,   94,  156,  237,   99,  439,  370,  817,  483,  958,  306,
         276,  476,  654,  367,  350,   25,  175,  362,   40,  873,  283,  610,
         575,  967,  586,  357,  934,  969,  199,  234,   36,   65,  752,   45,
        1007,  608, 1026,  632,  766,  273,  579,  847,  883,  778,  779,  268,
        1014,  964,   89,  642,  160,   36,  645,  908,  313,  311,  942,  265,
         498,  846,  561,   51,  294,  264,   27,  196,  617,  365,  539,  353,
         357,  724,  424,  267,  390,  7

In [162]:
pred_indx[:111]

tensor([ 568,   79, 1006,  518,  571,  592,  956,  880,  978,   89,  465,  712,
         858,  314,  136,  134,  429,  383,  312,  142,  475,  684,  786,  434,
         484,  259,  169,  635,  281,   98,   13,  680,  793,  465,  984,  750,
         812,   55,  346,  356,  864,  988,  825,  845,  714,   39,  901,  747,
         537,  587, 1004,  827,  274,  632, 1008,  592,  257,  494,  113,  456,
         605,  184,  188, 1005,  984,  484,  663,  576,  221,  537,  690,  366,
         916,  314,   94,  156,  237,   99,  439,  370,  817,  483,  958,  306,
         276,  476,  654,  367,  350,   25,  175,  362,   40,  873,  283,  610,
         575,  967,  586,  357,  934,  969,  199,  234,   36,   65,  752,   45,
        1007,  608, 1026])

In [159]:
(pred_indx == 1025).nonzero().flatten().tolist()

[]

In [160]:
(pred_indx == 1026).nonzero().flatten().tolist()

[110]

In [107]:
min(*(pred_indx == 1025).nonzero().flatten().tolist() , *(pred_indx == 1026).nonzero().flatten().tolist() , 180)

29

In [31]:
from utils.eval_music import evaluate_music_motion_trans
evaluate_music_motion_trans(dl, vqvae_model, trans_model)

  0%|          | 0/1910 [00:00<?, ?it/s]

(158, 22, 3) (158, 22, 3)


  0%|          | 1/1910 [00:01<42:51,  1.35s/it]

(399, 22, 3) (2, 22, 3)


  return np.linalg.norm(average_acceleration / current_window)
  0%|          | 2/1910 [00:03<1:05:07,  2.05s/it]

(158, 22, 3) (158, 22, 3)


  0%|          | 3/1910 [00:05<54:48,  1.72s/it]  

(399, 22, 3) (399, 22, 3)


  0%|          | 4/1910 [00:08<1:18:44,  2.48s/it]

(399, 22, 3) (399, 22, 3)


  0%|          | 4/1910 [00:11<1:28:19,  2.78s/it]


KeyboardInterrupt: 

## T2M Eval

In [12]:
import utils.utils_model as utils_model
from core.datasets import dataset_TM_eval
import utils.eval_trans as eval_trans
from core.models.evaluator_wrapper import EvaluatorModelWrapper
from utils.word_vectorizer import WordVectorizer
from utils.eval_trans import evaluation_vqvae_loss,evaluation_vqvae
from utils.eval_trans import calculate_R_precision,calculate_activation_statistics,calculate_diversity,calculate_frechet_distance
from tqdm import tqdm

In [13]:
w_vectorizer = WordVectorizer('/srv/scratch/sanisetty3/music_motion/T2M-GPT/glove', 'our_vab')
eval_wrapper = EvaluatorModelWrapper(cfg.eval_model)
tm_eval = dataset_TM_eval.DATALoader("t2m", True, 20, w_vectorizer, unit_length=4)


Loading Evaluation Model Wrapper (Epoch 28) Completed!!


100%|██████████| 4384/4384 [00:04<00:00, 912.21it/s]

4648 4648
Pointer Pointing at 0





In [37]:
metrics = evaluation_vqvae_loss(val_loader = tm_eval, net= vqvae_model,nb_iter= 0, eval_wrapper = eval_wrapper,save = False,)

100%|██████████| 232/232 [00:42<00:00,  5.50it/s]


--> 	 Eva. Iter 0 :, FID. 0.0637, Diversity Real. 9.4620, Diversity. 9.4266, R_precision_real. [0.61616379 0.79482759 0.86982759], R_precision. [0.60172414 0.78405172 0.86077586], matching_score_real. 2.9635313979510602, matching_score_pred. 3.0240239735307366
