Diffusion Policy with Cotraining, MMD, UOT, and BEA.
- Ubuntu 18.04+
- NVIDIA GPU with CUDA 11.8
The three simulator packages (robomimic, robosuite, mimicgen) are from the simpreal_p repo. *edit package path and miniconda path before running.
bash setup.shWhat setup.sh does:
- Installs system packages (
cmake,patchelf, OSMesa, etc.) - Installs Miniconda to
~/miniconda3(skipped if already present) - Creates a conda env named
robomimicwith Python 3.9 - Installs PyTorch 2.0 (GPU build preferred, CPU fallback)
- Installs the
mujocopip package - Clones simpreal_p into
~/opt/simpreal_pand pip-installs the three sub-packages in editable mode:~/opt/simpreal_p/robomimic~/opt/simpreal_p/robosuite~/opt/simpreal_p/mimicgen
After installation, activate the environment in any new shell:
source ~/miniconda3/etc/profile.d/conda.sh
conda activate robomimicThe setup exports MUJOCO_GL=osmesa (software rendering). If you have a
display or EGL available you can override this before running:
export MUJOCO_GL=egl # GPU-accelerated headless
# or
export MUJOCO_GL=glfw # on-screen windowTraining configs expect HDF5 demo datasets. Data for testing are uploaded to drive. Typical files:
| File | Description |
|---|---|
rs_demo_500.hdf5 |
500 demos in standard robosuite Stack |
rs_wood_20.hdf5 |
20 demos in StackWood (wood table, agentview) |
rs_wood_20_s.hdf5 |
20 demos in StackWood (wood table, agentview45 camera) |
pair_info_rsw.json |
dtw pair info for UOT (wood) |
pair_info_rsws.json |
dtw pair info for UOT (wood + sideview) |
python train.py -hTrain a Diffusion Policy on standard robosuite Stack demonstrations:
python train.py --config exp/single/dp_stack_rs.jsonCheckpoints, TensorBoard logs, and rollout videos are saved to the
output_dir specified in the config (e.g. ~/atz/dp_stack_ckpts_rs).
Five co-training modes are available via --cotraining:
| Mode | Flag | Description |
|---|---|---|
| Off | --cotraining off |
Standard single-domain BC (default) |
| Vanilla | --cotraining vanilla |
Convex combination of source and target losses |
| OT | --cotraining ot |
Optimal Transport alignment on encoder features |
| MMD | --cotraining mmd |
Energy-distance MMD alignment on encoder features |
| BEA | --cotraining bea |
Best-Effort Adaptation (q-weighted ERM) |
Vanilla co-training (source: 100 robosuite demos, target: 200 StackWood demos):
python train.py --config exp/cotrain/dp_stack_ct_w.json \
--cotraining vanilla --alpha 0.5 \
--source rs 500 --target rsw 20MMD co-training:
python train.py --config exp/cotrain/dp_stack_mmd_w.json \
--cotraining mmd \
--source rs 500 --target rsw 20OT co-training:
python train.py --config exp/cotrain/dp_stack_ot_w.json \
--cotraining ot \
--source rs 500 --target rsw 20BEA co-training:
python train.py --config exp/cotrain/dp_stack_bea_w.json \
--cotraining bea \
--source rs 500 --target rsw 20Additional train.py flags:
| Flag | Description |
|---|---|
--name NAME |
Override experiment name in config |
--dataset PATH |
Override dataset path (single-domain only) |
--debug |
Quick run for debugging |
--resume |
Resume from latest checkpoint |
Evaluate a trained checkpoint on any environment variant:
python eval.py \
--agent ckpt_path \
--num_seeds 200 \
--start_seed 501 \
--horizon 250 \
--output_dir directory \
--save_videos \
--camera_name [agentview, or agentview45] \
--video_height 256 \
--video_width 256 \
--eval_env [rsw, or rsws]
| Alias | Environment |
|---|---|
rs |
Standard robosuite Stack (ceramic table, agentview) |
rsw |
StackWood (wood table, agentview) |
rsws |
StackWood (wood table, agentview45 camera) |
| Flag | Default | Description |
|---|---|---|
--agent |
(required) | Path to .pth checkpoint |
--eval_env |
(required) | Environment alias (rs, rsw, rsws) |
--num_seeds |
5000 | Number of evaluation episodes |
--start_seed |
0 | Starting seed number |
--horizon |
from ckpt | Max steps per episode |
--output_dir |
./eval_results |
Where to save results |
--save_videos |
off | Save per-seed videos + success compilation |
--camera_name |
agentview |
Camera for video rendering |
--video_height |
256 | Video frame height |
--video_width |
256 | Video frame width |
--dense_reward |
off | Use dense (shaped) reward |
--ddim_steps |
None | Override to DDIM with N denoising steps (faster inference) |
Results are written to --output_dir:
rewards/reward_seed{seed}.npy— per-seed reward curvesvideos/test_{seed}.mp4— per-seed rollout videos (if--save_videos)max_reward.npy— per-seed max rewardsavg_success_rate.npy— scalar success ratesummary.json— human-readable summaryvideos/successes.mp4— compilation of successful episodes
| Script | Description |
|---|---|
scripts/gen_dtw.py |
Generate DTW pair-info JSON for OT co-training |
scripts/get_obs.py |
Extract sample observation images from an HDF5 dataset |
Example — generate DTW alignment data:
python scripts/gen_dtw.py \
--src data/stack/rs_demo_500.hdf5 \
--tgt data/stack/rs_wood_20_s.hdf5 \
--output data/stack/pair_info_rsws.jsonPre-built config JSONs live under exp/:
exp/
├── single/ # single-domain configs
│ ├── dp_stack_rs.json # standard Stack
│ ├── dp_stack_rsw.json # StackWood
│ └── dp_stack_rsws.json # StackWood (agentview45)
└── cotrain/ # co-training configs
├── dp_stack_ct_w.json # vanilla, wood table
├── dp_stack_ct_ws.json # vanilla, wood table + shifted camera
├── dp_stack_ot_w.json # OT, wood table
├── dp_stack_ot_ws.json # OT, wood table + shifted camera
├── dp_stack_mmd_w.json # MMD, wood table
├── dp_stack_mmd_ws.json # MMD, wood table + shifted camera
├── dp_stack_bea_w.json # BEA, wood table
└── dp_stack_bea_ws.json # BEA, wood table + shifted camera
atz/
├── train.py # main training script
├── eval.py # evaluation script
├── setup.sh # one-shot environment setup
├── exp/ # experiment config JSONs
├── data/ # demonstration datasets (HDF5)
├── runs/ # training outputs (checkpoints, logs, videos)
├── scripts/ # utility scripts (DTW, obs extraction, rollout)
├── azenv/ # Azure job submission helpers
└── test.py # development test script