-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn_main.py
99 lines (79 loc) · 2.38 KB
/
dqn_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 25 03:24:50 2022
rOPrrWO
@author: damieQWQ
"""
import os
import argparse
from time import sleep, time
from game_env import qwopEnv
from stable_baselines3 import DQN
import stable_baselines3.common.vec_env
def get_new_model():
env = qwopEnv()
model = DQN("MlpPolicy", env, policy_kwargs=dict(net_arch=[256, 128]),
exploration_final_eps=0.075, learning_rate=0.00008, verbose=1, device="cpu")
return model
def run_train():
model_name = "YDv3_3"
model = get_new_model()
# model.train_freq=5
model.learning_starts = 1000
model.exploration_fraction = 0.2
model.learn(total_timesteps=10000, log_interval=5)
model.save(model_name)
for i in range(50):
model.learn(total_timesteps=1000, log_interval=5)
model.save(model_name + str("_f") + str(i))
print("round {}".format(i))
def run_test():
env = qwopEnv()
model = DQN.load("YDv3_3_f78")
obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
def run_train_old():
model_name = "YDv3_3_f78"
# os.environ["CUDA_VISIBLE_DEVICES"] = '1'
model_name2 = "YDv3_3"
model = DQN.load(model_name, learning_rate=0.00005, device="cpu")
env = qwopEnv() # SubprocVecEnv([lambda: QWOPEnv()])
model.set_env(env)
sleep(1)
model.learning_rate = 0.00005
# model.train_freq=5
model.learning_starts = 100
model.exploration_final_eps = 0.05
for i in range(79, 100):
model.learn(total_timesteps=1000, log_interval=5)
model.save(model_name2 + str("_f") + str(i))
print("rpound {}".format(i))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train', action='store_true', help='train')
parser.add_argument('--retrain', action='store_true', help='retrain')
parser.add_argument('--test', action='store_true', help='test')
args = parser.parse_args()
if args.train:
run_train()
elif args.retrain:
run_train_old()
elif args.test:
run_test()
else:
parser.print_help()
# os.environ["CUDA_VISIBLE_DEVICES"] = '1'
'''
W
env= qwopEnv() r
sleep (5)
for i in range(20):
env.step(0)
env.step(1)P
sleep(0.5)
env.resetr() '''