Skip to content

zatpds/beact

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Run Guide

Diffusion Policy with Cotraining, MMD, UOT, and BEA.

Prerequisites

  • Ubuntu 18.04+
  • NVIDIA GPU with CUDA 11.8

Installation

The three simulator packages (robomimic, robosuite, mimicgen) are from the simpreal_p repo. *edit package path and miniconda path before running.

bash setup.sh

What setup.sh does:

  1. Installs system packages (cmake, patchelf, OSMesa, etc.)
  2. Installs Miniconda to ~/miniconda3 (skipped if already present)
  3. Creates a conda env named robomimic with Python 3.9
  4. Installs PyTorch 2.0 (GPU build preferred, CPU fallback)
  5. Installs the mujoco pip package
  6. Clones simpreal_p into ~/opt/simpreal_p and pip-installs the three sub-packages in editable mode:
    • ~/opt/simpreal_p/robomimic
    • ~/opt/simpreal_p/robosuite
    • ~/opt/simpreal_p/mimicgen

After installation, activate the environment in any new shell:

source ~/miniconda3/etc/profile.d/conda.sh
conda activate robomimic

Rendering backend

The setup exports MUJOCO_GL=osmesa (software rendering). If you have a display or EGL available you can override this before running:

export MUJOCO_GL=egl      # GPU-accelerated headless
# or
export MUJOCO_GL=glfw     # on-screen window

Data

Training configs expect HDF5 demo datasets. Data for testing are uploaded to drive. Typical files:

File Description
rs_demo_500.hdf5 500 demos in standard robosuite Stack
rs_wood_20.hdf5 20 demos in StackWood (wood table, agentview)
rs_wood_20_s.hdf5 20 demos in StackWood (wood table, agentview45 camera)
pair_info_rsw.json dtw pair info for UOT (wood)
pair_info_rsws.json dtw pair info for UOT (wood + sideview)

Training

python train.py -h

Single-domain training

Train a Diffusion Policy on standard robosuite Stack demonstrations:

python train.py --config exp/single/dp_stack_rs.json

Checkpoints, TensorBoard logs, and rollout videos are saved to the output_dir specified in the config (e.g. ~/atz/dp_stack_ckpts_rs).

Co-training (domain adaptation)

Five co-training modes are available via --cotraining:

Mode Flag Description
Off --cotraining off Standard single-domain BC (default)
Vanilla --cotraining vanilla Convex combination of source and target losses
OT --cotraining ot Optimal Transport alignment on encoder features
MMD --cotraining mmd Energy-distance MMD alignment on encoder features
BEA --cotraining bea Best-Effort Adaptation (q-weighted ERM)

Vanilla co-training (source: 100 robosuite demos, target: 200 StackWood demos):

python train.py --config exp/cotrain/dp_stack_ct_w.json \
    --cotraining vanilla --alpha 0.5 \
    --source rs 500 --target rsw 20

MMD co-training:

python train.py --config exp/cotrain/dp_stack_mmd_w.json \
    --cotraining mmd \
    --source rs 500 --target rsw 20

OT co-training:

python train.py --config exp/cotrain/dp_stack_ot_w.json \
    --cotraining ot \
    --source rs 500 --target rsw 20

BEA co-training:

python train.py --config exp/cotrain/dp_stack_bea_w.json \
    --cotraining bea \
    --source rs 500 --target rsw 20

Additional train.py flags:

Flag Description
--name NAME Override experiment name in config
--dataset PATH Override dataset path (single-domain only)
--debug Quick run for debugging
--resume Resume from latest checkpoint

Evaluation

Evaluate a trained checkpoint on any environment variant:

python eval.py \
    --agent ckpt_path \
    --num_seeds 200 \
    --start_seed 501 \
    --horizon 250 \
    --output_dir directory \
    --save_videos \
    --camera_name [agentview, or agentview45] \
    --video_height 256 \
    --video_width 256 \
    --eval_env [rsw, or rsws]

Environment aliases

Alias Environment
rs Standard robosuite Stack (ceramic table, agentview)
rsw StackWood (wood table, agentview)
rsws StackWood (wood table, agentview45 camera)

Full set of eval flags

Flag Default Description
--agent (required) Path to .pth checkpoint
--eval_env (required) Environment alias (rs, rsw, rsws)
--num_seeds 5000 Number of evaluation episodes
--start_seed 0 Starting seed number
--horizon from ckpt Max steps per episode
--output_dir ./eval_results Where to save results
--save_videos off Save per-seed videos + success compilation
--camera_name agentview Camera for video rendering
--video_height 256 Video frame height
--video_width 256 Video frame width
--dense_reward off Use dense (shaped) reward
--ddim_steps None Override to DDIM with N denoising steps (faster inference)

Evaluation outputs

Results are written to --output_dir:

  • rewards/reward_seed{seed}.npy — per-seed reward curves
  • videos/test_{seed}.mp4 — per-seed rollout videos (if --save_videos)
  • max_reward.npy — per-seed max rewards
  • avg_success_rate.npy — scalar success rate
  • summary.json — human-readable summary
  • videos/successes.mp4 — compilation of successful episodes

Utility Scripts

Script Description
scripts/gen_dtw.py Generate DTW pair-info JSON for OT co-training
scripts/get_obs.py Extract sample observation images from an HDF5 dataset

Example — generate DTW alignment data:

python scripts/gen_dtw.py \
    --src data/stack/rs_demo_500.hdf5 \
    --tgt data/stack/rs_wood_20_s.hdf5 \
    --output data/stack/pair_info_rsws.json

Experiment Configs

Pre-built config JSONs live under exp/:

exp/
├── single/                    # single-domain configs
│   ├── dp_stack_rs.json       #   standard Stack
│   ├── dp_stack_rsw.json      #   StackWood
│   └── dp_stack_rsws.json     #   StackWood (agentview45)
└── cotrain/                   # co-training configs
    ├── dp_stack_ct_w.json     #   vanilla, wood table
    ├── dp_stack_ct_ws.json    #   vanilla, wood table + shifted camera
    ├── dp_stack_ot_w.json     #   OT, wood table
    ├── dp_stack_ot_ws.json    #   OT, wood table + shifted camera
    ├── dp_stack_mmd_w.json    #   MMD, wood table
    ├── dp_stack_mmd_ws.json   #   MMD, wood table + shifted camera
    ├── dp_stack_bea_w.json    #   BEA, wood table
    └── dp_stack_bea_ws.json   #   BEA, wood table + shifted camera

Project Structure

atz/
├── train.py               # main training script
├── eval.py                # evaluation script
├── setup.sh               # one-shot environment setup
├── exp/                   # experiment config JSONs
├── data/                  # demonstration datasets (HDF5)
├── runs/                  # training outputs (checkpoints, logs, videos)
├── scripts/               # utility scripts (DTW, obs extraction, rollout)
├── azenv/                 # Azure job submission helpers
└── test.py                # development test script

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages