This is the implementation of the paper:
CASPI:Causal-aware Safe Policy Improvement for Task-oriented dialogue. [paper]
If you use this code, data or our results in your research, please cite as below bibtex:
@article{ramachandran2021causal,
title={Causal-aware Safe Policy Improvement for Task-oriented dialogue},
author={Ramachandran, Govardana Sachithanandam and Hashimoto, Kazuma and Xiong, Caiming},
journal={arXiv preprint arXiv:2103.06370},
year={2021}
}
Run the following command to install dependencies
pip install -r requirements.txtRun the following command for data setup
./damd_multiwoz/scripts/data_setup.shCreate K-fold datasets
Please choose appropriate number of folds. In our work, we use 10 folds. larger the number of folds, larger number of model needs to be trained. For quick turn around smaller number of folds with marginal loss in performance.
python CreateKFoldDataset.py --seed 111 --folds 10Generate dataset for pairiwse reward model
Following script needs to be run K times, each time the value passed to --fold argument should be increment by 1 i.e between 0 and K-1 and prefereably with different seeds
./damd_multiwoz/scripts/gen_reward_rollout.sh --cuda 0 --K 10 --fold 0 --metric soft --seed 68690Pairwise Reward Learning
Please ensure the number of folds choosen matches with previous step
python RewardLearning.py --seed 11 --folds 10 --action_space act --gamma 0.0 --metric softEstimate Behavior Policy
Please ensure the number of folds, gamma and action_space match with reward learning step
python EstimateBehaviorPolicy.py --seed 111 --folds 10 --action_space act --gamma 0.0 --metric softCASPI(MinTL),M_soft(act)
In this version of CASPI, we use MinTL as the base model
Please ensure the argument --caspi_returns_file matches the choices made in previous steps. The file is of the form fn_Gs_<action_space>_.json
python train.py --mode train --context_window 2 --pretrained_checkpoint bart-large-cnn --gradient_accumulation_steps 8 --lr 3e-5 --back_bone bart --cfg seed=111 cuda_device=0 batch_size=8 early_stop_count=7 --caspi_returns_file=fn_Gs_10_0.0_act_soft.json --caspi_wt=5. --caspi_data_file=data_for_damd.json --caspi_val_fraction=.5 --caspi CASPI(DAMD),M_soft(act) End-to-end
In this version of CASPI, we use DAMD as the base model. This script is to test end-to-end performance. Please ensure the arguments matches the choices made in the reward learning steps
./damd_multiwoz/scripts/caspi_damd.sh --cuda 0 --seed 111 --K 10 --gamma 0.0 --policy_loss L_det,L_sto --action_space act --metric soft --train_e2e TrueCASPI(DAMD),M_soft(act) dialogue-context-to-text
In this version of CASPI, we use DAMD as the base model. This script is to test only dialogue-context-to-text generation task of Multiwoz2.0 . Please ensure the arguments matches the choices made in the reward learning steps
./damd_multiwoz/scripts/caspi_damd.sh --cuda 0 --seed 111 --K 10 --gamma 0.0 --policy_loss L_det --action_space act --metric soft --train_e2e FalseThis code extends or uses following prior codebase and data:
Please refer the [paper] and feel free to reach out to us.
