In [1]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

NVIDIA GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import torch
import numpy as np
import os
import torch.nn as nn
from tqdm import tqdm
import json
from functools import partial
from torch import einsum, nn
import torch.nn.functional as F

In [4]:
def findAllFile(base):
    file_path = []
    for root, ds, fs in os.walk(base, followlinks=True):
        for f in fs:
            fullname = os.path.join(root, f)
            file_path.append(fullname)
    return file_path



In [5]:
# from utils.motion_processing.hml_process import recover_from_ric, recover_root_rot_pos,recover_from_rot
import utils.vis_utils.plot_3d_global as plot_3d
import matplotlib.pyplot as plt

def vis(mot , dset , name = "motion"):

    if isinstance(mot , torch.Tensor):
        mot = dset.toMotion(mot)
    mot =dset.inv_transform(mot)



    xyz = np.array(dset.to_xyz(mot).cpu())

    print(xyz.shape)

    
    plot_3d.render(xyz , f"/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/{name}.gif")

In [6]:
from configs.config import cfg, get_cfg_defaults

vcfg = get_cfg_defaults()
vcfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/vqvae/vqvae_body/vqvae_body_rv.yaml")

In [7]:
from core.datasets.conditioner import ConditionProvider, ConditionFuser
from core.datasets.multimodal_dataset import MotionAudioTextDataset, load_dataset, simple_collate

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
from core.models.attend import Attention
from core import AttentionParams
from core import AttentionParams, TranslationTransformerParams, PositionalEmbeddingParams, PositionalEmbeddingType, MotionRep, AudioRep, TextRep


In [9]:
from core.datasets.conditioner import ConditionProvider, ConditionFuser
# from core.datasets.multimodal_dataset import MotionAudioTextDataset, load_dataset, simple_collate
dataset_args = vcfg.dataset


condition_provider2 = ConditionProvider(
            motion_rep=MotionRep(dataset_args.motion_rep),
            motion_padding=dataset_args.motion_padding,

        )

In [10]:
from core.datasets.multimodal_dataset import MotionAudioTextDataset
from core.datasets.vq_dataset import VQSMPLXMotionDataset
from core.datasets.vq_dataset import load_dataset, simple_collate as simple_collate2
from core import Motion

from utils.motion_processing.skeleton import Skeleton, t2m_kinematic_chain , body_joints_id, t2m_raw_body_offsets

In [11]:
# dset = MotionAudioTextDataset("moyo" , "/srv/hays-lab/scratch/sanisetty3/motionx" ,motion_rep = "body" , hml_rep = "gprvc", split = "test"   )

In [12]:
dset = VQSMPLXMotionDataset("choreomaster" , "/srv/hays-lab/scratch/sanisetty3/motionx" ,motion_rep = "body" , hml_rep = "rv", split = "train" , window_size = 600  )

Total number of motions choreomaster: 34


In [13]:
# test_ds, _, _ = load_dataset(
#             dataset_args=dataset_args,
#             split="test",
#         )

In [14]:
train_loader = torch.utils.data.DataLoader(
        dset,
        10,
        # sampler=sampler,
        collate_fn=partial(simple_collate2 , conditioner = condition_provider2),
        # drop_last=True,
    )

In [15]:
for inputs in train_loader:
    break
    

In [16]:
mot = inputs["motion"][0]
mot.shape

torch.Size([10, 300, 192])

In [62]:
dset.render_hml(mot[0] , "/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/r.gif")

In [17]:
from core.models.resnetVQ.vqvae import HumanVQVAE
from core.models.loss import ReConsLoss

In [18]:
vqvae_args = vcfg.vqvae

In [19]:
vqvae_args.nb_joints = dset.nb_joints
vqvae_args.motion_dim = dset.motion_dim
hml_rep = dset.hml_rep
motion_rep = dset.motion_rep


In [20]:

loss_fnc = ReConsLoss("l1_smooth" , True , vqvae_args.nb_joints , hml_rep=hml_rep , motion_rep = motion_rep  )
loss_fnc2 = ReConsLoss("l1_smooth" , False , vqvae_args.nb_joints , hml_rep=hml_rep , motion_rep = motion_rep  )

In [21]:
vqvae_model = HumanVQVAE(vqvae_args).to(device)
vqvae_model.load("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/vqvae/vqvae_body/vqvae_motion_rv.pt")

In [47]:
gt_motion = inputs["motion"][0].to(device)
out = vqvae_model(
    motion=gt_motion,
    # mask=mask,
)

In [None]:
# motion = inputs["motion"][0].to(device)
# pred = vqvae_model(motion)
# loss = loss_fnc(pred.decoded_motion , motion)

In [48]:
loss_motion = loss_fnc(
            out.decoded_motion, gt_motion, mask=None
        )

In [49]:
loss_motion

tensor(0.2120, device='cuda:0', grad_fn=<AddBackward0>)

In [50]:
loss_motion2 = loss_fnc2(
            out.decoded_motion, gt_motion, mask=None
        )

In [51]:
loss_motion2

tensor(0.0545, device='cuda:0', grad_fn=<SmoothL1LossBackward0>)

In [23]:

vqvae_model.eval()
val_loss_ae = {}
cnt = 0
with torch.no_grad():
    for batch in tqdm(
        (train_loader),
        position=0,
        leave=True,
    ):
        gt_motion = batch["motion"][0].to(device)

        indices = vqvae_model.encode(
            motion=gt_motion,
            # mask=mask,
        )

        used_indices = indices.flatten().tolist()
        usage = len(set(used_indices)) / vqvae_args.codebook_size
        # print(usage)

        loss_dict = {
            # "total_loss": loss.detach().cpu(),
            # "loss_motion": loss_motion.detach().cpu(),
            # "commit_loss": vqvae_output.commit_loss.detach().cpu(),
            "usage": usage,
        }

        cnt+=1

        for key, value in loss_dict.items():
            if key in val_loss_ae:
                val_loss_ae[key] += value
            else:
                val_loss_ae[key] = value


for key in val_loss_ae.keys():
    val_loss_ae[key] = val_loss_ae[key] / cnt

100%|█████████████████████████████████████████████████████████████████████████| 49/49 [00:07<00:00,  6.17it/s]


In [24]:
val_loss_ae

{'usage': 0.5708107461734694}

### Motion MUSE

In [9]:
from core import MotionRep, AudioRep, TextRep
from core.datasets.conditioner import ConditionProvider, ConditionFuser
from core.datasets.multimodal_dataset import MotionIndicesAudioTextDataset, load_dataset_gen, simple_collate
from core.models.generation.motion_generator import Transformer, MotionMuse
from core.models.utils import instantiate_from_config, get_obj_from_str


In [10]:
from configs.config_t2m import cfg, get_cfg_defaults
from configs.config import get_cfg_defaults as get_cfg_defaults3

cfg = get_cfg_defaults()
cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/motion_generation/motion_generation.yaml")
cfg.freeze()
mmuse_args = cfg.motion_generator
dataset_args = cfg.dataset


In [11]:
target = mmuse_args.pop("target")
motion_muse = MotionMuse(mmuse_args).to(device).eval()

In [12]:
from core.models.resnetVQ.vqvae import HumanVQVAE

In [13]:

vcfg = get_cfg_defaults3()
vcfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/vqvae/vqvae_body_gprvc/vqvae_body_gprvc.yaml")
vqvae_args = vcfg.vqvae
vqvae_args.nb_joints = 22
vqvae_args.motion_dim = 263

In [14]:
vqvae_model = HumanVQVAE(vqvae_args).to(device).eval()
vqvae_model.load("/srv/hays-lab/scratch/sanisetty3/music_motion/ACMG/checkpoints/smplx_resnet/vqvae_motion.pt")

In [15]:
condition_provider = ConditionProvider(
            motion_rep=MotionRep(dataset_args.motion_rep),
            audio_rep=AudioRep(dataset_args.audio_rep),
            text_rep=TextRep(dataset_args.text_rep),
            motion_padding=dataset_args.motion_padding,
            audio_padding=dataset_args.audio_padding,
            motion_max_length_s=10,
            audio_max_length_s=10,
            pad_id = motion_muse.transformer.pad_token_id,
            fps=30/4
        )

In [57]:
dset = MotionIndicesAudioTextDataset("choreomaster" , "/srv/hays-lab/scratch/sanisetty3/motionx" ,motion_rep = "body", split = "train" , fps = 30/4  )

Total number of motions choreomaster: 34 and texts 34


In [56]:
train_ds, sampler_train, weights_train  = load_dataset_gen(dataset_args=dataset_args, split = "train" , dataset_names = ["animation" , "choreomaster" ] )

Total number of motions animation: 308 and texts 308
Total number of motions choreomaster: 34 and texts 34


In [62]:
train_loader = torch.utils.data.DataLoader(
        train_ds,
        4,
        sampler=sampler_train,
        # shuffle = False,
        collate_fn=partial(simple_collate , conditioner = condition_provider),
        drop_last=True,
    )

In [72]:
for inputs, conditions in train_loader:
    break
    

In [73]:
inputs["motion"][0].shape

torch.Size([4, 75, 1])

In [74]:
conditions["audio"][0].shape

torch.Size([4, 500, 128])

In [35]:
motions = inputs["motion"][0].squeeze().to(torch.long)
motion_mask = inputs["motion"][1]

In [72]:
inputs["motion"][0].shape

torch.Size([4, 75, 1])

In [73]:
pred_indices.shape

torch.Size([4, 75])

In [36]:
loss , logits = motion_muse((motions , motion_mask) , conditions , cond_drop_prob = 0.4 , return_logits = True)

In [79]:
pred_indices = lologits.argmax(-1)
pred_motion  = vqvae_model.decode(pred_indices[:1])

In [65]:
mod_motion = torch.where(motions >= 1024 , 0 , motions)
gt_motion  = vqvae_model.decode(mod_motion)

In [71]:
gt_motion[1].shape

torch.Size([300, 263])

In [66]:
dset.render_hml(
                    gt_motion[1][:(int(sum(motion_mask[1])) *4)].detach().squeeze().cpu(),
                    "/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/render/gt_motion_recon.gif"
                )

In [27]:
inputs["names"]

array(['animation/subset_0000/Ways_To_Catch_Juggling',
       'animation/subset_0003/Ways_To_Text_One_Minute',
       'animation/subset_0000/Ways_To_Go_To_Bed_Sleep_Wake_Up_Not_Enough_Blanket',
       'animation/subset_0002/Ways_To_Open_A_Christmas_Gift_Kevinbparry'],
      dtype='<U72')

In [None]:
logits2 = motion_muse.transformer.forward_with_cond_scale((motions , motion_mask) , conditions)
logits3 =motion_muse.transformer.forward_with_neg_prompt((motions , motion_mask) , conditions , conditions)

In [48]:
motions.shape

torch.Size([4, 28])

In [49]:
28*4

112

In [67]:
gen_ids = motion_muse.generate(conditions)

100%|█████████████████████████████████████████████████████████████████████████| 18/18 [00:01<00:00, 14.65it/s]


In [57]:
conditions["text"][0].shape

torch.Size([4, 1, 768])

## Inference

In [None]:
text = "a man dancing"
audio = 

In [33]:
aud , am = condition_provider._get_audio_features(audio_list = [None])

In [34]:
aud.shape

(1, 1, 128)

In [82]:

sap = AttentionParams(dim = 256 , causal=True)
cap = AttentionParams(dim = 256 , causal=True , add_null_kv=True)
transformer_params = TranslationTransformerParams(self_attention_params = sap , 
                                                  cross_attention_params = cap , 
                                                  depth = 1, 
                                                  positional_embedding_params=PositionalEmbeddingParams(dim = 256) , 
                                                  positional_embedding=PositionalEmbeddingType.SINE,
                                                  fuse_method = {"cross_seperate" : ["audio" , "text"]}
                                                 )

In [87]:
embed = transformer_blocks(
            x=x_,
            mask=x_padding_mask,
            context=context,
            context_mask=context_padding_mask,
        )

In [89]:
embed = embed[:, -n:, :]

In [90]:
embed.shape

torch.Size([4, 116, 256])

In [None]:
class FrozenCLIPEmbedder(nn.Module):
    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens)

        z = outputs.last_hidden_state
        return z

    def encode(self, text):
        return self(text)

In [143]:
pos_emb = ScaledSinusoidalEmbedding(PositionalEmbeddingParams(dim = 256))

In [161]:
b , n , _ = input["motion"][0].shape

In [167]:

x = pos_emb(input["motion"][0]).repeat(b , 1 ,1)


In [173]:
x.shape

torch.Size([4, 116, 256])

In [None]:
for a, (b , c) in conditions.items():
    print(a)
    print(b.shape)

In [175]:
inputs_ , cross_inputs = condition_fuser(x , conditions  )

In [176]:
inputs_.shape

torch.Size([4, 153, 256])

In [177]:
cross_inputs.shape

torch.Size([4, 500, 256])