Skip to content

rirolab/DiSPo

Repository files navigation

DiSPo: Diffusion-SSM based Policy Learning for Coarse-to-Fine Action Discretization

Project PagePaperCodeBibTex

We introduce a novel diffusion-state space model-based policy (DiSPo) that leverages a state-space model, Mamba, to learn from diverse coarse demonstrations and generate multi-scale actions

overview

Installation

This repository is tested on ubuntu 20.04, CUDA 12.1 with docker environment.

First, clone this repository:

git clone --recurse-submodules https://github.com/rirolab/DiSPo.git

Install dependencies

sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf

Install Mujoco

wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
tar -zxvf mujoco210-linux-x86_64.tar.gz
mkdir -p ~/.mujoco
mv mujoco210 ~/.mujoco
  • download link
  • make folder ~/.mujoco/mujoco210
  • unzip downloaded file using tar -zxvf

Create a conda environment:

conda env create -f conda_environment.yaml
conda activate dispo

Then, install the DiSPo block using,

cd dispo_ssm
python setup.py install
cd ..

Dataset preparation

Task: clamp passing

python expert_clmap.py -o output_data_path -d pipe_stl_path -chz control_freq -rhz recording_freq
  • output_data_path: output dataset path, e.g. data/clamp/dataset.zarr
  • pipe_stl_path: pipe stl folder path, e.g. data/straight_thin/path
  • control_freq: agent control frequency, e.g. 20
  • recording_freq: demonstration recording frequency, chz%rhz == 0, e.g. 5

For example,

python demo_clamp.py -o data/clamp_20_5.zarr -d dispo/env/clamp_passing/robots/urdf/pipe/path -chz 20 -rhz 5

Task: passage passing

python demo_passage.py -o output_data_path -chz control_freq -rhz recording_freq

For example,

python demo_passage.py -o data/passage_20_5.zarr -chz 20 -rhz 5

Task: button touch

python demo_button.py -a action_obs_mode -o output_data_path -chz control_freq -rhz recording_freq
  • action_obs_mode: 0: joint angle obs/action, 1: eef pose obs/action, 2: joint angle obs/eef pose action, choosing 1 is recommended

Running the training code

Pre-training observation encoder/decoder

This process is optional.

Before you start, please check that the data path is correct. I recommend you to override the setting through terminal.

python ae_pretrain.py --config-name pretrain_dispo_workspace.yaml task=task  task.dataset.zarr_path=data_path
  • task: name of the task, currently support [clamp, passage, button]
  • data_path: path to the generated dataset, e.g. data/button/button_20_20.zarr

For example,

python ae_pretrain.py --config-name pretrain_dispo_workspace.yaml task=passage task.dataset.zarr_path=data/passage_20_5.zarr

Main training

For example, if you want to train button touch task from original demonstrations (e.g. 20Hz agent w/ 20Hz recording)

python train.py --config-dir=dispo/config --config-name=train_dispo_workspace.yaml hydra.run.dir='data/outputs_button/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='button_20' policy.ae_ckpt_path=ae_ckpt_path task=button task.dataset.zarr_path=data_path task.env_runner.proxy_sample_rate=None checkpoint.topk.monitor_key='test_success_rate' checkpoint.topk.format_str='epoch\=\{epoch:04d\}-test_success_rate\=\{test_success_rate:.3f\}.ckpt'
  • ae_ckpt_path: need to set ae_ckpt_path based on the output of python ae_pretraining.py, e.g. data/ae_checkpoints/Train_button_20_20.zarr_2025_04_05_20_28/model_encoder_best.pt. If you do not want pretrained ae, just set as None
  • data_path: path to the generated dataset, e.g. data/button/button_20_20.zarr Make sure you set task.env_runner.proxy_sample_rate=None

For example, if you want to train passage passing task from 5Hz demonstrations (e.g. 20Hz agent w/ 5Hz recording)

python train.py --config-dir=dispo/config --config-name=train_dispo_workspace.yaml hydra.run.dir='data/outputs_passage/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='passage_5' policy.ae_ckpt_path=ae_ckpt_path task=passage task.dataset.zarr_path=data_path task.env_runner.proxy_sample_rate=0.25

Main point here is setting task.env_runner.proxy_sample_rate=0.25

Finetuning

First, you need to generate pseudo demonstrations by executing savedata_dispo_workspace.yaml For example, if you want to train passage passing task from 5Hz demonstration, generating 10Hz pseudo demonstrations,

python train.py --config-dir=dispo/config --config-name=savedata_dispo_workspace.yaml hydra.run.dir='data/outputs_passage/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='save_passage_5_10' task=passage task.dataset.zarr_path=data_path training.proxy_sample_rate=0.5 training.base_ckpt_path=base_ckpt_path training.save_path=save_path
  • base_ckpt_path: path to the trained checkpoint
  • save_path: path to store the generated pseudo demonstrations

Then, you can start finetuning

python train.py --config-dir=dispo/config --config-name=finetune_dispo_workspace.yaml hydra.run.dir='data/outputs_passage/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='passage_finetune_5_10' task=passage training.base_ckpt_path=base_ckpt_path training.ori_zarr_path=data_path task.dataset.zarr_path=save_path task.dataset.sample_rate=2 task.env_runner.proxy_sample_rate=0.25 training.proxy_sample_rate=0.5

If you want, you can try

  • freeze_more: to freeze more DiSPo layers
  • no_freeze_body: to train whole DiSPo layers and change learning rate to achieve the best performance.

If you want to generate more than twice finer trajecotires. For example, to generate 20Hz pseudo demonstrations

python train.py --config-dir=dispo/config --config-name=savedata_dispo_workspace.yaml hydra.run.dir='data/outputs_passage/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}' logging.name='save_passage_5_20' task=passage task.dataset.zarr_path=data_path task.env_runner.proxy_sample_rate=0.25 training.base_ckpt_path=base_ckpt_path training.save_path=save_path

We recommend to repeat the finetuning by generating twice finer demonstrations until reach the designated granularities.

Train the factor predictor

You need to generate the dataset for training the factor predictor. Here we use simple rule based approaches to label the critical regions.

Run

python data_gen_factor_predictor.py -o output_data_path -d original_data_path
  • output_data_path: output dataset path, e.g. data/clamp/dataset_rp.zarr
  • original_data_path: original dataset path generated by executing codes start with demo_, e.g. data/clamp/dataset.zarr

Then, you can start training the factor predictor by

python factor_predictor_train.py --config-name factor_train_dispo_workspace.yaml task=task task.dataset.zarr_path=output_data_path training.base_ckpt_path=base_ckpt_path
  • task: name of the task, currently support [clamp, passage, button]
  • output_data_path: path to the generated dataset, e.g. data/button/button_20_20.zarr

Running the evaluation code

python eval.py -c ckpt_path -o output_path
  • ckpt_path: path to the best checkpoint
  • output_path: path to the output folder

You can conduct the evaluation with the factor predictory by

python eval.py -c ckpt_path -o output_path -fp factor_predictor_ckpt_path
  • factor_predictor_ckpt_path: path to the best checkpoint of a factor predictor, trained through factor_predictor_train.py

Acknowledgement

We thank to open source repositories: Diffusion Policy, and Mamba.

BibTex

@inproceedings{oh2026dispo,
  title={DiSPo: Diffusion-SSM based Policy Learning for Coarse-to-Fine Action Discretization},
  author={Oh, Nayoung and Jang, Jaehyeong and Jung, Moonkyeong and Park, Daehyung},
  booktitle={Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
  year={2026}
}

About

[ICRA 2026] An official implementation of the paper "DiSPo: Diffusion-SSM based Policy Learning for Coarse-to-Fine Action Discretization"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors