Skip to content
main
Switch branches/tags
Go to file
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
src
 
 
 
 
 
 

Cascading Decision Tree (CDT) for Interpretable Reinforcement Learning

Open-source code for paper CDT: Cascading Decision Trees for Explainable Reinforcement Learning (https://arxiv.org/abs/2011.07553).

File Structure

  • data: all data for experiments
    • mlp: data for MLP model;
    • cdt: data for CDT model;
    • sdt: data for SDT model;
    • il: data for general Imitation Learning (IL);
    • rl: data for general Reinforcement Learning (RL);
    • cdt_compare_depth: data for cdt with different depths in RL;
    • sdt_compare_depth: data for sdt with different depths in RL;
  • src: source code
    • mlp: training configurations for MLP as policy function approximator;
    • cdt: the Cascading Decision Tree (CDT) class and necessary functions;
    • sdt: the Soft Decision Tree (SDT) class and necessary functions;
    • hdt: the heuristic agents;
    • il: configurations for Imitation Learning (IL);
    • rl: configurations for Reinforcement Learning (RL) and RL agents (e.g., PPO) etc;
    • utils: some common functions
    • il_data_collect.py: collect dataset (state-action from heuristic or well-trained policy) for IL;
    • rl_data_collect.py: collect dataset (states during training for calculating normalization statistics) for RL;
    • il_train.py: train IL agent with different function approximators (e.g., SDT, CDT);
    • rl_train.py: train RL agent different function approximators (e.g., SDT, CDT, MLP);
    • il_eval.py: evaluate the trained IL agents before and after tree discretization, based on prediction accuracy;
    • rl_eval.py: evaluate the trained RL agents before and after tree discretization, based on episodic reward;
    • il_train.sh: bash to run IL test with different models on server;
    • rl_train.sh: bash to run RL test with different models on server;
    • rl_train_compare_sdt.py: train RL agent with SDT;
    • rl_train_compare_cdt.py: train RL agent with SDT;
    • rl_train_compare_sdt.sh: bash to run RL test with SDT of different depths on server;
    • rl_train_compare_cdt.sh: bash to run RL test with CDT of different depths on server;
  • visual
    • plot.ipynb: plot learning curves, etc.
    • params.ipynb: quantitive analysis of model parameters (SDT and CDT).
    • stability_analysis.ipynb: refer to the stability analysis in paper--compare the tree weights.

To Run

For fully replicating the experiments in the paper, the code needs to run in several stages.

A. Reinforcement Learning Comparison with SDT, CDT and MLP

  1. Collect dataset: for state normalization

    cd ./src
    python rl_data_collect.py
  2. Get statistics on dataset

    cd rl
    jupyter notebook

    open stats.ipynb and run cells in it to generate files for dataset statistics.

  3. Train RL agents with different policy function approximators: SDT, CDT, MLP

    cd ..
    python rl_train.py --train --env='CartPole-v1' --method='sdt' --id=0
    python rl_train.py --train --env='LunarLander-v2' --method='cdt' --id=0
    python rl_train.py --train --env='MountainCar-v0' --method='mlp' --id=0

    or simply run with:

    ./rl_train.sh
  4. Evaluate the trained agents (with discretization operation)

    python rl_eval.py --env='CartPole-v1' --method='sdt'
    python rl_eval.py --env='LunarLander-v2' --method='cdt'
  5. Results visualization

    cd ../visual
    jupyter notebook

    see in plot.ipynb.

B. Imitation Learning Comparison with SDT and CDT

  1. Collect dataset: for (1) state normalization and (2) as imitation learning dataset

    cd ./src
    python il_data_collect.py
  2. Train RL agents with different policy function approximators: SDT, CDT

    python il_train.py --train --env='CartPole-v1' --method='sdt' --id=0
    python il_train.py --train --env='LunarLander-v2' --method='cdt' --id=0

    or simply run with:

    ./il_train.sh
  3. Evaluate the trained agents

    python il_eval.py --env='CartPole-v1' --method='sdt'
    python il_eval.py --env='LunarLander-v2' --method='cdt'
  4. Results visualization

    cd ../visual
    jupyter notebook
    

    see in plot.ipynb.

C. Tree Depths for SDT and CDT in Reinforcement Learning

Run the comparison with different tree depths:

For SDT:

./rl_train_compare_sdt.sh

For CDT:

./rl_train_compare_cdt.sh

D. Stability Analysis

Compare the tree weights of different agents in IL:

cd ./visual
jupyner notebook

See in stability_analysis.ipynb.

E. Model Simplicity

Quantitative analysis of number of model parameters:

cd ./visual
jupyter notebook

See in params.ipynb.

About

Open-source code for paper CDT: Cascading Decision Trees for Explainable Reinforcement Learning

Resources

Releases

No releases published

Packages

No packages published