-
Notifications
You must be signed in to change notification settings - Fork 4
/
memory.py
164 lines (119 loc) · 6.12 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import numpy as np
class SumTree(object):
"""
This SumTree code is modified version of Morvan Zhou:
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/5.2_Prioritized_Replay_DQN/RL_brain.py
"""
data_pointer = 0
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype=object)
def add(self, priority, data):
# Look at what index we want to put the experience
tree_index = self.data_pointer + self.capacity - 1
# Update data frame
self.data[self.data_pointer] = data
# Update the leaf
self.update(tree_index, priority)
# Add 1 to data_pointer
self.data_pointer += 1
if self.data_pointer >= self.capacity: # If we're above the capacity, you go back to first index (we overwrite)
self.data_pointer = 0
def update(self, tree_index, priority):
# Change = new priority score - former priority score
change = priority - self.tree[tree_index]
self.tree[tree_index] = priority
# then propagate the change through tree
while tree_index != 0: # this method is faster than the recursive loop in the reference code
tree_index = (tree_index - 1) // 2
self.tree[tree_index] += change
def get_leaf(self, v):
parent_index = 0
while True: # the while loop is faster than the method in the reference code
left_child_index = 2 * parent_index + 1
right_child_index = left_child_index + 1
# If we reach bottom, end the search
if left_child_index >= len(self.tree):
leaf_index = parent_index
break
else: # downward search, always search for a higher priority node
if v <= self.tree[left_child_index]:
parent_index = left_child_index
else:
v -= self.tree[left_child_index]
parent_index = right_child_index
data_index = leaf_index - self.capacity + 1
return leaf_index, self.tree[leaf_index], self.data[data_index]
@property
def total_priority(self):
return self.tree[0] # Returns the root node
class Memory(object): # stored as ( s, a, r, s_ ) in SumTree
"""
This SumTree code is modified version and the original code is from:
https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
"""
PER_e = 0.01 # Hyperparameter that we use to avoid some experiences to have 0 probability of being taken
PER_a = 0.8 # Hyperparameter that we use to make a tradeoff between taking only exp with high priority and sampling randomly
PER_b = 0.4 # importance-sampling, from initial value increasing to 1
PER_b_increment_per_sampling = 0.001
absolute_error_upper = 100. # clipped abs error
def __init__(self, capacity):
# Making the tree
"""
Remember that our tree is composed of a sum tree that contains the priority scores at his leaf
And also a data array
We don't use deque because it means that at each timestep our experiences change index by one.
We prefer to use a simple array and to overwrite when the memory is full.
"""
self.tree = SumTree(capacity)
"""
Store a new experience in our tree
Each new experience have a score of max_prority (it will be then improved when we use this exp to train our DDQN)
"""
def store(self, experience, priority=None):
# Find the max priority
max_priority = np.max(self.tree.tree[-self.tree.capacity:])
# If the max priority = 0 we can't put priority = 0 since this exp will never have a chance to be selected
# So we use a minimum priority
if max_priority == 0:
max_priority = self.absolute_error_upper
if priority is not None:
max_priority = priority
self.tree.add(max_priority, experience) # set the max p for new p
def sample(self, n):
# Create a sample array that will contains the minibatch
memory_b = []
# b_idx, b_ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, 1), dtype=np.float32)
b_idx = []
# b_idx = np.empty((n,), dtype=np.int32)
# Calculate the priority segment
# Here, as explained in the paper, we divide the Range[0, ptotal] into n ranges
priority_segment = self.tree.total_priority / n # priority segment
# Here we increasing the PER_b each time we sample a new minibatch
self.PER_b = np.min([1., self.PER_b + self.PER_b_increment_per_sampling]) # max = 1
# Calculating the max_weight
p_min = (np.min(self.tree.tree[-self.tree.capacity:]) + 1e-100) / (self.tree.total_priority + 1e-100)
max_weight = (p_min * n) ** (-self.PER_b)
for i in range(n):
a, b = priority_segment * i, priority_segment * (i + 1)
value = np.random.uniform(a, b)
index, priority, data = self.tree.get_leaf(value)
# P(j)
# sampling_probabilities = (priority + 1e-100) / (self.tree.total_priority + 1e-100)
# IS = (1/N * 1/P(i))**b /max wi == (N*P(i))**-b /max wi
# b_ISWeights[i, 0] = np.power(n * sampling_probabilities, -self.PER_b) / max_weight
if type(data) != int:
if index in b_idx:
pass
else:
b_idx.append(index)
experience = [data]
memory_b.append(experience)
return np.array(b_idx).reshape((-1,)), memory_b # , b_ISWeights
def batch_update(self, tree_idx, abs_errors):
abs_errors += self.PER_e # convert to abs and avoid 0
clipped_errors = np.minimum(abs_errors, self.absolute_error_upper)
ps = np.power(clipped_errors, self.PER_a)
for ti, p in zip(tree_idx, ps):
self.tree.update(ti, p)