In [None]:
def resize_pad(x):
    if x.shape[0] < FRAME_LEN:
        x = F.pad(x, (0, 0, 0, FRAME_LEN - x.shape[0], 0, 0))
    else:
        x = x.unsqueeze(0)  # Add batch and channel dimensions
        x = torch.nn.functional.interpolate(
            x, size=(FRAME_LEN, x.shape[1]), mode="bilinear", align_corners=False
        ).squeeze(0)

    return x


def frames_preprocess(x):
    x = torch.tensor(x)
    rhand = x[:, RHAND_IDX]
    lhand = x[:, LHAND_IDX]
    rpose = x[:, RPOSE_IDX]
    lpose = x[:, LPOSE_IDX]

    rnan_idx = torch.any(torch.isnan(rhand), dim=1)
    lnan_idx = torch.any(torch.isnan(lhand), dim=1)

    rnans = torch.sum(rnan_idx)
    lnans = torch.sum(lnan_idx)

    if rnans > lnans:
        hand = lhand
        pose = lpose

        hand_x = hand[:, 0 * (len(LHAND_IDX) // 3) : 1 * (len(LHAND_IDX) // 3)]
        hand_y = hand[:, 1 * (len(LHAND_IDX) // 3) : 2 * (len(LHAND_IDX) // 3)]
        hand_z = hand[:, 2 * (len(LHAND_IDX) // 3) : 3 * (len(LHAND_IDX) // 3)]
        hand = torch.cat([1 - hand_x, hand_y, hand_z], dim=1)

        pose_x = pose[:, 0 * (len(LPOSE_IDX) // 3) : 1 * (len(LPOSE_IDX) // 3)]
        pose_y = pose[:, 1 * (len(LPOSE_IDX) // 3) : 2 * (len(LPOSE_IDX) // 3)]
        pose_z = pose[:, 2 * (len(LPOSE_IDX) // 3) : 3 * (len(LPOSE_IDX) // 3)]
        pose = torch.cat([1 - pose_x, pose_y, pose_z], dim=1)
    else:
        hand = rhand
        pose = rpose

    hand_x = hand[:, 0 * (len(LHAND_IDX) // 3) : 1 * (len(LHAND_IDX) // 3)]
    hand_y = hand[:, 1 * (len(LHAND_IDX) // 3) : 2 * (len(LHAND_IDX) // 3)]
    hand_z = hand[:, 2 * (len(LHAND_IDX) // 3) : 3 * (len(LHAND_IDX) // 3)]
    hand = torch.cat(
        [hand_x.unsqueeze(-1), hand_y.unsqueeze(-1), hand_z.unsqueeze(-1)], dim=-1
    )

    mean = torch.mean(hand, dim=1).unsqueeze(1)
    std = torch.std(hand, dim=1).unsqueeze(1)
    hand = (hand - mean) / std

    pose_x = pose[:, 0 * (len(LPOSE_IDX) // 3) : 1 * (len(LPOSE_IDX) // 3)]
    pose_y = pose[:, 1 * (len(LPOSE_IDX) // 3) : 2 * (len(LPOSE_IDX) // 3)]
    pose_z = pose[:, 2 * (len(LPOSE_IDX) // 3) : 3 * (len(LPOSE_IDX) // 3)]
    pose = torch.cat(
        [pose_x.unsqueeze(-1), pose_y.unsqueeze(-1), pose_z.unsqueeze(-1)], dim=-1
    )

    x = torch.cat([hand, pose], dim=1)
    # print(f"befor  re{x.shape}")
    x = resize_pad(x)
    # print(f"after re{x.shape}")
    x = torch.where(torch.isnan(x), torch.zeros_like(x), x)
    # print(x.shape)

    #! CRITICAL Debug
    # x = x.view(FRAME_LEN, len(LHAND_IDX) + len(LPOSE_IDX))
    return x

In [1]:
"""doc
"""

FRAME_LEN = 128

LPOSE = [13, 15, 17, 19, 21]
RPOSE = [14, 16, 18, 20, 22]
POSE = LPOSE + RPOSE

X = (
    [f"x_right_hand_{i}" for i in range(21)]
    + [f"x_left_hand_{i}" for i in range(21)]
    + [f"x_pose_{i}" for i in POSE]
)
Y = (
    [f"y_right_hand_{i}" for i in range(21)]
    + [f"y_left_hand_{i}" for i in range(21)]
    + [f"y_pose_{i}" for i in POSE]
)
Z = (
    [f"z_right_hand_{i}" for i in range(21)]
    + [f"z_left_hand_{i}" for i in range(21)]
    + [f"z_pose_{i}" for i in POSE]
)

FEATURE_COLUMNS = X + Y + Z

X_IDX = [i for i, col in enumerate(FEATURE_COLUMNS) if "x_" in col]
Y_IDX = [i for i, col in enumerate(FEATURE_COLUMNS) if "y_" in col]
Z_IDX = [i for i, col in enumerate(FEATURE_COLUMNS) if "z_" in col]

RHAND_IDX = [i for i, col in enumerate(FEATURE_COLUMNS) if "right" in col]
LHAND_IDX = [i for i, col in enumerate(FEATURE_COLUMNS) if "left" in col]
RPOSE_IDX = [
    i
    for i, col in enumerate(FEATURE_COLUMNS)
    if "pose" in col and int(col[-2:]) in RPOSE
]
LPOSE_IDX = [
    i
    for i, col in enumerate(FEATURE_COLUMNS)
    if "pose" in col and int(col[-2:]) in LPOSE
]


In [4]:
len(RHAND_IDX), len(LHAND_IDX)

(63, 63)

In [None]:
def read_file(file, file_id, landmarks_metadata_path):
    phrase_list = []
    frames_list = []
    metadata_train_dataframe = pd.read_csv(landmarks_metadata_path)
    file_id_df = metadata_train_dataframe.loc[
        metadata_train_dataframe["file_id"] == file_id
    ]
    saved_parueat_df = pq.read_table(
        file, columns=["sequence_id"] + FEATURE_COLUMNS
    ).to_pandas()
    for seq_id, phrase in zip(file_id_df.sequence_id, file_id_df.phrase):
        frames = saved_parueat_df[saved_parueat_df.index == seq_id].to_numpy()
        # NaN
        right_num_nan = np.sum(np.sum(np.isnan(frames[:, RHAND_IDX]), axis=1) == 0)
        left_num_nan = np.sum(np.sum(np.isnan(frames[:, LHAND_IDX]), axis=1) == 0)
        
        total_num_nan = max(right_num_nan, left_num_nan)
        if 2 * len(phrase) < total_num_nan:
            frames_list.append(frames)
            phrase_list.append(phrase)
    return (frames_list, phrase_list)
