-
Notifications
You must be signed in to change notification settings - Fork 0
/
q_table.py
50 lines (40 loc) · 1.56 KB
/
q_table.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
class DiscreteQTable(object):
def __init__(self, table_args):
grid = table_args["grid"]
self.table = np.zeros((*grid, 3))
self.lr = table_args["lr"]
self.lambda_t = table_args["lambda"]
self.gamma = 0.98
self.step_size = [np.pi * 2 / grid[0], np.pi * 30 / grid[1]]
self.count, self.sum_delta = 0, 0.0
def getIndex(self, s):
return (int((s[0] + np.pi) // self.step_size[0]), int((s[1] + 15 * np.pi) // self.step_size[1]))
def greedy_action(self, s):
index = self.getIndex(s)
#print(index)
return np.argmax(self.table[index])
def train_sarsa(self, s, a, r, s_1, a_1):
index_0 = (*self.getIndex(s), a)
index_1 = (*self.getIndex(s_1), a_1)
delta = r + self.gamma*self.table[index_1] - self.table[index_0]
self.sum_delta += delta ** 2
self.count += 1
if self.lambda_t:
self.et = self.gamma * self.lambda_t *self.et
self.et[index_0] += 1
self.table += self.lr * delta * self.et
return
self.table[index_0] += self.lr * delta
def reset_et(self):
self.et = np.zeros(self.table.shape)
def get_weight(self):
return self.table
def load_weight(self, filename):
self.table = np.load(filename)
def get_delta(self):
avg_delta = self.sum_delta/self.count
self.count, self.sum_delta = 0, 0.
return avg_delta
def set_lr(self, lr):
self.lr = lr