In [1]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

NVIDIA GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import os
import random
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
from torch.utils import data
from tqdm import tqdm



In [4]:
from configs.config import get_cfg_defaults
from core.datasets.dataset_loading_utils import load_dataset
from core.datasets.vq_dataset import DATALoader
from utils.vis_utils import plot_3d_global
from core.models.conformer_vqvae import ConformerVQMotionModel, Encoder
from torch.utils import data
from core.datasets.vq_dataset import DATALoader, MotionCollator
from einops import pack, rearrange, reduce, repeat, unpack

def pack_one(t, pattern):
    return pack([t], pattern)


In [5]:
import utils.vis_utils.plot_3d_global as plot_3d
from utils.motion_processing.hml_process import recover_from_ric


## Dataset creation

In [6]:
from utils.vis_utils.render_final import Renderer
renderer = Renderer(device)

In [6]:
path = "/srv/hays-lab/scratch/sanisetty3/music_motion/TGM3D/checkpoints/conformer_768_1024_hmlvec/conformer_768_1024_hmlvec.yaml"
cfg = get_cfg_defaults()
print("loading config from:", path)
cfg.merge_from_file(path)
cfg.freeze()

ckpt = torch.load("/srv/hays-lab/scratch/sanisetty3/music_motion/TGM3D/checkpoints/conformer_768_1024_hmlvec/vqvae_motion.pt" , map_location="cpu")
print(ckpt["steps"])

loading config from: /srv/hays-lab/scratch/sanisetty3/music_motion/TGM3D/checkpoints/conformer_768_1024_hmlvec/conformer_768_1024_hmlvec.yaml
tensor([200000.])


In [9]:
import importlib

In [10]:
("core.models.conformer_vqvae.ConformerVQMotionModel").rsplit("." , 1)

['core.models.conformer_vqvae', 'ConformerVQMotionModel']

In [11]:
model = getattr(importlib.import_module('core.models.conformer_vqvae'), 'ConformerVQMotionModel')

In [12]:
convvq = model(cfg.vqvae).to(device).eval()


Sync is turned on False


In [7]:


from core.models.conformer_vqvae import ConformerVQMotionModel, Encoder
convvq = ConformerVQMotionModel(cfg.vqvae).to(device).eval()
convvq.load("/srv/hays-lab/scratch/sanisetty3/music_motion/TGM3D/checkpoints/conformer_768_1024_hmlvec/vqvae_motion.pt")


Sync is turned on False


In [12]:
# cd ../motion_vqvae/

In [13]:
# from configs.config import get_cfg_defaults

# path = "/srv/hays-lab/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/conv_vq/convq_512_512/convq_512_512.yaml"
# cfg = get_cfg_defaults()
# print("loading config from:", path)
# cfg.merge_from_file(path)
# cfg.freeze()

# ckpt = torch.load("/srv/hays-lab/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/conv_vq/convq_512_512/vqvae_motion.pt" , map_location="cpu")
# print(ckpt["steps"])

# from motion_vqvae.core.models.conv_vqvae import ConvVQMotionModel
# convvq = ConvVQMotionModel(cfg.vqvae).to(device).eval()

# convvq.load_state_dict(ckpt["model"])

In [32]:
from glob import glob
alll = glob("/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/AIST/new_joint_vecs/*")

In [43]:
with open("/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotionSMPL/AIST_SMPL/all.txt" , "w") as f:
    for line in alll:
        f.write(f'{line.split("/")[-1].split(".")[0]}\n')

In [14]:
from core.datasets.vq_dataset import VQMotionDataset


In [15]:
train_ds = VQMotionDataset("t2m" , "/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion" , window_size = -1, split = "test")

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4384/4384 [00:02<00:00, 1781.73it/s]

Total number of motions 4384





In [16]:
train_dl = DATALoader(
            train_ds,
            batch_size=1,
            shuffle=True,
            collate_fn=None,
        )

In [17]:
# for batch in train_dl:
#     break

In [12]:
dest = "/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotionIndices/HumanML3D/joint_indices"
os.makedirs(dest, exist_ok=True)

In [45]:
for i, batch in enumerate(tqdm(train_dl)):
    # if i < 12300:
    #     continue
    
    gt_motion = batch["motion"].to(device)
    if gt_motion.shape[1] > 100:
        ind = []
        for m in range(0, gt_motion.shape[1], 100):
            indics = convvq.encode(gt_motion[:, m:m+100])
            ind.append(indics[0])
        indices = torch.cat(ind)[None]
    else:
        indices = convvq.encode(gt_motion)
    np.save(os.path.join(dest , batch["names"][0]+".npy") , indices.detach().cpu().numpy())
    del indices
    del gt_motion
    torch.cuda.empty_cache()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2726/2726 [02:14<00:00, 20.30it/s]


In [None]:
"/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotionSMPLIndices/AIST/joint_indices/M_gJS_sBM_cAll_d03_mJS3_ch02"

In [None]:
og = f"/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/HumanML3D/new_joint_vecs/{batch['names'][0]}.npy"

In [18]:
for batch in train_dl:
    break
gt_motion = batch["motion"][:,:1000]
batch["motion"].shape

torch.Size([1, 199, 263])

In [58]:
if gt_motion.shape[1] > 200:
    ind = []
    for m in range(0, gt_motion.shape[1], 80):
        indics = convvq.encode(gt_motion[:, m:m+80].to(device))
        ind.append(indics[0])
    indices = torch.cat(ind)[None]

In [19]:
indices = convvq.encode(gt_motion.to(device))
indices.shape

torch.Size([1, 50])

In [21]:
indices

tensor([[ 304,    1,  587,  411,  889,    0,    0,  532,  778,  616,  587,  529,
          300,  300,  927, 1020,  721,  402,  995,  389,  801,  147,   41,  870,
          995,  995,  763,  310,  147,   41,  922,  411,  347,  616,  389,  332,
          217,  411,  704,  995,  402,  332,  821,  153,  393,  347,  274,  680,
          644,  992]], device='cuda:0')

In [20]:
quantized, decoded_motion_features = convvq.decode(indices.long())
decoded_motion_features.shape

torch.Size([1, 200, 263])

In [22]:
quantized.shape

torch.Size([1, 50, 768])

In [61]:
gt_motion1 = (
    train_ds.inv_transform(gt_motion.cpu())
    .squeeze()
    .float()
)
pred_motion = (
    train_ds.inv_transform(decoded_motion_features.cpu())
    .squeeze()
    .float()
)

save_file = "/srv/hays-lab/scratch/sanisetty3/music_motion/TGM3D/renders"
gt_motion_xyz = recover_from_ric(gt_motion1, 22)
pred_motion_xyz = recover_from_ric(pred_motion, 22)

gt_pose_vis = plot_3d.draw_to_batch(
    gt_motion_xyz.numpy().squeeze()[None],
    None,
    [os.path.join(save_file, "t" + "_gt.gif")],
)
pred_pose_vis = plot_3d.draw_to_batch(
    pred_motion_xyz.numpy().squeeze()[None],
    None,
    [os.path.join(save_file, "t" + "_pred.gif")],
)

In [14]:
from core.models.evaluator_wrapper import EvaluatorModelWrapper
from utils.word_vectorizer import WordVectorizer
from core.datasets import dataset_TM_eval


In [15]:
w_vectorizer = WordVectorizer(
   "/srv/hays-lab/scratch/sanisetty3/music_motion/T2M-GPT/glove", "our_vab"
)
eval_wrapper = EvaluatorModelWrapper(cfg.eval_model)
tm_eval = dataset_TM_eval.DATALoader(
    32,
    w_vectorizer,
    unit_length=4,
)

Loading Evaluation Model Wrapper (Epoch 28) Completed!!


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4384/4384 [00:04<00:00, 1066.12it/s]

4248 4248
Pointer Pointing at 0





In [16]:
from utils.eval_trans import calculate_R_precision, calculate_multimodality, calculate_diversity, calculate_frechet_distance, calculate_activation_statistics

In [115]:
renderer.render(
    motion_vec=ogg[:1000,:135],
    outdir="./renders/",
    step=0,
    name=f"00_000007_og",
)

In [3]:
import numpy as np

In [13]:
hml = "/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/HumanML3D/Mean.npy"
hml2 = "/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/AIST/Mean.npy"
hml3 = "/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/Choreomaster/Mean.npy"

In [14]:
np.save("/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/Mean.npy" , np.mean([np.load(hml) + np.load(hml2) + np.load(hml3)] , 0))

In [138]:
from glob import glob

In [139]:
pths = sorted(glob("/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/HumanML3D/new_joint_vecs/*"))

In [145]:
pths[25000]

'/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/HumanML3D/new_joint_vecs/M008676.npy'

In [144]:
int(pths[25000].split("/")[-1].split(".")[0][-6:])

8676

In [156]:
add = []
for p in tqdm(pths):
    nm = int(p.split("/")[-1].split(".")[0][-6:])
    if nm > 14616:
        add.append(p.split("/")[-1].split(".")[0])
        

100%|███████████████████████████████████████████████████████████████████████████████████████████████| 32648/32648 [00:00<00:00, 964613.99it/s]


In [158]:
with open(r'/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/HumanML3D/train.txt', 'a') as fp:
    for item in add:
        fp.write("%s\n" % item)
    print('Done')

Done


In [157]:
len(add)

3418

In [153]:
add = []
for p in tqdm(pths):
    n = p.split("/")[-1].split(".")[0]
    add.append(n)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 32648/32648 [00:00<00:00, 1130149.03it/s]


Done


# BERT

In [6]:
from core.datasets.dataset_loading_utils import load_dataset_bert
from core.datasets.motion_bert_dataset import BERTMotionDataset, DATALoader
from core.models.BERT import BERT, BERTParams
from core.optimizer import get_optimizer
from torch.utils.data import DataLoader


  from .autonotebook import tqdm as notebook_tqdm


In [38]:
path = "/srv/hays-lab/scratch/sanisetty3/music_motion/TGM3D/checkpoints/bert_12_768/bert_12_768.yaml"
cfg = get_cfg_defaults()
print("loading config from:", path)
cfg.merge_from_file(path)
cfg.freeze()

loading config from: /srv/hays-lab/scratch/sanisetty3/music_motion/TGM3D/checkpoints/bert_12_768/bert_12_768.yaml


In [39]:
train_ds, sampler_train, weights_train = load_dataset_bert(
                dataset_names=["t2m"],
                args=cfg,
                split="test",
                weight_scale=[1],
            )

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4384/4384 [00:02<00:00, 2117.85it/s]

Total number of motions 4376





In [44]:
dl = DATALoader(
            train_ds,
            batch_size=20,
            shuffle=False,
        )

In [45]:
for data in dl:
    break

In [None]:
self.pretrained_embedding = torch.nn.Embedding.from_pretrained(glove_embeddings)
self.trainable_embedding = torch.nn.Embedding(
    how_many_tokens_not_present, glove_embeddings.shape[1]
)

In [32]:
torch.nn.Embedding.from_pretrained(convvq.vq.codebook)

Embedding(1024, 768)

In [34]:
params = BERTParams()
bert = BERT(params).cuda()

In [99]:
a = [False , True]
a.extend([True])

In [36]:
torch.save(convvq.vq.codebook , "./checkpoints/conformer_768_1024_hmlvec/codebook.pt")

In [39]:
data["bert_input"]

tensor([[1026, 1024,  611,  242, 1024,  845,  193,  591,  888, 1024, 1024,  699,
          153,  193,  822,  591,  197,  699,  845,  360,   33, 1024, 1024,  358,
           33, 1024, 1024,  142,  591,  197,  699, 1024,  360,  591, 1024,  897,
         1024,  699,  447,  193,  177, 1024, 1024,   74,  148,  999,  674,  142,
          552,  627,  699, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025],
        [1026,  565, 1024,  565,  692,  692, 1024,  129, 1011,  482,  632,  791,
          288,   52,  204, 1024,  241, 1025, 1025, 

In [44]:
mask = data["bert_input"] != 1024
mask

torch.Size([2, 128])

In [72]:
v = torch.randint(0,10 , (5 , 1)).float()
v

tensor([[8.],
        [2.],
        [8.],
        [2.],
        [3.]])

In [None]:
mask = torch.ones(128)
mask[:p] = 0
np.random.shuffle(mask)

In [76]:
e = torch.nn.Embedding(1027 , 768 , padding_idx=1025)
e.weight.data[:1024] = convvq.vq.codebook

In [34]:
len(convvq.vq.codebook)

1024

In [88]:
torch.zeros(10).bool()

tensor([False, False, False, False, False, False, False, False, False, False])

In [78]:
convvq.vq.codebook

torch.Size([1024, 768])

In [84]:
e.weight.data[:1024] = convvq.vq.codebook

In [66]:

dots= torch.randint(0,10 , (5,5)).float()
dots

tensor([[6., 7., 1., 9., 8.],
        [5., 9., 1., 5., 3.],
        [3., 5., 2., 3., 2.],
        [3., 8., 9., 7., 2.],
        [2., 5., 7., 5., 1.]])

In [67]:
mask = torch.Tensor([False,True,False,False,True]).bool()

In [68]:
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)

In [69]:
dots

tensor([[-3.4028e+38,  7.0000e+00, -3.4028e+38, -3.4028e+38,  8.0000e+00],
        [-3.4028e+38,  9.0000e+00, -3.4028e+38, -3.4028e+38,  3.0000e+00],
        [-3.4028e+38,  5.0000e+00, -3.4028e+38, -3.4028e+38,  2.0000e+00],
        [-3.4028e+38,  8.0000e+00, -3.4028e+38, -3.4028e+38,  2.0000e+00],
        [-3.4028e+38,  5.0000e+00, -3.4028e+38, -3.4028e+38,  1.0000e+00]])

In [75]:
torch.softmax(dots , -1)@v

tensor([[2.7311],
        [2.0025],
        [2.0474],
        [2.0025],
        [2.0180]])

In [35]:
mlm_loss_fnc = torch.nn.CrossEntropyLoss(
            ignore_index=bert.pad_index
        )

In [36]:
mask_lm_output = bert.forward(data["bert_input"].cuda())

mask_loss = mlm_loss_fnc(
    mask_lm_output.transpose(1, 2), data["bert_label"].cuda()
)
loss = cfg.bert.loss_mlm * mask_loss

In [37]:
data["bert_input"].shape

torch.Size([2, 128])

In [50]:
data["bert_label"]

tensor([[1025, 1025, 1025,  242, 1025, 1025,  193, 1025, 1025, 1025, 1025, 1025,
         1025, 1025,  922, 1025,  197, 1025,  845,  360, 1025, 1025,  552, 1025,
          148, 1025, 1025,  142, 1025, 1025,  699, 1025, 1025,  591, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,  999, 1025,  142,
         1025, 1025,  699, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025,
         1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025],
        [1025, 1025, 1025, 1025, 1025, 1025, 1025,  129, 1025, 1025,  632, 1025,
         1025,   52, 1025,   52, 1025, 1025, 1025, 

In [39]:
mask_lm_output.shape

torch.Size([2, 128, 1028])

In [44]:
mlm = torch.nn.functional.softmax(mask_lm_output , -1)

In [46]:
mlm[0,0]

tensor([0.0028, 0.0008, 0.0013,  ..., 0.0005, 0.0005, 0.0004], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [None]:
np.where()

In [57]:
msk = data["bert_label"]!=1025

In [58]:
data["bert_label"][msk]

tensor([242, 193, 922, 197, 845, 360, 552, 148, 142, 699, 591, 999, 142, 699,
        129, 632,  52,  52])

In [61]:
correct = mlm.argmax(-1).cpu()[msk]  == data["bert_label"][msk]
correct = correct.sum()/len(correct)

In [62]:
correct

tensor(0.)

In [29]:
nsp_loss_fnc = torch.nn.NLLLoss(ignore_index=0)
mlm_loss_fnc = torch.nn.CrossEntropyLoss(ignore_index=-1)

In [31]:
mask_lm_output = bert.forward(data["bert_input"].cuda())

In [32]:
bert.token_emb

Embedding(1028, 768, padding_idx=1025)

In [16]:
135 + 3*22 + 22*3

267

In [17]:
135 + 3*22

201

In [18]:
22*9

198

In [41]:
next_sent_output.shape

torch.Size([2, 2])

In [43]:
data["is_next"]

tensor([[1],
        [1]])

In [33]:
# next_loss = nsp_loss_fnc(next_sent_output, data["is_next"].cuda().reshape(-1))

# 2-2. NLLLoss of predicting masked token word
mask_loss = mlm_loss_fnc(
    mask_lm_output.transpose(1, 2), data["bert_label"].cuda()
)

In [39]:
mask_loss

tensor(7.0331, device='cuda:0', grad_fn=<NllLoss2DBackward0>)

In [58]:
correct = (
                    next_sent_output.argmax(dim=-1).eq(data["is_next"].cuda()).sum().item()
                )

In [64]:
next_sent_output.argmax(dim=-1)

tensor([0, 0], device='cuda:0')

In [65]:
data["is_next"].shape

torch.Size([2, 1])

In [15]:
from functools import reduce


In [16]:
def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob


def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask


def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)

    num_tokens = mask.sum(dim=-1, keepdim=True)
    mask_excess = mask.cumsum(dim=-1) > (num_tokens * prob).ceil()
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-1)
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, seq_len + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    return new_mask[:, 1:].bool()

In [20]:
import math

In [34]:
mask_ignore_token_ids = set([1025, 1026, 1027])

In [64]:
seq = torch.Tensor([1026] + list(np.random.randint(0, 1024 , (10))) + [1027]).long()[None]

In [65]:
no_mask = mask_with_tokens(seq, mask_ignore_token_ids)
mask = get_mask_subset_with_prob(~no_mask, 0.3)

In [67]:
mask

tensor([[False, False, False,  True, False,  True, False,  True,  True, False,
         False, False]])

In [68]:
labels = seq.masked_fill(~mask, 1024)

In [69]:
labels

tensor([[1024, 1024, 1024,   90, 1024,  310, 1024,  914,  344, 1024, 1024, 1024]])

## text

In [7]:
hml_txs = "/srv/hays-lab/scratch/sanisetty3/music_motion/HumanMotion/HumanML3D/texts/000000.txt"

In [11]:
from glob import glob
import codecs as cs
from scipy.spatial.transform import Rotation as R
def get_caption(path):
    text_data = []
    captions = []
    flag = False
    with cs.open(path) as f:
        for line in f.readlines():
            text_dict = {}
            line_split = line.strip().split('#')
            caption = line_split[0]
            captions.append(caption)
            tokens = line_split[1].split(' ')
            f_tag = float(line_split[2])
            to_tag = float(line_split[3])
            f_tag = 0.0 if np.isnan(f_tag) else f_tag
            to_tag = 0.0 if np.isnan(to_tag) else to_tag
    
            text_dict['caption'] = caption
            text_dict['tokens'] = tokens
            if f_tag == 0.0 and to_tag == 0.0:
                flag = True
                text_data.append(text_dict)
            else:
                try:
                    n_motion = motion[int(f_tag*fps) : int(to_tag*fps)]
                    if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
                        continue
                    new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
                    while new_name in data_dict:
                        new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
                    data_dict[new_name] = {'motion': n_motion,
                                           'length': len(n_motion),
                                           'text':[text_dict]}
                    new_name_list.append(new_name)
                    length_list.append(len(n_motion))
                except:
                    print(line_split)
                    print(line_split[2], line_split[3], f_tag, to_tag, name)
                    # break
    return captions

In [12]:
captions = get_caption(hml_txs)

In [13]:
captions

['a man kicks something or someone with his left leg.',
 'the standing person kicks with their left foot before going back to their original stance.',
 'a man kicks with something or someone with his left leg.',
 'he is flying kick with his left leg']

In [None]:
t5 = T5(128)