-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiscrete_actions_sarsa.py
More file actions
102 lines (84 loc) · 2.74 KB
/
discrete_actions_sarsa.py
File metadata and controls
102 lines (84 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
SARSA using LFA with discrete actions.
"""
import numpy as np
from .algo_base import LearningAlgorithm
class DiscreteSARSA(LearningAlgorithm):
"""SARSA with linear function approximation.
Actions are assumed to be discrete, while states are represented via a
feature vector. That is, `Q(s,a) = [〈w, x〉]_a`
Exploration occurs via an ε-greedy policy.
"""
def __init__(self, num_features, num_actions, epsilon=5e-2):
self.num_features = num_features
self.num_actions = num_actions
self.epsilon = epsilon
# Create the weight matrix
self.w = np.random.randn(self.num_actions, self.num_features)
# Eligibility traces
self.z = np.zeros((self.num_actions, self.num_features))
def start_episode(self):
"""Get ready to start a new episode."""
self.z *= 0
def get_config(self):
"""Return the parameters needed to specify the algorithm's state."""
# ret = {
# 'num_features' : self.num_features,
# 'weights' : self.w.copy(),
# 'traces': self.z.copy(),
# }
return ret
def act(self, x):
"""Select an action following the ε-greedy policy.
Parameters
----------
x : Vector[float]
"""
if np.random.random() <= self.epsilon:
action = np.random.randint(self.num_actions)
else:
action = np.argmax(np.dot(self.w, x))
return action
def learn(self, x, a, r, xp, ap, alpha, gm, lm):
"""
Update from new experience.
Parameters
----------
x : Vector[float]
a : int
r : float
ap : int
xp : Vector[float]
alpha : float
gm : float
lm : float
"""
v = np.dot(self.w[a], x)
vp = np.dot(self.w[ap], xp)
# Compute TD-error
δ = r + gm*vp - v
# Update eligibility trace
self.z *= gm*lm
self.z[a] += x
# Update Q-values
self.w += alpha * δ * self.z
# Return the TD-error, for lack of something more informative
return δ
def get_value(self, x, a=None):
"""Get the value for a given state and action, or if action is left unspecified, just the
value for the best action in the given state.
Parameters
----------
x : Vector[float]
a : int
"""
if a is None:
return np.max(np.dot(self.w, x))
else:
return np.dot(self.w[a], x)
def save_weights(self, fname):
"""Save the weights to a file."""
np.save(fname, self.w)
def load_weights(self, fname):
"""Load the weights from a file."""
self.w = np.load(fname)