-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathpqn_atari_envpool.py
More file actions
291 lines (252 loc) · 11.2 KB
/
pqn_atari_envpool.py
File metadata and controls
291 lines (252 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/pqn/#pqn_atari_envpoolpy
import os
import random
import time
from collections import deque
from dataclasses import dataclass
import envpool
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from torch.utils.tensorboard import SummaryWriter
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
track: bool = False
"""if toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "cleanRL"
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = False
"""whether to capture videos of the agent performances (check out `videos` folder)"""
# Algorithm specific arguments
env_id: str = "Breakout-v5"
"""the id of the environment"""
total_timesteps: int = 10000000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 8
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr: bool = True
"""Toggle learning rate annealing for policy and value networks"""
gamma: float = 0.99
"""the discount factor gamma"""
num_minibatches: int = 4
"""the number of mini-batches"""
update_epochs: int = 4
"""the K epochs to update the policy"""
max_grad_norm: float = 10.0
"""the maximum norm for the gradient clipping"""
start_e: float = 1
"""the starting epsilon for exploration"""
end_e: float = 0.01
"""the ending epsilon for exploration"""
exploration_fraction: float = 0.10
"""the fraction of `total_timesteps` it takes from start_e to end_e"""
q_lambda: float = 0.65
"""the lambda for the Q-Learning algorithm"""
# to be filled in runtime
batch_size: int = 0
"""the batch size (computed in runtime)"""
minibatch_size: int = 0
"""the mini-batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
def reset(self, **kwargs):
observations = super().reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.lives = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations
def step(self, action):
observations, rewards, dones, infos = super().step(action)
self.episode_returns += infos["reward"]
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= 1 - infos["terminated"]
self.episode_lengths *= 1 - infos["terminated"]
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.network = nn.Sequential(
layer_init(nn.Conv2d(4, 32, 8, stride=4)),
nn.LayerNorm([32, 20, 20]),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, 4, stride=2)),
nn.LayerNorm([64, 9, 9]),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, 3, stride=1)),
nn.LayerNorm([64, 7, 7]),
nn.ReLU(),
nn.Flatten(),
layer_init(nn.Linear(3136, 512)),
nn.LayerNorm(512),
nn.ReLU(),
layer_init(nn.Linear(512, env.single_action_space.n)),
)
def forward(self, x):
return self.network(x / 255.0)
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
slope = (end_e - start_e) / duration
return max(slope * t + start_e, end_e)
if __name__ == "__main__":
args = tyro.cli(Args)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
# env setup
envs = envpool.make(
args.env_id,
env_type="gym",
num_envs=args.num_envs,
episodic_life=True,
reward_clip=True,
seed=args.seed,
)
envs.num_envs = args.num_envs
envs.single_action_space = envs.action_space
envs.single_observation_space = envs.observation_space
envs = RecordEpisodeStatistics(envs)
assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"
q_network = QNetwork(envs).to(device)
optimizer = optim.RAdam(q_network.parameters(), lr=args.learning_rate)
# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
avg_returns = deque(maxlen=20)
# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs = torch.Tensor(envs.reset()).to(device)
next_done = torch.zeros(args.num_envs).to(device)
for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow
for step in range(0, args.num_steps):
global_step += args.num_envs
obs[step] = next_obs
dones[step] = next_done
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device)
with torch.no_grad():
q_values = q_network(next_obs)
max_actions = torch.argmax(q_values, dim=1)
values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten()
explore = torch.rand((args.num_envs,)).to(device) < epsilon
action = torch.where(explore, random_actions, max_actions)
actions[step] = action
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
for idx, d in enumerate(next_done):
if d and info["lives"][idx] == 0:
print(f"global_step={global_step}, episodic_return={info['r'][idx]}")
avg_returns.append(info["r"][idx])
writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step)
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)
# Compute Q(lambda) targets
with torch.no_grad():
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
next_value, _ = torch.max(q_network(next_obs), dim=-1)
nextnonterminal = 1.0 - next_done
returns[t] = rewards[t] + args.gamma * next_value * nextnonterminal
else:
nextnonterminal = 1.0 - dones[t + 1]
next_value = values[t + 1]
returns[t] = (
rewards[t]
+ args.gamma * (args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value) * nextnonterminal
)
# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_returns = returns.reshape(-1)
# Optimizing the Q-network
b_inds = np.arange(args.batch_size)
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]
old_val = q_network(b_obs[mb_inds]).gather(1, b_actions[mb_inds].unsqueeze(-1).long()).squeeze()
loss = F.mse_loss(b_returns[mb_inds], old_val)
# optimize the model
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(q_network.parameters(), args.max_grad_norm)
optimizer.step()
writer.add_scalar("losses/td_loss", loss, global_step)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
envs.close()
writer.close()