Skip to content

Commit

Permalink
Merge e566a79 into 7c96789
Browse files Browse the repository at this point in the history
  • Loading branch information
JSPromisel committed Apr 16, 2019
2 parents 7c96789 + e566a79 commit 1aab844
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions src/mdptoolbox/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,11 @@ class for details.
By default we run a check on the ``transitions`` and ``rewards``
arguments to make sure they describe a valid MDP. You can set this
argument to True in order to skip this check.
learning_rate : float, optional
The learning rate, set between 0.0 and 1.0. Setting it to 0 means that
the Q-values are never updated, hence nothing is learned. Setting a
high value such as 0.9 means that learning can occur quickly.
Default: -1 means learning rate will be set to (1/sqrt(n+2))
Data Attributes
---------------
Expand Down Expand Up @@ -1023,14 +1028,19 @@ class for details.
"""

def __init__(self, transitions, reward, discount, n_iter=10000,
skip_check=False):
skip_check=False, learning_rate=-1.0):
# Initialise a Q-learning MDP.

# The following check won't be done in MDP()'s initialisation, so let's
# do it here
self.max_iter = int(n_iter)
assert self.max_iter >= 10000, "'n_iter' should be greater than 10000."

# Validate the learning rate value
self.learning_rate = learning_rate
assert self.learning_rate == -1.0 or \
(0.0 <= self.learning_rate <= 1.0), "'learning_rate' should be -1.0 or between 0.0 and 1.0"

if not skip_check:
# We don't want to send this to MDP because _computePR should not
# be run on it, so check that it defines an MDP
Expand All @@ -1057,6 +1067,10 @@ def run(self):
# initial state choice
s = _np.random.randint(0, self.S)

use_learn_rate = True
if self.learning_rate == -1.0:
use_learn_rate = False

for n in range(1, self.max_iter + 1):

# Reinitialisation of trajectories every 100 transitions
Expand Down Expand Up @@ -1089,9 +1103,14 @@ def run(self):
r = self.R[s]

# Updating the value of Q
# Decaying update coefficient (1/sqrt(n+2)) can be changed
# If learning_rate = -1.0 use decaying update coefficient of (1/sqrt(n+2))
# Else use learning_rate
if use_learn_rate:
decay_rate = self.learning_rate
else:
decay_rate = (1 / _math.sqrt(n + 2))
delta = r + self.discount * self.Q[s_new, :].max() - self.Q[s, a]
dQ = (1 / _math.sqrt(n + 2)) * delta
dQ = decay_rate * delta
self.Q[s, a] = self.Q[s, a] + dQ

# current state is updated
Expand Down Expand Up @@ -1581,4 +1600,4 @@ def run(self):
self.V[s] = Q.max()
self.policy.append(int(Q.argmax()))

self._endRun()
self._endRun()

0 comments on commit 1aab844

Please sign in to comment.