Skip to content

stanl1y/tdil

Repository files navigation

Expert Proximity as Surrogate Rewards for Single Demonstration Imitation Learning

arXiv

This repository contains the official code implementation of the paper: Expert Proximity as Surrogate Rewards for Single Demonstration Imitation Learning

The code is based on @stanl1y's reinforcement learning framework, which is available at stanl1y/RL_framework.

Note: The neighbor model in the codebase refers to the transition discriminator in our paper.

Result

TDIL enables the agent to learn from a single demonstration and achieve expert-level performance. The following video shows the HalfCheetah-v3 environment. The left side is the expert demonstration, and the right side is the learned policy.

hc.mov

Installation

Clone this repo with:

git clone https://github.com/stanl1y/tdil.git
cd tdil

(Optional) Launch a Docker Container

Install docker and nvidia-docker, and then run:

# assume the current directory is the root of this repository
docker build -t j3soon/tdil .
docker run --rm -it --gpus all --ipc=host -v $(pwd):/workspace j3soon/tdil

Install Dependencies

pip install -r requirements.txt

Expert data

The expert data for performing Imitation Learning (IL) is provided for reproducibility.

Path to expert data

The expert data is stored in the folder saved_expert_transition/. The expert data is generated by a pretrained SAC agent and is stored in dictionary format with the following keys: "states", "actions", "rewards", "next_states", "dones". The value of each keys contains a numpy array.

Total reward of the expert trajectory

  • Hopper-v3: 4114
  • Walker-v3: 6123
  • Ant-v3: 6561
  • HalfCheetah-v3: 15251
  • Humanoid-v3: 5855

Usage

Please note that TDIL (Transition Discriminator-based Imitation Learning) is our proposed method.

TDIL + BC

  • Hopper-v3:

    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Hopper-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_4114
    • For fixing alpha, add:
      --no_update_alpha --log_alpha_init -4.6
      
  • Walker-v3:

    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Walker2d-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_6123
    • For fixing alpha, add:
      --no_update_alpha --log_alpha_init -1.2
      
  • Ant-v3:

    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Ant-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_6561 --terminate_when_unhealthy
    • For fixing alpha, add:
      --no_update_alpha --log_alpha_init -1.9
      
  • HalfCheetah-v3:

    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env HalfCheetah-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_15251
    • For fixing alpha, add:
      --no_update_alpha --log_alpha_init 0.4
      
  • Humanoid-v3:

    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Humanoid-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_5855 --terminate_when_unhealthy
    • For fixing alpha, add:
      --no_update_alpha --log_alpha_init -0.6
      

TDIL + IRL

  • Hopper-v3:
    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Hopper-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_4114  --no_bc --beta 0.9 --use_discriminator
  • Walker-v3:
    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Walker2d-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_6123  --no_bc --beta 0.9 --use_discriminator
  • Ant-v3:
    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Ant-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_6561 --terminate_when_unhealthy  --no_bc --beta 0.9 --use_discriminator
  • HalfCheetah-v3:
    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env HalfCheetah-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_15251  --no_bc --beta 0.9 --use_discriminator
  • Humanoid-v3:
    python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Humanoid-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_5855 --terminate_when_unhealthy  --no_bc --beta 0.9 --use_discriminator

Add custom log name to the experiment

Add the following flag:

--log_name <custom_name>

Run without the BC loss

Add the following flag:

--no_bc

Run without hard negative samples

Add the following flag:

--no_hard_negative_sampling

Run on the toy environment

python main.py --main_stage neighborhood_il --main_task neighborhood_dsac --env Maze-v6 --episodes 300 --policy_threshold_ratio 0.5 --neighbor_model_alpha 0.1 --gamma 0.8

The policy_threshold_ratio hyperparameter aims to filter out the state-action pairs that are too close to the expert proximity when training the policy. Because the toy maze environment is a commutative type of environment, which means the agent can go back to $s_{t-1}$ from $s_t$ with the negative action $-a_{t-1}$. These kind of state-action pairs would harm the performance of the policy when they are assigned high reward by the transition discriminator. Therefore, we need to filter out these state-action pairs by setting a threshold. If any state-action pair's reward is higher than the threshold ratio times the average reward of expert data, then we will not use this state-action pair to train the policy.

Run on AdroitHandDoor-v1

python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env AdroitHandDoor-v1 --wrapper gymnasium --total_timesteps 1000000 --data_name dapg/episode_num1_3019 --max_episode_steps 200 --no_hard_negative_sampling --policy_threshold_ratio 0.005 --ood

The ood argument makes the agent to test on the out-of-distribution (OOD) states. More specifically, in the beginning of testing, the agent will first take few timesteps of random actions. Then, the agent will start to take actions based on the learned policy.

About

This is the official Implementation of "Expert Proximity as Surrogate Rewards for Single Demonstration Imitation Learning"

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages