-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
343 lines (285 loc) · 11.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import functools
import jax
from jax import numpy as jp
from jax import random
from typing import Dict
import wandb
import numpy as np
from brax import envs
from brax.io import model
import hydra
from omegaconf import DictConfig, OmegaConf
import mujoco
import imageio
from ppo_imitation import train as ppo
from ppo_imitation import ppo_networks
from envs.humanoid import HumanoidTracking, HumanoidStanding
from envs.ant import AntTracking
from envs.rodent import RodentTracking
from typing import Union
from brax import envs
from brax.v1 import envs as envs_v1
import numpy as np
import uuid
from preprocessing.mjx_preprocess import process_clip
# rendering related
from dm_control.mujoco import wrapper
from dm_control.mujoco.wrapper.mjbindings import enums
State = Union[envs.State, envs_v1.State]
Env = Union[envs.Env, envs_v1.Env, envs_v1.Wrapper]
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
def jax_has_gpu():
try:
_ = jax.device_put(jp.ones(1), device=jax.devices("gpu")[0])
return True
except:
return False
if jax_has_gpu():
n_devices = jax.device_count(backend="gpu")
print(f"Using {n_devices} GPUs")
else:
n_devices = 1
os.environ["XLA_FLAGS"] = (
"--xla_gpu_enable_triton_softmax_fusion=true " "--xla_gpu_triton_gemm_any=True "
)
envs.register_environment("humanoidtracking", HumanoidTracking)
envs.register_environment("ant", AntTracking)
envs.register_environment("rodent", RodentTracking)
envs.register_environment("humanoidstanding", HumanoidStanding)
@hydra.main(config_path="./configs", config_name="train_config", version_base=None)
def main(train_config: DictConfig):
env_cfg = hydra.compose(config_name="env_config")
env_cfg = OmegaConf.to_container(env_cfg, resolve=True)
rodent_config = env_cfg[train_config.env_name]
env_args = rodent_config["env_args"]
# Process rodent clip
reference_clip = process_clip(
rodent_config["stac_path"],
start_step=rodent_config["clip_idx"] * env_args["clip_length"],
clip_length=env_args["clip_length"],
)
# Init env
env = envs.get_environment(
env_cfg[train_config.env_name]["name"],
reference_clip=reference_clip,
**env_args,
)
# TODO: Also have preset solver params here for eval
# so we can relax params in training for faster sps?
# Set the env to always start at frame 0 by maximizing sub_clip_length
eval_env_args = env_args.copy()
eval_env_args["sub_clip_length"] = (
env_args["clip_length"] - env_args["ref_traj_length"]
)
eval_env = envs.get_environment(
env_cfg[train_config.env_name]["name"],
reference_clip=reference_clip,
**eval_env_args,
)
# TODO: make the intention network factory a part of the config
intention_network_factory = functools.partial(
ppo_networks.make_intention_ppo_networks,
intention_latent_size=train_config.intention_latent_size,
encoder_layer_sizes=train_config.encoder_layer_sizes,
decoder_layer_sizes=train_config.decoder_layer_sizes,
)
train_fn = functools.partial(
ppo.train,
num_timesteps=train_config["num_timesteps"],
num_evals=int(train_config["num_timesteps"] / train_config["eval_every"]),
reward_scaling=1,
episode_length=train_config["episode_length"],
normalize_observations=True,
action_repeat=1,
unroll_length=20,
num_minibatches=train_config["num_minibatches"],
num_updates_per_batch=train_config["num_updates_per_batch"],
discounting=0.99,
learning_rate=train_config["learning_rate"],
entropy_cost=1e-3,
num_envs=train_config["num_envs"] * n_devices,
batch_size=train_config["batch_size"] * n_devices,
seed=0,
clipping_epsilon=train_config["clipping_epsilon"],
kl_weight=train_config["kl_weight"],
network_factory=intention_network_factory,
)
# Generates a completely random UUID (version 4)
run_id = uuid.uuid4()
model_path = f"./model_checkpoints/{run_id}"
run = wandb.init(
project="VNL_SingleClipImitationPPO_Intention",
config=OmegaConf.to_container(train_config, resolve=True),
notes=f"",
dir="/tmp",
)
wandb.run.name = f"{train_config.env_name}_{train_config.task_name}_{train_config['algo_name']}_{run_id}"
def wandb_progress(num_steps, metrics):
metrics["num_steps"] = num_steps
wandb.log(metrics)
# TODO: make the rollout into a scan (or call brax's rollout fn?)
def policy_params_fn(num_steps, make_policy, params, model_path=model_path):
os.makedirs(model_path, exist_ok=True)
model.save_params(f"{model_path}/{num_steps}", params)
jit_inference_fn = jax.jit(make_policy(params, deterministic=False))
reset_rng, act_rng = jax.random.split(jax.random.PRNGKey(0))
jit_step = jax.jit(eval_env.step)
state = eval_env.reset(reset_rng)
rollout = [state.pipeline_state]
errors = []
rewards = []
means = []
stds = []
log_probs = []
rand_probs = []
for _ in range(train_config["episode_length"]):
_, act_rng = jax.random.split(act_rng)
ctrl, extras = jit_inference_fn(
state.info["traj"], state.obs, act_rng
) # extra is a dictionary
state = jit_step(state, ctrl)
if train_config.env_name != "humanoidstanding":
errors.append(state.info["termination_error"])
rewards.append(state.reward)
mean, std = np.split(extras["logits"], 2)
log_prob, rand_prob = extras["rand_log_prob"], extras["log_prob"]
log_probs.append(log_prob)
rand_probs.append(rand_prob)
means.append(mean)
stds.append(std)
rollout.append(state.pipeline_state)
# Plot rtrunk over rollout
data = [[x, y] for (x, y) in zip(range(len(errors)), errors)]
table = wandb.Table(data=data, columns=["frame", "rtrunk"])
wandb.log(
{
"eval/rollout_rtrunk": wandb.plot.line(
table,
"frame",
"rtrunk",
title="rtrunk for each rollout frame",
)
}
)
# Plot action means over rollout (array of array)
data = np.array(means).T
wandb.log(
{
f"logits/rollout_means": wandb.plot.line_series(
xs=range(data.shape[1]),
ys=data,
keys=[str(i) for i in range(data.shape[0])],
xname="Frame",
title=f"Action actuator means for each rollout frame",
)
}
)
# Plot action stds over rollout (optimize this later)
data = np.array(stds).T
wandb.log(
{
f"logits/rollout_stds": wandb.plot.line_series(
xs=range(data.shape[1]),
ys=data,
keys=[str(i) for i in range(data.shape[0])],
xname="Frame",
title=f"Action actuator stds for each rollout frame",
)
}
)
# Plot policy action prob over rollout
data = [[x, y] for (x, y) in zip(range(len(log_probs)), log_probs)]
table = wandb.Table(data=data, columns=["frame", "log_probs"])
wandb.log(
{
"logits/rollout_log_probs": wandb.plot.line(
table,
"frame",
"log_probs",
title="Policy action probability for each rollout frame",
)
}
)
# Plot random action prob over rollout
data = [[x, y] for (x, y) in zip(range(len(rand_probs)), rand_probs)]
table = wandb.Table(data=data, columns=["frame", "rand_probs"])
wandb.log(
{
"logits/rollout_rand_probs": wandb.plot.line(
table,
"frame",
"rand_probs",
title="Random action probability for each rollout frame",
)
}
)
# Plot reward over rollout
data = [[x, y] for (x, y) in zip(range(len(rewards)), rewards)]
table = wandb.Table(data=data, columns=["frame", "reward"])
wandb.log(
{
"eval/rollout_reward": wandb.plot.line(
table,
"frame",
"reward",
title="reward for each rollout frame",
)
}
)
# Render the walker with the reference expert demonstration trajectory
os.environ["MUJOCO_GL"] = "osmesa"
def f(x):
if len(x.shape) != 1:
return jax.lax.dynamic_slice_in_dim(
x,
0,
train_config["episode_length"],
)
return jp.array([])
# extract qpos from rollout
ref_traj = eval_env._ref_traj
ref_traj = jax.tree_util.tree_map(f, ref_traj)
qposes_ref = jp.hstack(
[ref_traj.position, ref_traj.quaternion, ref_traj.joints]
)
qposes_rollout = [data.qpos for data in rollout]
mj_model = mujoco.MjModel.from_xml_path(
f"./assets/{env_cfg[train_config.env_name]['rendering_mjcf']}"
)
mj_model.opt.solver = {
"cg": mujoco.mjtSolver.mjSOL_CG,
"newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
mj_model.opt.jacobian = 0 # dense
mj_data = mujoco.MjData(mj_model)
# save rendering and log to wandb
os.environ["MUJOCO_GL"] = "osmesa"
mujoco.mj_kinematics(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=512, width=512)
frames = []
# render while stepping using mujoco
video_path = f"{model_path}/{num_steps}.mp4"
with imageio.get_writer(video_path, fps=float(1.0 / eval_env.dt)) as video:
for qpos1, qpos2 in zip(qposes_ref, qposes_rollout):
mj_data.qpos = np.append(qpos1, qpos2)
mujoco.mj_forward(mj_model, mj_data)
renderer.update_scene(
mj_data, camera=f"{env_cfg[train_config.env_name]['camera']}"
)
pixels = renderer.render()
video.append_data(pixels)
frames.append(pixels)
wandb.log({"eval/rollout": wandb.Video(video_path, format="mp4")})
make_inference_fn, params, _ = train_fn(
environment=env, progress_fn=wandb_progress, policy_params_fn=policy_params_fn, eval_env=eval_env,
)
final_save_path = f"{model_path}/finished"
model.save_params(final_save_path, params)
print(f"Run finished. Model saved to {final_save_path}")
if __name__ == "__main__":
main()