In [1]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

NVIDIA GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import torch
import numpy as np
import os
import torch.nn as nn
from tqdm import tqdm
import json
from functools import partial
from torch import einsum, nn
import torch.nn.functional as F

In [4]:
def findAllFile(base):
    file_path = []
    for root, ds, fs in os.walk(base, followlinks=True):
        for f in fs:
            fullname = os.path.join(root, f)
            file_path.append(fullname)
    return file_path



In [82]:
# from utils.motion_processing.hml_process import recover_from_ric, recover_root_rot_pos,recover_from_rot
import utils.vis_utils.plot_3d_global as plot_3d
import matplotlib.pyplot as plt

def vis(mot , dset , name = "motion"):

    if isinstance(mot , torch.Tensor):
        mot = dset.toMotion(mot)
    mot =dset.inv_transform(mot)



    xyz = np.array(dset.to_xyz(mot).cpu())

    print(xyz.shape)

    
    plot_3d.render(xyz , f"/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/{name}.gif")

In [6]:
from configs.config_t2m import cfg, get_cfg_defaults

cfg = get_cfg_defaults()
cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_translation/motion_translation.yaml")

In [8]:
from core.datasets.conditioner import ConditionProvider, ConditionFuser
from core.datasets.multimodal_dataset import MotionAudioTextDataset, load_dataset, simple_collate

In [9]:
from core.models.attend import Attention
from core import AttentionParams
from core import AttentionParams, TranslationTransformerParams, PositionalEmbeddingParams, PositionalEmbeddingType, MotionRep, AudioRep, TextRep


In [10]:
# from core.models.generation.translation_transformer import TranslationTransformer
# translation_tranformer = TranslationTransformer(cfg.translation_transformer).cuda()

In [64]:
dataset_args = cfg.dataset
condition_provider = ConditionProvider(
            motion_rep=MotionRep(dataset_args.motion_rep),
            audio_rep=AudioRep(dataset_args.audio_rep),
            text_rep=TextRep(dataset_args.text_rep),
            motion_padding=dataset_args.motion_padding,
            audio_padding=dataset_args.audio_padding,
            motion_max_length_s=10,
            audio_max_length_s=10,
        )

condition_provider2 = ConditionProvider(
            motion_rep=MotionRep(dataset_args.motion_rep),
            motion_padding=dataset_args.motion_padding,

        )

In [54]:
from core.datasets.multimodal_dataset import MotionAudioTextDataset
from core.datasets.vq_dataset import VQSMPLXMotionDataset
from core.datasets.vq_dataset import simple_collate as simple_collate2


from utils.motion_processing.skeleton import Skeleton, t2m_kinematic_chain , body_joints_id, t2m_raw_body_offsets

In [85]:
dset = MotionAudioTextDataset("moyo" , "/srv/hays-lab/scratch/sanisetty3/motionx" ,motion_rep = "body" , hml_rep = "gprvc", split = "test"   )

Total number of motions moyo: 9 and texts 9


In [86]:
dset2 = VQSMPLXMotionDataset("moyo" , "/srv/hays-lab/scratch/sanisetty3/motionx" ,motion_rep = "body" , hml_rep = "gprvc", split = "test" , window_size = 120  )

Total number of motions moyo: 9


In [87]:
mmm = next(iter(dset2))

In [88]:
mmm["motion"].root_params.shape

(120, 4)

In [89]:
train_loader = torch.utils.data.DataLoader(
        dset2,
        4,
        # sampler=sampler,
        collate_fn=partial(simple_collate2 , conditioner = condition_provider2),
        drop_last=True,
    )

In [90]:
for inputs in train_loader:
    break
    

In [91]:
mot = inputs["motion"][0]

In [92]:
mot.shape

torch.Size([4, 120, 263])

In [72]:
mottt = dset.toMotion(mot)

In [73]:
mottt().shape

torch.Size([4, 120, 263])

In [93]:
vis(mot[0] , dset2)

(120, 22, 3)


In [82]:

sap = AttentionParams(dim = 256 , causal=True)
cap = AttentionParams(dim = 256 , causal=True , add_null_kv=True)
transformer_params = TranslationTransformerParams(self_attention_params = sap , 
                                                  cross_attention_params = cap , 
                                                  depth = 1, 
                                                  positional_embedding_params=PositionalEmbeddingParams(dim = 256) , 
                                                  positional_embedding=PositionalEmbeddingType.SINE,
                                                  fuse_method = {"cross_seperate" : ["audio" , "text"]}
                                                 )

In [62]:
from core.models.generation.translation_transformer import ClassifierFreeGuidanceDropout, TransformerBlock

In [63]:
dim = 256

In [64]:
pos_emb = ScaledSinusoidalEmbedding(dim , theta = 10000)
cfg_dropout = ClassifierFreeGuidanceDropout(
            0.0
        )
condition_fuser = ConditionFuser({"cross" : ["audio"] ,   "prepend" : ["text"]})

In [65]:
transformer_blocks = TransformerBlock(sap , cap).cuda()

In [66]:
project_audio = (
    nn.Linear(128, dim).cuda()
    
)
project_text = (
    nn.Linear(768, dim).cuda()
    
)

In [67]:
motion = inputs["motion"][0]
motion_padding_mask = inputs["motion"][1]
device, b, n , d = motion.device, *motion.shape
translation = motion[... , :4]

In [68]:
x = (
    pos_emb(translation).repeat(b, 1, 1) * motion_padding_mask.unsqueeze(-1),
    motion_padding_mask,
)


In [69]:
conditions = cfg_dropout(conditions , 0.2)


In [70]:
audio_embed = project_audio(conditions["audio"][0])
text_embed = project_text(conditions["text"][0])


In [71]:
conditions["audio"] = (audio_embed, conditions["audio"][1])
conditions["text"] = (text_embed, conditions["text"][1])


In [72]:
conditions["audio"][0].shape

torch.Size([4, 1, 256])

In [75]:
inputs_, cross_inputs_ = condition_fuser(x, conditions)

In [78]:
x_ = inputs_[0]
x_padding_mask = inputs_[1]
context = cross_inputs_[0]
context_padding_mask = cross_inputs_[1]

In [81]:
x_.shape

torch.Size([4, 153, 256])

In [82]:
x_padding_mask.shape

torch.Size([4, 153])

In [83]:
context.shape

torch.Size([4, 1, 256])

In [84]:
context_padding_mask.shape

torch.Size([4, 1])

In [85]:
conditions["text"][1]

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  Tr

In [86]:
x_padding_mask

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  

In [87]:
embed = transformer_blocks(
            x=x_,
            mask=x_padding_mask,
            context=context,
            context_mask=context_padding_mask,
        )

In [89]:
embed = embed[:, -n:, :]

In [90]:
embed.shape

torch.Size([4, 116, 256])

In [143]:
pos_emb = ScaledSinusoidalEmbedding(PositionalEmbeddingParams(dim = 256))

In [161]:
b , n , _ = input["motion"][0].shape

In [167]:

x = pos_emb(input["motion"][0]).repeat(b , 1 ,1)


In [173]:
x.shape

torch.Size([4, 116, 256])

In [None]:
for a, (b , c) in conditions.items():
    print(a)
    print(b.shape)

In [175]:
inputs_ , cross_inputs = condition_fuser(x , conditions  )

In [176]:
inputs_.shape

torch.Size([4, 153, 256])

In [177]:
cross_inputs.shape

torch.Size([4, 500, 256])