-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Single Agent Imitation Learning #12
Conversation
…ch 'vector-env' of github.com:sjtu-marl/malib into vector-env
Despite the code can run without error, there are several problems with current implementation: - For malib/agent/agent_interface.py:L341, I request the data from all the datasets not be None to fix a bug in adverarial training. However, in offline training such as behavior cloning, the rollout dataset can be empty. I think we can add a way to remove the main environment dataset in offline training. - The settings update rule now adds the newly specified key after the default key if the settings item is a dictionary. So we need to manually set the default value in malib/agent/agent_interface.py:L130. - The GAIL+DDPG training can not converge.
@@ -76,4 +78,4 @@ global_evaluator: | |||
|
|||
dataset_config: | |||
episode_capacity: 1000000 | |||
learning_start: 2560 | |||
learning_start: 2560 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will ignore changes in these files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your impressive contribution! I left some comments, please resolve them before merging. Also, please recover the deleted environment implementations and their related install scripts (such as sc2/install.sh
, vizdoom/v1
...), it is an unreasonable removal. @zbzhu99 @Ericonaldo
malib/envs/star_craft2/install.sh
Outdated
#!/bin/bash | ||
# Install SC2 and add the custom maps | ||
if [ -z "$SC2ROOT" ] | ||
then | ||
SC2ROOT=~ | ||
fi | ||
|
||
echo 'SC2ROOT:'$SC2ROOT | ||
cd $SC2ROOT | ||
|
||
export SC2PATH=$SC2ROOT'/StarCraftII' | ||
echo 'SC2PATH is set to '$SC2PATH | ||
|
||
if [ ! -d $SC2PATH ]; then | ||
echo 'StarCraftII is not installed. Installing now ...'; | ||
wget http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip | ||
unzip -P iagreetotheeula SC2.4.10.zip | ||
rm -rf SC2.4.10.zip | ||
else | ||
echo 'StarCraftII is already installed.' | ||
fi | ||
|
||
echo 'Adding SMAC maps.' | ||
MAP_DIR="$SC2PATH/Maps/" | ||
echo 'MAP_DIR is set to '$MAP_DIR | ||
|
||
if [ ! -d $MAP_DIR ]; then | ||
mkdir -p $MAP_DIR | ||
fi | ||
|
||
cd .. | ||
wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip | ||
unzip SMAC_Maps.zip | ||
mv SMAC_Maps $MAP_DIR | ||
rm -rf SMAC_Maps.zip | ||
|
||
echo 'StarCraft II and SMAC are installed.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please recover this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recovered
malib/envs/vizdoom_v1/v1.py
Outdated
|
||
def get_total_reward(self): | ||
return self.game.get_total_reward() | ||
|
||
def step(self, actions: Dict[AgentID, Any]) -> Tuple[Dict, Dict, Dict, Dict]: | ||
"""Environment stepping by taking agent actions and return: `observations`, `rewards`, `dones` and `infos`. Dicts | ||
where each dict looks lke {agent_1: item_1, agent_2: item_2}. | ||
|
||
:param Dict[AgentID,Any] actions: A dict of agent actions. | ||
:return: A tuple of environment returns. | ||
""" | ||
|
||
if not actions: | ||
self.agents = [] | ||
return {}, {}, {}, {} | ||
|
||
actions = action_transform(actions, 3) | ||
rewards = { | ||
agent: self.game.make_action(actions[agent], FRAME_REPEAT) | ||
for agent in self.agents | ||
} | ||
self.num_moves += 1 | ||
|
||
env_done = self.num_moves >= NUM_ITERS or self.game.is_episode_finished() | ||
dones = {agent: env_done for agent in self.agents} | ||
dones["__all__"] = any(dones.values()) | ||
observations = { | ||
agent: state_transform( | ||
self.game.get_state(), resolution=self.observation_spaces[agent].shape | ||
) | ||
for agent in self.agents | ||
} | ||
infos = { | ||
agent: { | ||
"living_reward": self.game.get_living_reward(), | ||
"last_reward": self.game.get_last_reward(), | ||
"last_action": self.game.get_last_action(), | ||
"available_action": self.game.get_available_buttons(), | ||
"step": self.num_moves, | ||
} | ||
for agent in self.agents | ||
} | ||
|
||
return observations, rewards, dones, infos | ||
|
||
|
||
def meta_info(data): | ||
return { | ||
"type": type(data), | ||
"shape": data.shape if hasattr(data, "shape") else "No shape", | ||
"agg_sum_value": np.sum(data) if isinstance(data, np.ndarray) else data, | ||
"agg_mean_value": np.mean(data) if isinstance(data, np.ndarray) else data, | ||
"agg_var_value": np.var(data) if isinstance(data, np.ndarray) else data, | ||
} | ||
|
||
|
||
def parse_state(state: vzd.GameState): | ||
if state is None or not isinstance(state, vzd.GameState): | ||
return state | ||
else: | ||
return { | ||
"time": meta_info(state.number), | ||
"vars": meta_info(state.game_variables), | ||
"screen_buf": meta_info(state.screen_buffer), | ||
"depth_buf": meta_info(state.depth_buffer), | ||
"labels_buf": meta_info(state.labels_buffer), | ||
"automap_buf": meta_info(state.automap_buffer), | ||
"labels": meta_info(state.labels), | ||
"objects": meta_info(state.objects), | ||
"sectors": meta_info(state.sectors), | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
env = make_env( | ||
doom_scenario_path=os.path.join(vzd.scenarios_path, "basic.wad"), | ||
doom_map="map01", | ||
) | ||
|
||
agents = env.possible_agents | ||
obs = env.reset() | ||
done = False | ||
|
||
iter = 0 | ||
while not done: | ||
actions = {agent: random.choice([0, 1, 2]) for agent in agents} | ||
observations, rewards, dones, infos = env.step(actions) | ||
print(f"=================\nstep on #{iter}:") | ||
parsed_state = {agent: parse_state(v) for agent, v in observations.items()} | ||
print("game state:") | ||
pprint(parsed_state) | ||
pprint(f"reward: {rewards}") | ||
done = dones["__all__"] | ||
print("==================") | ||
iter += 1 | ||
|
||
print("Episode finished") | ||
print(f"Total reward: {env.get_total_reward()}") | ||
print("********************") | ||
env.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please recover this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recovered
@@ -70,12 +70,30 @@ def __init__( | |||
worker_index=worker_idx, | |||
env_desc=self._env_desc, | |||
metric_type=self._metric_type, | |||
test=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the functionality of test
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used for creating rollout workers for deterministic evaluation.
Please also take a look at:
https://github.com/apexrl/malib/blob/196de6592fd82ea889cb871d4663d9e4d5028dde/malib/rollout/base_worker.py#L339
PICKLE_PROTOCOL_VER = 5 | ||
PICKLE_PROTOCOL_VER = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why downgrade the pickle version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my test, it seems that Protocol Version 5 can not be used in Python 3.7.10.
"test_num_episodes": 0, | ||
"test_episode_seg": 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For evaluation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is also used for deterministic evaluation.
Please take a look at:
https://github.com/apexrl/malib/blob/196de6592fd82ea889cb871d4663d9e4d5028dde/malib/manager/rollout_worker_manager.py#L81
* Single Agent Imitation Learning (#12) * update ignores * tmp save * init vector_env * ignore build * shared vector env * replace Func with stepping * one training_interface one rollout_worker * remove useless logger init * replace func with stepping * use explicit params for sampler * formatted * vector rollout test passed * mute vizdoom * in progress: bridge remote servers * migrate from dev repository * rollout test passed * specify versions * support nested transformatio and stack mode * test passed for rollout * fix bug: no data saved * resolve comments * collect configs for mpe * collect other configs * fix: id to env_id * wrap sc2 * fix: asuyc simple * fix: behavior policy not specified * Add gym environment wrapper * add dqn test * dqn test passed * test passed for ppo * add contributing markdown * update link * update * update * Add model customizing interface, e.g. qmixer * support explicit tagging. * apply explicitly tagging to collect summary * add docs of SequentialEpisode * update * Add gym cartpole * maddpg and psro worked * update summary keys * apply async data saving * update centralized agent batch usage * Continuous DDPG on Pendulum * Continuous PPO on Pendulum * update * update bc and imitation trainer * Continuous SAC on Pendulum * Reformat code * update algo Still have problems with PPO and SAC * update * Add test rollout worker with deterministic action * bug fix for deterministic evaluation * temporal saving, dumping test * offline dataset passed single agent table test * add unittest for parameter server * Black format * explaining doc * Fix merge bug * policy model save & sample with loaded model Only for the use of single agent imitation learning. May not be applied on general MALib framework. * BC on Pendulum with DDPG expert * update irl interface * update imitation trainer * update adv irl alg and interface * black format * update of advirl * temp save * Successfully run advirl with ddpg Despite the code can run without error, there are several problems with current implementation: - For malib/agent/agent_interface.py:L341, I request the data from all the datasets not be None to fix a bug in adverarial training. However, in offline training such as behavior cloning, the rollout dataset can be empty. I think we can add a way to remove the main environment dataset in offline training. - The settings update rule now adds the newly specified key after the default key if the settings item is a dictionary. So we need to manually set the default value in malib/agent/agent_interface.py:L130. - The GAIL+DDPG training can not converge. * GAIL with SAC worked on Pendulum * reorganize imitation learning interface structure * black format * action squash for applying sac on mujoco * add space between comment description * recover the MLP class & move action_squash config * recover multi-agent env files Co-authored-by: Ming Zhou <kornbergfresnel@outlook.com> Co-authored-by: morning9393 <243549184@qq.com> Co-authored-by: ericonaldo <ericliuof97@gmail.com> Co-authored-by: hanjing <wanghanjingwhj@gmail.com> * Skip open spiel installation * Fix env id naming * Add classic environment implementation * Format * Fix parameter errors * Single agent instance should group all agents * Remove print * Update sync buffer desc Co-authored-by: Zhengbang Zhu <zbzhu.yz@gmail.com> Co-authored-by: morning9393 <243549184@qq.com> Co-authored-by: ericonaldo <ericliuof97@gmail.com> Co-authored-by: hanjing <wanghanjingwhj@gmail.com>
Experiment Results on Pendulum
![WechatIMG233](https://user-images.githubusercontent.com/18256149/126757715-272baf6d-3dee-415d-a774-d9a40a2549ba.jpeg)