In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

if "../" not in sys.path:
    sys.path.append("../")
    print("[sys.path]:", sys.path)

[sys.path]: ['/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/nadir/motion-linner/.motion-linner.venv/lib/python3.10/site-packages', '../']


In [3]:
import os
import tqdm
import torch
import datasets

import numpy as np

from dotenv import load_dotenv
from huggingface_hub import login

from src.data.datasets import BabelDataset

  from .autonotebook import tqdm as notebook_tqdm


[.env]: True


In [4]:
print("[.env]:", load_dotenv())

[.env]: True


In [10]:
from src.constants import (
    HUGGING_FACE_TOKEN,
    BABEL_REMOTE_DATASET_NAME,
    HML3D_REMOTE_DATASET_NAME,
)

In [11]:
login(token=HUGGING_FACE_TOKEN)

In [12]:
print("[BABEL]:", datasets.get_dataset_config_names(BABEL_REMOTE_DATASET_NAME, trust_remote_code=True))
print("[HML3D]:", datasets.get_dataset_config_names(HML3D_REMOTE_DATASET_NAME, trust_remote_code=True))

[.env]: True
[BABEL]: ['full_joint_vecs', 'labels_only', 'motion_all', 'motion_joint_vecs', 'motion_joints', 'full_all_motion', 'full_joints']
[HML3D]: ['full_joint_vecs', 'labels_only', 'motion_all', 'motion_joint_vecs', 'motion_joints', 'full_all_motion', 'full_joints']


In [13]:
from src.data.datasets import HML3DDataset, BabelDataset

babel_dataset = datasets.load_dataset(
    BABEL_REMOTE_DATASET_NAME,
    trust_remote_code=True,
    name="full_all_motion"
)

Repo card metadata block was not found. Setting CardData to empty.


In [9]:
babel_dataset["train"][0]

{'sid': '7905',
 'amass_file_relative_path': 'KIT/1226/Trial_70_poses.npz',
 'duration': 10.569999694824219,
 'sequence_annotations': {'labels': {'act_cat': [['raising body part',
     'hand movements']],
   'proc_label': ['hand rise'],
   'raw_label': ['hand rise'],
   'start_t': [0.0],
   'end_t': [10.569999694824219]}},
 'frame_annotations': {'labels': {'act_cat': [],
   'proc_label': [],
   'raw_label': [],
   'start_t': [],
   'end_t': []}},
 'motion': {'new_joint_vecs': [[-0.03423130512237549,
    -3.4834381949622184e-05,
    -0.00027727760607376695,
    0.8576548099517822,
    0.057211533188819885,
    0.7793337106704712,
    0.014336563646793365,
    -0.060422807931900024,
    0.7718049883842468,
    0.005111958831548691,
    0.0028455667197704315,
    0.9732496738433838,
    0.042833615094423294,
    0.12428243458271027,
    0.43814119696617126,
    0.03800967335700989,
    -0.13072486221790314,
    0.4359435439109802,
    0.04511795938014984,
    0.011742609553039074,
    1.1

---

In [None]:
from src.data.filtering import create_babel_filter_fn
from src.constants import DEFAULT_FPS, DEFAULT_SEED

filtering_function = create_babel_filter_fn(
    seed=DEFAULT_SEED,
    fps=DEFAULT_FPS,
    min_motion_frames=20,
    max_motion_frames=4096,
    min_prompts_per_sample=1,
    max_prompts_per_sample=4,
    # TODO: instead we should just pass a lambda function that receives the prompt text and returns a boolean indicatin whether we should keep it or not
    # prompt_text_blacklist=[
    #     "transition",    
    # ],
    prompt_text_filter_fn=None, # TODO: use a predefined one from the filtering file
    min_span_frames=1,
    max_span_frames=32,
    max_spans_per_prompt=8,
    debug=False,
)

In [17]:
dsds = babel_dataset.map(
    filtering_function,
    batched=True,
    batch_size=32   
)

Map: 100%|██████████| 6615/6615 [06:17<00:00, 17.50 examples/s]
Map: 100%|██████████| 2193/2193 [02:30<00:00, 14.57 examples/s]


In [22]:
print("[#train]", len(dsds["train"]))
print("[#validation]", len(dsds["validation"]))

[#train] 3667
[#validation] 1251


---

<div class="alert alert-info">

Supported Types: [None, 'arrow', 'numpy', 'pandas', 'custom', 'torch'].

```python
babel_dataset["train"].set_format("ds")
```

</div>

---

In [13]:
from src.model.motion_encoders.tmr import TMR

In [15]:
encoder = TMR(
    latent_dim=256
)

In [None]:
inputs = next(iter(train_dataloader))
outputs = encoder(inputs)

In [None]:
cls_token, final = outputs

In [None]:
print("[inputs.motion.shape]:", inputs["motion"].shape)

print("--- --- ---")

print("[cls_token.shape]:", cls_token.shape)
print("[final.shape]:", final.shape)

[inputs.motion.shape]: torch.Size([32, 916, 263])
--- --- ---
[cls_token.shape]: torch.Size([32, 1, 256])
[final.shape]: torch.Size([32, 917, 256])


---

In [None]:
from src.data.typing import RawBatch, ProcessedBatch

from src.data.batching import hml3d_create_raw_batch_collate_fn

from src.data.batching import PromptGenerationMode, babel_create_raw_batch_collate_fn, babel_augment_and_split_batch

In [25]:
new_set = babel_dataset["train"].map(augment_and_split_batch, batched=True, batch_size=8)

Map: 100%|██████████| 6615/6615 [03:16<00:00, 33.70 examples/s]


In [28]:
print(len(babel_dataset["train"]))
print(len(new_set))

6615
10600


In [None]:
from src.data.typing import RawBatch

my_collate_fn = create_raw_batch_collate_fn(
    fps=20,
    mode=PromptGenerationMode.BOTH,
)

train_dataloader = torch.utils.data.DataLoader(
    babel_dataset["train"],
    batch_size=8,
    collate_fn=my_collate_fn,
    shuffle=True
)

In [13]:
for batch in tqdm.tqdm(train_dataloader):
    # print(batch)
    # print(batch["ds"])
    # print(batch["amass_relative_path"])
    # print(batch["sid"])
    # print(batch["dataset_name"])
    # print(batch["motion"])
    # print(batch["motion_length"])
    # print(batch["motion_start_end"])
    # print(batch["motion_start_end_length"])
    break

  0%|          | 0/827 [00:00<?, ?it/s]


In [14]:
batch = next(iter(train_dataloader))

In [15]:
print("[batch]:", list(batch.__dataclass_fields__.keys()))
print("[batch.raw_motion]:", batch.raw_motion.shape)
print("[batch.transformed_motion]:", batch.transformed_motion.shape)
print("[batch.motion_mask]:", batch.motion_mask.shape)
print("[batch.prompts]:", batch.prompts)
batch

[batch]: ['sid', 'dataset_name', 'amass_relative_path', 'raw_motion', 'transformed_motion', 'motion_mask', 'prompts']
[batch.raw_motion]: torch.Size([8, 453, 22, 3])
[batch.transformed_motion]: torch.Size([8, 453, 263])
[batch.motion_mask]: torch.Size([8, 453])
[batch.prompts]: [[('walk', 0, 452, True)], [('run', 0, 27, True)], [('sidestep', 0, 163, True), ('side step to left', 54, 116, True), ('side step to right', 0, 54, True), ('side step to right', 116, 160, True), ('stand', 160, 163, True)], [('turn', 0, 95, True), ('turn right', 32, 58, True), ('stand', 0, 15, True), ('stand', 79, 95, True), ('walk', 15, 32, True), ('walk', 58, 79, True)], [('walk with support', 0, 285, True), ('walk forward', 23, 246, True), ('stand', 0, 15, True), ('transition', 15, 23, True), ('transition', 246, 285, True), ('apose', 285, 285, True), ('use left handrail', 61, 242, True), ('use right handrail', 35, 250, True)], [('wave', 0, 59, True)], [('walk', 0, 391, True)], [('sneak', 0, 114, True), ('stand

RawBatch(sid=['1189', '3473', '5809', '7127', '8156', '2996', '3668', '410'], dataset_name=['babel', 'babel', 'babel', 'babel', 'babel', 'babel', 'babel', 'babel'], amass_relative_path=['BioMotionLab_NTroje/rub078/0002_treadmill_slow_poses.npz', 'CMU/02/02_03_poses.npz', 'BMLmovi/Subject_18_F_MoSh/Subject_18_F_11_poses.npz', 'KIT/9/LeftTurn06_poses.npz', 'KIT/675/walk_slow_with_handrail_table_beam_left03_poses.npz', 'BMLmovi/Subject_65_F_MoSh/Subject_65_F_7_poses.npz', 'BioMotionLab_NTroje/rub053/0001_treadmill_fast_poses.npz', 'CMU/74/74_11_poses.npz'], raw_motion=tensor([[[[ 0.0000e+00,  8.3621e-01,  0.0000e+00],
          [ 6.3360e-02,  7.5964e-01,  1.7598e-04],
          [-5.4326e-02,  7.4410e-01,  3.2346e-03],
          ...,
          [-2.7796e-01,  1.0306e+00,  8.4422e-03],
          [ 2.3049e-01,  8.0514e-01, -1.7474e-02],
          [-2.3149e-01,  8.4153e-01, -1.2744e-01]],

         [[-1.2018e-03,  8.3441e-01,  9.5276e-03],
          [ 6.2813e-02,  7.5840e-01,  1.0637e-02],
   

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

preprocessed_batch = ProcessedBatch.from_raw_batch(
    batch,
    tokenizer,
)

print("[preprocessed_batch.keys()]:", list(preprocessed_batch.__dataclass_fields__.keys()))

print(preprocessed_batch.target_spans.shape)odel 
print(preprocessed_batch.prompt_input_ids[0].shape)
preprocessed_batch.prompt_input_ids[0][0]
preprocessed_batch.prompt_attention_mask[0][0]