Project Page • Paper • Code • BibTex
We introduce a novel diffusion-state space model-based policy (DiSPo) that leverages a state-space model, Mamba, to learn from diverse coarse demonstrations and generate multi-scale actions
This repository is tested on ubuntu 20.04, CUDA 12.1 with docker environment.
First, clone this repository:
git clone --recurse-submodules https://github.com/rirolab/DiSPo.gitInstall dependencies
sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelfInstall Mujoco
wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
tar -zxvf mujoco210-linux-x86_64.tar.gz
mkdir -p ~/.mujoco
mv mujoco210 ~/.mujoco- download link
- make folder
~/.mujoco/mujoco210 - unzip downloaded file using
tar -zxvf
Create a conda environment:
conda env create -f conda_environment.yaml
conda activate dispoThen, install the DiSPo block using,
cd dispo_ssm
python setup.py install
cd ..python expert_clmap.py -o output_data_path -d pipe_stl_path -chz control_freq -rhz recording_freqoutput_data_path: output dataset path, e.g. data/clamp/dataset.zarrpipe_stl_path: pipe stl folder path, e.g. data/straight_thin/pathcontrol_freq: agent control frequency, e.g. 20recording_freq: demonstration recording frequency, chz%rhz == 0, e.g. 5
For example,
python demo_clamp.py -o data/clamp_20_5.zarr -d dispo/env/clamp_passing/robots/urdf/pipe/path -chz 20 -rhz 5
python demo_passage.py -o output_data_path -chz control_freq -rhz recording_freqFor example,
python demo_passage.py -o data/passage_20_5.zarr -chz 20 -rhz 5python demo_button.py -a action_obs_mode -o output_data_path -chz control_freq -rhz recording_freqaction_obs_mode:0: joint angle obs/action,1: eef pose obs/action,2: joint angle obs/eef pose action, choosing1is recommended
This process is optional.
Before you start, please check that the data path is correct. I recommend you to override the setting through terminal.
python ae_pretrain.py --config-name pretrain_dispo_workspace.yaml task=task task.dataset.zarr_path=data_pathtask: name of the task, currently support [clamp, passage, button]data_path: path to the generated dataset, e.g.data/button/button_20_20.zarr
For example,
python ae_pretrain.py --config-name pretrain_dispo_workspace.yaml task=passage task.dataset.zarr_path=data/passage_20_5.zarrFor example, if you want to train button touch task from original demonstrations (e.g. 20Hz agent w/ 20Hz recording)
python train.py --config-dir=dispo/config --config-name=train_dispo_workspace.yaml hydra.run.dir='data/outputs_button/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='button_20' policy.ae_ckpt_path=ae_ckpt_path task=button task.dataset.zarr_path=data_path task.env_runner.proxy_sample_rate=None checkpoint.topk.monitor_key='test_success_rate' checkpoint.topk.format_str='epoch\=\{epoch:04d\}-test_success_rate\=\{test_success_rate:.3f\}.ckpt'ae_ckpt_path: need to set ae_ckpt_path based on the output ofpython ae_pretraining.py, e.g.data/ae_checkpoints/Train_button_20_20.zarr_2025_04_05_20_28/model_encoder_best.pt. If you do not want pretrained ae, just set asNonedata_path: path to the generated dataset, e.g.data/button/button_20_20.zarrMake sure you settask.env_runner.proxy_sample_rate=None
For example, if you want to train passage passing task from 5Hz demonstrations (e.g. 20Hz agent w/ 5Hz recording)
python train.py --config-dir=dispo/config --config-name=train_dispo_workspace.yaml hydra.run.dir='data/outputs_passage/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='passage_5' policy.ae_ckpt_path=ae_ckpt_path task=passage task.dataset.zarr_path=data_path task.env_runner.proxy_sample_rate=0.25Main point here is setting task.env_runner.proxy_sample_rate=0.25
First, you need to generate pseudo demonstrations by executing savedata_dispo_workspace.yaml
For example, if you want to train passage passing task from 5Hz demonstration, generating 10Hz pseudo demonstrations,
python train.py --config-dir=dispo/config --config-name=savedata_dispo_workspace.yaml hydra.run.dir='data/outputs_passage/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='save_passage_5_10' task=passage task.dataset.zarr_path=data_path training.proxy_sample_rate=0.5 training.base_ckpt_path=base_ckpt_path training.save_path=save_pathbase_ckpt_path: path to the trained checkpointsave_path: path to store the generated pseudo demonstrations
Then, you can start finetuning
python train.py --config-dir=dispo/config --config-name=finetune_dispo_workspace.yaml hydra.run.dir='data/outputs_passage/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='passage_finetune_5_10' task=passage training.base_ckpt_path=base_ckpt_path training.ori_zarr_path=data_path task.dataset.zarr_path=save_path task.dataset.sample_rate=2 task.env_runner.proxy_sample_rate=0.25 training.proxy_sample_rate=0.5If you want, you can try
freeze_more: to freeze more DiSPo layersno_freeze_body: to train whole DiSPo layers and change learning rate to achieve the best performance.
If you want to generate more than twice finer trajecotires. For example, to generate 20Hz pseudo demonstrations
python train.py --config-dir=dispo/config --config-name=savedata_dispo_workspace.yaml hydra.run.dir='data/outputs_passage/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='save_passage_5_20' task=passage task.dataset.zarr_path=data_path task.env_runner.proxy_sample_rate=0.25 training.base_ckpt_path=base_ckpt_path training.save_path=save_pathWe recommend to repeat the finetuning by generating twice finer demonstrations until reach the designated granularities.
You need to generate the dataset for training the factor predictor. Here we use simple rule based approaches to label the critical regions.
Run
python data_gen_factor_predictor.py -o output_data_path -d original_data_pathoutput_data_path: output dataset path, e.g. data/clamp/dataset_rp.zarroriginal_data_path: original dataset path generated by executing codes start withdemo_, e.g. data/clamp/dataset.zarr
Then, you can start training the factor predictor by
python factor_predictor_train.py --config-name factor_train_dispo_workspace.yaml task=task task.dataset.zarr_path=output_data_path training.base_ckpt_path=base_ckpt_pathtask: name of the task, currently support [clamp, passage, button]output_data_path: path to the generated dataset, e.g.data/button/button_20_20.zarr
python eval.py -c ckpt_path -o output_pathckpt_path: path to the best checkpointoutput_path: path to the output folder
You can conduct the evaluation with the factor predictory by
python eval.py -c ckpt_path -o output_path -fp factor_predictor_ckpt_pathfactor_predictor_ckpt_path: path to the best checkpoint of a factor predictor, trained throughfactor_predictor_train.py
We thank to open source repositories: Diffusion Policy, and Mamba.
@inproceedings{oh2026dispo,
title={DiSPo: Diffusion-SSM based Policy Learning for Coarse-to-Fine Action Discretization},
author={Oh, Nayoung and Jang, Jaehyeong and Jung, Moonkyeong and Park, Daehyung},
booktitle={Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
year={2026}
}
