-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn_agent.py
100 lines (78 loc) · 3.59 KB
/
dqn_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import torch as T
from deep_q_network import DeepQNetwork
from replay_memory import ReplayBuffer
class DQNAgent(object):
def __init__(self, gamma, epsilon, lr, n_actions, input_dims,
mem_size, batch_size, eps_min=0.01, eps_dec=5e-7,
replace=1000, algo=None, env_name=None, chkpt_dir='tmp/dqn'):
self.gamma = gamma
self.epsilon = epsilon
self.lr = lr
self.n_actions = n_actions
self.input_dims = input_dims
self.batch_size = batch_size
self.eps_min = eps_min
self.eps_dec = eps_dec
self.replace_target_cnt = replace
self.algo = algo
self.env_name = env_name
self.chkpt_dir = chkpt_dir
self.action_space = [i for i in range(n_actions)]
self.learn_step_counter = 0
self.memory = ReplayBuffer(mem_size, input_dims, n_actions)
self.q_eval = DeepQNetwork(self.lr, self.n_actions,
input_dims=self.input_dims,
name=self.env_name+'_'+self.algo+'_q_eval',
chkpt_dir=self.chkpt_dir)
self.q_next = DeepQNetwork(self.lr, self.n_actions,
input_dims=self.input_dims,
name=self.env_name+'_'+self.algo+'_q_next',
chkpt_dir=self.chkpt_dir)
def choose_action(self, observation):
if np.random.random() > self.epsilon:
state = T.tensor([observation],dtype=T.float).to(self.q_eval.device)
actions = self.q_eval.forward(state)
action = T.argmax(actions).item()
else:
action = np.random.choice(self.action_space)
return action
def store_transition(self, state, action, reward, state_, done):
self.memory.store_transition(state, action, reward, state_, done)
def sample_memory(self):
state, action, reward, new_state, done = \
self.memory.sample_buffer(self.batch_size)
states = T.tensor(state).to(self.q_eval.device)
rewards = T.tensor(reward).to(self.q_eval.device)
dones = T.tensor(done).to(self.q_eval.device)
actions = T.tensor(action).to(self.q_eval.device)
states_ = T.tensor(new_state).to(self.q_eval.device)
return states, actions, rewards, states_, dones
def replace_target_network(self):
if self.learn_step_counter % self.replace_target_cnt == 0:
self.q_next.load_state_dict(self.q_eval.state_dict())
def decrement_epsilon(self):
self.epsilon = self.epsilon - self.eps_dec \
if self.epsilon > self.eps_min else self.eps_min
def save_models(self):
self.q_eval.save_checkpoint()
self.q_next.save_checkpoint()
def load_models(self):
self.q_eval.load_checkpoint()
self.q_next.load_checkpoint()
def learn(self):
if self.memory.mem_cntr < self.batch_size:
return
self.q_eval.optimizer.zero_grad()
self.replace_target_network()
states, actions, rewards, states_, dones = self.sample_memory()
indices = np.arange(self.batch_size)
q_pred = self.q_eval.forward(states)[indices, actions]
q_next = self.q_next.forward(states_).max(dim=1)[0]
q_next[dones] = 0.0
q_target = rewards + self.gamma*q_next
loss = self.q_eval.loss(q_target, q_pred).to(self.q_eval.device)
loss.backward()
self.q_eval.optimizer.step()
self.learn_step_counter += 1
self.decrement_epsilon()