-
Notifications
You must be signed in to change notification settings - Fork 5.6k
/
trajectory_view_utilizing_models.py
136 lines (115 loc) · 5.53 KB
/
trajectory_view_utilizing_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.tf_ops import one_hot
from ray.rllib.utils.torch_ops import one_hot as torch_one_hot
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
# __sphinx_doc_begin__
class FrameStackingCartPoleModel(TFModelV2):
"""A simple FC model that takes the last n observations as input."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
num_frames=3):
super(FrameStackingCartPoleModel, self).__init__(
obs_space, action_space, None, model_config, name)
self.num_frames = num_frames
self.num_outputs = num_outputs
# Construct actual (very simple) FC model.
assert len(obs_space.shape) == 1
obs = tf.keras.layers.Input(
shape=(self.num_frames, obs_space.shape[0]))
obs_reshaped = tf.keras.layers.Reshape(
[obs_space.shape[0] * self.num_frames])(obs)
rewards = tf.keras.layers.Input(shape=(self.num_frames))
rewards_reshaped = tf.keras.layers.Reshape([self.num_frames])(rewards)
actions = tf.keras.layers.Input(
shape=(self.num_frames, self.action_space.n))
actions_reshaped = tf.keras.layers.Reshape(
[action_space.n * self.num_frames])(actions)
input_ = tf.keras.layers.Concatenate(axis=-1)(
[obs_reshaped, actions_reshaped, rewards_reshaped])
layer1 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(input_)
layer2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(layer1)
out = tf.keras.layers.Dense(self.num_outputs)(layer2)
values = tf.keras.layers.Dense(1)(layer1)
self.base_model = tf.keras.models.Model([obs, actions, rewards],
[out, values])
self._last_value = None
self.view_requirements["prev_n_obs"] = ViewRequirement(
data_col="obs",
shift="-{}:0".format(num_frames - 1),
space=obs_space)
self.view_requirements["prev_n_rewards"] = ViewRequirement(
data_col="rewards", shift="-{}:-1".format(self.num_frames))
self.view_requirements["prev_n_actions"] = ViewRequirement(
data_col="actions",
shift="-{}:-1".format(self.num_frames),
space=self.action_space)
def forward(self, input_dict, states, seq_lens):
obs = tf.cast(input_dict["prev_n_obs"], tf.float32)
rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32)
actions = one_hot(input_dict["prev_n_actions"], self.action_space)
out, self._last_value = self.base_model([obs, actions, rewards])
return out, []
def value_function(self):
return tf.squeeze(self._last_value, -1)
# __sphinx_doc_end__
class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
"""A simple FC model that takes the last n observations as input."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
num_frames=3):
nn.Module.__init__(self)
super(TorchFrameStackingCartPoleModel, self).__init__(
obs_space, action_space, None, model_config, name)
self.num_frames = num_frames
self.num_outputs = num_outputs
# Construct actual (very simple) FC model.
assert len(obs_space.shape) == 1
in_size = self.num_frames * (obs_space.shape[0] + action_space.n + 1)
self.layer1 = SlimFC(
in_size=in_size, out_size=256, activation_fn="relu")
self.layer2 = SlimFC(in_size=256, out_size=256, activation_fn="relu")
self.out = SlimFC(
in_size=256, out_size=self.num_outputs, activation_fn="linear")
self.values = SlimFC(in_size=256, out_size=1, activation_fn="linear")
self._last_value = None
self.view_requirements["prev_n_obs"] = ViewRequirement(
data_col="obs",
shift="-{}:0".format(num_frames - 1),
space=obs_space)
self.view_requirements["prev_n_rewards"] = ViewRequirement(
data_col="rewards", shift="-{}:-1".format(self.num_frames))
self.view_requirements["prev_n_actions"] = ViewRequirement(
data_col="actions",
shift="-{}:-1".format(self.num_frames),
space=self.action_space)
def forward(self, input_dict, states, seq_lens):
obs = input_dict["prev_n_obs"]
obs = torch.reshape(obs,
[-1, self.obs_space.shape[0] * self.num_frames])
rewards = torch.reshape(input_dict["prev_n_rewards"],
[-1, self.num_frames])
actions = torch_one_hot(input_dict["prev_n_actions"],
self.action_space)
actions = torch.reshape(actions,
[-1, self.num_frames * actions.shape[-1]])
input_ = torch.cat([obs, actions, rewards], dim=-1)
features = self.layer1(input_)
features = self.layer2(features)
out = self.out(features)
self._last_value = self.values(features)
return out, []
def value_function(self):
return torch.squeeze(self._last_value, -1)