In [1]:
# setting device on GPU if available, else CPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

NVIDIA GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
from IPython.display import Image


In [4]:
import torch
import numpy as np
import os
import torch.nn as nn
from tqdm import tqdm
import json
from functools import partial
from torch import einsum, nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from einops import pack, rearrange, reduce, repeat, unpack


In [5]:
from core.datasets.text_encoders import T5Conditioner,MPNETConditioner

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
!python train_tmr.py

loading config from: /srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/tmr/tmr.yaml
output_dir:  /srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/tmr
contrastive: 0.1
fact: None
kl: 1e-05
latent: 1e-05
recons: 1.0
temperature: 0.1
threshold_selfsim: 0.8
threshold_selfsim_metrics: 0.95
vae: True
Total training params: 16.13M
Total number of motions animation: 120 and texts 120
Total number of motions idea400: 10949 and texts 10949
Total number of motions animation: 2 and texts 2
Total number of motions idea400: 577 and texts 577
Total number of motions animation: 1 and texts 1
Total number of motions idea400: 2 and texts 2
training with training 11069 and test dataset of 579 and render 3 samples
[34m[1mwandb[0m: Currently logged in as: [33msohananisetty[0m ([33mai-choreo[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.16.6 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --up

In [6]:
# t5 = T5Conditioner("t5-large")
# sent = " a man walks forward, to the end of the platform, then turns counterclockwise on his heel, before walking back to his starting point."
# tokn = t5.tokenize(sent)
# embed, msk = t5.get_text_embedding(tokn)

In [7]:
mp = MPNETConditioner()
# sent = " a man walks forward, to the end of the platform, then turns counterclockwise on his heel, before walking back to his starting point."
# tokn = mp.tokenize(sent)
# embed, msk = mp.get_text_embedding(tokn)

In [152]:
embed.shape

torch.Size([1, 768])

In [8]:
from core.models.utils import instantiate_from_config, get_obj_from_str
from core.datasets.tmr_dataset import TMRDataset, load_dataset, simple_collate
from core import MotionRep, TextRep, AudioRep
from core.datasets.conditioner import ConditionProvider,ConditionFuser
from configs.config_tmr import get_cfg_defaults
from core.models.TMR.tmr import TMR

In [9]:
tmr_cfg = get_cfg_defaults()
tmr_cfg.merge_from_file("/srv/hays-lab/scratch/sanisetty3/music_motion/ATCMG/checkpoints/tmr/tmr.yaml")
tmr_cfg.freeze()
dataset_args = tmr_cfg.dataset
tmr_parms = tmr_cfg.tmr

In [10]:
_ = tmr_parms.pop("target")
motion_encoder = instantiate_from_config(tmr_cfg.motion_encoder).to(device)
text_encoder = instantiate_from_config(tmr_cfg.text_encoder).to(device)
motion_decoder = instantiate_from_config(tmr_cfg.motion_decoder).to(device)
tmr = TMR(motion_encoder , text_encoder , motion_decoder , lr = tmr_cfg.train.learning_rate, **tmr_parms).to(device)

In [11]:
train_ds, sampler_train, weights_train  = load_dataset(dataset_names = ["idea400" , "animation"] , dataset_args=dataset_args, split = "test")


Total number of motions idea400: 577 and texts 577
Total number of motions animation: 2 and texts 2


In [12]:
condition_provider = ConditionProvider(
            text_conditioner_name = dataset_args.text_conditioner_name,
            motion_rep=MotionRep(dataset_args.motion_rep),
            audio_rep=AudioRep(dataset_args.audio_rep),
            text_rep=TextRep(dataset_args.text_rep),
            motion_padding=dataset_args.motion_padding,
            motion_max_length_s=dataset_args.motion_max_length_s,
            fps=30,
            # only_motion = True
        )

In [13]:
train_loader = torch.utils.data.DataLoader(
        train_ds,
        4,
        sampler=sampler_train,
        # shuffle = False,
        collate_fn=partial(simple_collate , conditioner = condition_provider),
        # drop_last=True,
    )

In [14]:
for inputs, conditions in (train_loader):
    break
    

In [15]:
motion = inputs["motion"][0]
motion_mask = inputs["motion"][1].to(torch.bool)
text = conditions["text"][0]
text_mask = conditions["text"][1].to(torch.bool)

In [143]:
motion_mask.device

device(type='cuda', index=0)

In [21]:
tkn = mp.tokenize(list(inputs["texts"]))
sent_embed, sent_msk = mp.get_text_embedding(tkn)
conditions["sent_emb"] = (sent_embed , sent_msk)

In [28]:
losses = tmr.compute_loss(inputs["motion"] , conditions)

In [29]:
losses

{'recons': tensor(0.6562, device='cuda:0', grad_fn=<AddBackward0>),
 'kl': tensor(8.7833, device='cuda:0', grad_fn=<AddBackward0>),
 'latent': tensor(1.3628, device='cuda:0', grad_fn=<SmoothL1LossBackward0>),
 'contrastive': tensor(1.7467, device='cuda:0', grad_fn=<DivBackward0>),
 'loss': tensor(0.8310, device='cuda:0', grad_fn=<AddBackward0>)}

In [54]:
m_motions, m_latents, m_dists = tmr(inps, mask=motion_mask, return_all=True)

In [None]:
t_motions, t_latents, t_dists = tmr(text_x_dict, mask=text_mask, return_all=True)

In [None]:
tokn = mp.tokenize(sent)
embed, msk = mp.get_text_embedding(tokn)

In [30]:
pwd

'/coc/scratch/sanisetty3/music_motion/ATCMG'