Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling
Raunaq Bhirangi, Chenyu Wang, Venkatesh Pattabiraman, Carmel Majidi, Abhinav Gupta, Tess Hellebrekers and Lerrel Pinto
Paper: https://arxiv.org/abs/2402.10211
Website: https://hiss-csp.github.io/
HiSS is a simple technique that stacks deep state space models like S4 and Mamba to reason over continuous sequences of sensory data over mutiple temporal hierarchies. We also release CSP-Bench: a benchmark for sequence-to-sequence prediction from sensory data.
-
Clone the repository
-
Create a conde environment from the provided
env.ymlfile:conda env create -f env.yml -
Install Mamba based on the official instructions.
Note: If you run into CUDA issues while installing Mamba, run export CUDA_HOME=$CONDA_PREFIX, and try again. If you still have problems, install both causal_conv1d and mamba-ssm from source.
-
Refer to data_processing/README to download and extract the required dataset.
-
Set the
DATA_DIRvariable in thehiss/utils/__init__.pyfile. This is the path to the parent directory which contains folders corresponding to every dataset. -
Process the datasets into format compatible with training
Marker Writing:python data_processing/process_reskin_data.py -dd marker_writing_<hiss/full>_dataset
Intrinsic Slip:python data_processing/process_reskin_data.py -dd intrinsic_slip_<hiss/full>_dataset
Joystick Control:python data_processing/process_xela_data.py -dd joystick_control_<hiss/full>_dataset
RoNIN:python data_processing/process_ronin_data.py
VECtor:python data_processing/process_vector_data.py
TotalCapture:python data_processing/process_total_capture_data.py -
Run
create_dataset.pyfor the respective dataset to preprocess data and resample it at the desired frequencies.
Marker Writing:python create_dataset.py --config-name marker_writing_config
Intrinsic Slip:python create_dataset.py --config-name intrinsic_slip_config
Joystick Control:python create_dataset.py --config-name joystick_control_config
RoNIN:
python create_dataset.py --config-name ronin_train_config
python create_dataset.py --config-name ronin_test_config
VECtor:python create_dataset.py --config-name vector_config
TotalCapture:
python create_dataset.py --config-name total_capture_train_config
python create_dataset.py --config-name total_capture_test_config
To train HiSS models for sequential prediction, use the train.py file. For each dataset, we provide a <dataset_name>_hiss_config.yaml file in the conf/ directory, containing model parameters corresponding to the best-performing HiSS model for the respective dataset. To train the model, simply run
python train.py --config-name <dataset_name>_hiss_config
New datasets can be added by creating a corresponding Task object in line with tasks defined in vt_state/tasks, and creating a config file in conf/data_env/<data_env_name>.
