-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
63 lines (54 loc) · 1.58 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
import numpy as np
import gym
import math
import os
import model
import architecture as policies
import poke_environ as env
# SubprocVecEnv creates a vector of n environments to run them simultaneously.
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from tensorflow.keras import backend as K
def main():
config = tf.ConfigProto()
# Avoid warning message errors
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# Allowing GPU memory growth
config.gpu_options.allow_growth = True
K.clear_session()
with tf.Session(config=config):
model.learn(policy=policies.PPOPolicy,
env=SubprocVecEnv([
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
env.make_poke_env(),
]),
nsteps=32, # Steps per environment
# nsteps=2048, # Steps per environment
# total_timesteps=10000000,
total_timesteps=10000000,
gamma=0.99,
lam=0.95,
vf_coef=0.5,
ent_coef=0.01,
lr = lambda _:2e-4,
cliprange = lambda _:0.2, # 0.1 * learning_rate
max_grad_norm = 0.5,
log_interval = 10
)
if __name__ == '__main__':
main()