In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt


def preprocess_sequence(sequence, target_length, padding_value=0):
    """
    Preprocesses the raw sequence data to have a fixed length.

    Args:
    - sequence (np.ndarray): The input sequence of shape (sequence_length, num_joints * 3)
    - target_length (int): The fixed length to which the sequence should be clipped or padded.
    - num_joints (int): The number of joints per frame in the sequence.
    - padding_value (float): The value used to pad shorter sequences.

    Returns:
    - np.ndarray: The preprocessed sequence of shape (target_length, num_joints * 3)
    """
    current_length, features = sequence.shape[:2]
    if current_length > target_length:
        # Clip the sequence if it's longer than the target length
        return sequence[:target_length]
    elif current_length < target_length:
        # Pad the sequence if it's shorter than the target length
        padding = np.full((target_length - current_length, features, 3), padding_value)
        return np.vstack([sequence, padding])
    return sequence

labels = ['keep', 'come', 'stop', 'ring']
labels_to_idx = {label: idx for idx, label in enumerate(labels)}
sequences = []
labels = []
target_length = 100

root = './datasets'
for file in os.listdir('datasets'):
    if 'motion' in file:
        motion = np.load(os.path.join(root, file))
        print(file, motion.shape)
        motion = preprocess_sequence(motion, target_length)
        sequences.append(motion)
        labels.append(labels_to_idx[file.split('_')[1]])

sequences = np.stack(sequences).reshape(-1, target_length, 63)
labels = np.array(labels)
print(sequences.shape, labels.shape)
np.savez('./datasets/motion_datasets.npz', sequences=sequences, labels=labels)

        # plt.plot(np.sum(motion, axis=1)[:, 1])
# motion = np.load('datasets/motion_come_04_18_2024_03_09_44.npy')
# print(motion.shape)
# plt.figure(figsize=(15,5))
# plt.plot(np.sum(motion, axis=1))