forked from pengjujin/rllab_mujoco
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mujoco_env.py
176 lines (148 loc) · 5.5 KB
/
mujoco_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
from rllab import spaces
from rllab.envs.base import Env
# from gym import error, spaces
from gym import error
from gym.utils import seeding
import numpy as np
from os import path
import gym
import six
try:
import mujoco_py
from mujoco_py.mjlib import mjlib
except ImportError as e:
raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e))
class MujocoEnv(Env):
"""Superclass for all MuJoCo environments.
"""
def __init__(self, model_path, frame_skip):
if model_path.startswith("/"):
fullpath = model_path
else:
fullpath = os.path.join(os.path.dirname(__file__), model_path)
if not path.exists(fullpath):
raise IOError("File %s does not exist" % fullpath)
self.frame_skip = frame_skip
self.model = mujoco_py.MjModel(fullpath)
self.data = self.model.data
self.viewer = None
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': int(np.round(1.0 / self.dt))
}
self.init_qpos = self.model.data.qpos.ravel().copy()
self.init_qvel = self.model.data.qvel.ravel().copy()
observation, _reward, done, _info = self.step(np.zeros(self.model.nu))
assert not done
self.obs_dim = observation.size
bounds = self.model.actuator_ctrlrange.copy()
low = bounds[:, 0]
high = bounds[:, 1]
# self.action_space = spaces.Box(low, high)
high = np.inf*np.ones(self.obs_dim)
low = -high
# self.observation_space = spaces.Box(low, high)
self._seed()
@property
def observation_space(self):
high = np.inf*np.ones(self.obs_dim)
low = -high
return spaces.Box(low, high)
@property
def action_space(self):
bounds = self.model.actuator_ctrlrange.copy()
low = bounds[:, 0]
high = bounds[:, 1]
return spaces.Box(low, high)
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
# methods to override:
# ----------------------------
def reset_model(self):
"""
Reset the robot degrees of freedom (qpos and qvel).
Implement this in each subclass.
"""
raise NotImplementedError
def viewer_setup(self):
"""
This method is called when the viewer is initialized and after every reset
Optionally implement this method, if you need to tinker with camera position
and so forth.
"""
pass
# -----------------------------
def reset(self):
mjlib.mj_resetData(self.model.ptr, self.data.ptr)
ob = self.reset_model()
if self.viewer is not None:
self.viewer.autoscale()
self.viewer_setup()
return ob
# def _reset(self):
# mjlib.mj_resetData(self.model.ptr, self.data.ptr)
# ob = self.reset_model()
# if self.viewer is not None:
# self.viewer.autoscale()
# self.viewer_setup()
# return ob
def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
self.model.data.qpos = qpos
self.model.data.qvel = qvel
self.model._compute_subtree() # pylint: disable=W0212
self.model.forward()
@property
def dt(self):
return self.model.opt.timestep * self.frame_skip
def do_simulation(self, ctrl, n_frames):
self.model.data.ctrl = ctrl
for _ in range(n_frames):
self.model.step()
# def _render(self, mode='human', close=False):
# if close:
# if self.viewer is not None:
# self._get_viewer().finish()
# self.viewer = None
# return
# if mode == 'rgb_array':
# self._get_viewer().render()
# data, width, height = self._get_viewer().get_image()
# return np.fromstring(data, dtype='uint8').reshape(height, width, 3)[::-1, :, :]
# elif mode == 'human':
# self._get_viewer().loop_once()
def render(self, mode='human', close=False):
if close:
if self.viewer is not None:
self._get_viewer().finish()
self.viewer = None
return
if mode == 'rgb_array':
self._get_viewer().render()
data, width, height = self._get_viewer().get_image()
return np.fromstring(data, dtype='uint8').reshape(height, width, 3)[::-1, :, :]
elif mode == 'human':
self._get_viewer().loop_once()
def _get_viewer(self):
if self.viewer is None:
self.viewer = mujoco_py.MjViewer()
self.viewer.start()
self.viewer.set_model(self.model)
self.viewer_setup()
return self.viewer
def get_body_com(self, body_name):
idx = self.model.body_names.index(six.b(body_name))
return self.model.data.com_subtree[idx]
def get_body_comvel(self, body_name):
idx = self.model.body_names.index(six.b(body_name))
return self.model.body_comvels[idx]
def get_body_xmat(self, body_name):
idx = self.model.body_names.index(six.b(body_name))
return self.model.data.xmat[idx].reshape((3, 3))
def state_vector(self):
return np.concatenate([
self.model.data.qpos.flat,
self.model.data.qvel.flat
])