In [11]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

NVIDIA GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [12]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
from IPython.display import Image


In [14]:
import torch
import numpy as np
import os
import torch.nn as nn
from tqdm import tqdm
import json
from functools import partial
from torch import einsum, nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from einops import pack, rearrange, reduce, repeat, unpack


In [15]:
from core.datasets.text_encoders import T5Conditioner,MPNETConditioner

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# t5 = T5Conditioner("t5-large")
# sent = " a man walks forward, to the end of the platform, then turns counterclockwise on his heel, before walking back to his starting point."
# tokn = t5.tokenize(sent)
# embed, msk = t5.get_text_embedding(tokn)

In [16]:
mp = MPNETConditioner()
# sent = " a man walks forward, to the end of the platform, then turns counterclockwise on his heel, before walking back to his starting point."
# tokn = mp.tokenize(sent)
# embed, msk = mp.get_text_embedding(tokn)

In [152]:
embed.shape

torch.Size([1, 768])

In [17]:
from core.models.utils import instantiate_from_config, get_obj_from_str
from core.datasets.tmr_dataset import TMRDataset, load_dataset, simple_collate
from core import MotionRep, TextRep, AudioRep
from core.datasets.conditioner import ConditionProvider,ConditionFuser
from configs.config_tmr import get_cfg_defaults
from core.models.TMR.tmr import TMR

In [18]:
tmr_cfg = get_cfg_defaults()
tmr_cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/tmr/tmr.yaml")
tmr_cfg.freeze()
dataset_args = tmr_cfg.dataset
tmr_parms = tmr_cfg.tmr

In [19]:
_ = tmr_parms.pop("target")
motion_encoder = instantiate_from_config(tmr_cfg.motion_encoder).to(device)
text_encoder = instantiate_from_config(tmr_cfg.text_encoder).to(device)
motion_decoder = instantiate_from_config(tmr_cfg.motion_decoder).to(device)
tmr = TMR(motion_encoder , text_encoder , motion_decoder , lr = tmr_cfg.train.learning_rate, **tmr_parms).to(device)

In [20]:
pkg = torch.load("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/tmr/checkpoints/tmr.40000.pt")

In [22]:
tmr.load_state_dict(pkg["model"])

<All keys matched successfully>

In [11]:
train_ds, sampler_train, weights_train  = load_dataset(dataset_names = ["idea400" , "animation"] , dataset_args=dataset_args, split = "test")


Total number of motions idea400: 577 and texts 577
Total number of motions animation: 2 and texts 2


In [12]:
condition_provider = ConditionProvider(
            text_conditioner_name = dataset_args.text_conditioner_name,
            motion_rep=MotionRep(dataset_args.motion_rep),
            audio_rep=AudioRep(dataset_args.audio_rep),
            text_rep=TextRep(dataset_args.text_rep),
            motion_padding=dataset_args.motion_padding,
            motion_max_length_s=dataset_args.motion_max_length_s,
            fps=30,
            # only_motion = True
        )

In [13]:
train_loader = torch.utils.data.DataLoader(
        train_ds,
        4,
        sampler=sampler_train,
        # shuffle = False,
        collate_fn=partial(simple_collate , conditioner = condition_provider),
        # drop_last=True,
    )

In [14]:
for inputs, conditions in (train_loader):
    break
    

In [15]:
motion = inputs["motion"][0]
motion_mask = inputs["motion"][1].to(torch.bool)
text = conditions["text"][0]
text_mask = conditions["text"][1].to(torch.bool)

In [143]:
motion_mask.device

device(type='cuda', index=0)

In [21]:
tkn = mp.tokenize(list(inputs["texts"]))
sent_embed, sent_msk = mp.get_text_embedding(tkn)
conditions["sent_emb"] = (sent_embed , sent_msk)

In [28]:
losses = tmr.compute_loss(inputs["motion"] , conditions)

In [29]:
losses

{'recons': tensor(0.6562, device='cuda:0', grad_fn=<AddBackward0>),
 'kl': tensor(8.7833, device='cuda:0', grad_fn=<AddBackward0>),
 'latent': tensor(1.3628, device='cuda:0', grad_fn=<SmoothL1LossBackward0>),
 'contrastive': tensor(1.7467, device='cuda:0', grad_fn=<DivBackward0>),
 'loss': tensor(0.8310, device='cuda:0', grad_fn=<AddBackward0>)}

In [54]:
m_motions, m_latents, m_dists = tmr(inps, mask=motion_mask, return_all=True)

In [None]:
t_motions, t_latents, t_dists = tmr(text_x_dict, mask=text_mask, return_all=True)

In [None]:
tokn = mp.tokenize(sent)
embed, msk = mp.get_text_embedding(tokn)

## Eval

In [75]:
from core.models.TMR.tmr import get_score_matrix

In [25]:
condition_provider = ConditionProvider(
            text_conditioner_name = dataset_args.text_conditioner_name,
            motion_rep=MotionRep(dataset_args.motion_rep),
            audio_rep=AudioRep(dataset_args.audio_rep),
            text_rep=TextRep(dataset_args.text_rep),
            motion_padding=dataset_args.motion_padding,
            motion_max_length_s=dataset_args.motion_max_length_s,
            fps=30,
            # only_motion = True
        )

In [26]:
from core.datasets.base_dataset import BaseMotionDataset
base_dset = BaseMotionDataset(motion_rep=MotionRep.BODY , hml_rep= "gpvc")

In [31]:
text="a man sets to do a backflips then fails back flip and falls to the ground"
mot="/srv/hays-lab/scratch/sanisetty3/motionx/motion_data/new_joint_vecs/humanml/001034.npy"

In [32]:
motion = np.load(mot)

In [37]:
motion.shape

(231, 623)

In [67]:
text_embed, text_mask = condition_provider._get_text_features(
        raw_text=text,
    )

In [39]:
processed_motion = base_dset.get_processed_motion(
            motion, motion_rep=MotionRep(dataset_args.motion_rep), hml_rep=dataset_args.hml_rep
        )

In [84]:
mot_body = torch.Tensor(processed_motion()).to(device)[None]

In [85]:
motion_x_dict = {"x": mot_body, "mask": torch.ones_like(mot_body)[...,0].to(torch.bool)}
text_x_dict = {"x": text_embed, "mask": text_mask.to(torch.bool)}

In [80]:
mot_body.shape

torch.Size([1, 120, 137])

In [81]:
lat_m = tmr.encode(motion_x_dict, sample_mean=True)[0]
lat_t = tmr.encode(text_x_dict, sample_mean=True)[0]

In [82]:
score = get_score_matrix(lat_t, lat_m).cpu()

In [83]:
score

tensor(0.7725, grad_fn=<ToCopyBackward0>)

In [None]:
with torch.inference_mode():
        # motion -> latent
        # motion_x_dict = collate_x_dict([motion_x_dict])
        lat_m = tmr.encode(motion_x_dict, sample_mean=True)[0]

        # text -> latent
        text_x_dict = collate_x_dict(text_model([text]))
        lat_t = model.encode(text_x_dict, sample_mean=True)[0]

        score = get_score_matrix(lat_t, lat_m).cpu()