-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
unity3d_env.py
381 lines (347 loc) · 15.1 KB
/
unity3d_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
from gymnasium.spaces import Box, MultiDiscrete, Tuple as TupleSpace
import logging
import numpy as np
import random
import time
from typing import Callable, Optional, Tuple
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.typing import MultiAgentDict, PolicyID, AgentID
logger = logging.getLogger(__name__)
@PublicAPI
class Unity3DEnv(MultiAgentEnv):
"""A MultiAgentEnv representing a single Unity3D game instance.
For an example on how to use this Env with a running Unity3D editor
or with a compiled game, see:
`rllib/examples/unity3d_env_local.py`
For an example on how to use it inside a Unity game client, which
connects to an RLlib Policy server, see:
`rllib/examples/envs/external_envs/unity3d_[client|server].py`
Supports all Unity3D (MLAgents) examples, multi- or single-agent and
gets converted automatically into an ExternalMultiAgentEnv, when used
inside an RLlib PolicyClient for cloud/distributed training of Unity games.
"""
# Default base port when connecting directly to the Editor
_BASE_PORT_EDITOR = 5004
# Default base port when connecting to a compiled environment
_BASE_PORT_ENVIRONMENT = 5005
# The worker_id for each environment instance
_WORKER_ID = 0
def __init__(
self,
file_name: str = None,
port: Optional[int] = None,
seed: int = 0,
no_graphics: bool = False,
timeout_wait: int = 300,
episode_horizon: int = 1000,
):
"""Initializes a Unity3DEnv object.
Args:
file_name (Optional[str]): Name of the Unity game binary.
If None, will assume a locally running Unity3D editor
to be used, instead.
port (Optional[int]): Port number to connect to Unity environment.
seed: A random seed value to use for the Unity3D game.
no_graphics: Whether to run the Unity3D simulator in
no-graphics mode. Default: False.
timeout_wait: Time (in seconds) to wait for connection from
the Unity3D instance.
episode_horizon: A hard horizon to abide to. After at most
this many steps (per-agent episode `step()` calls), the
Unity3D game is reset and will start again (finishing the
multi-agent episode that the game represents).
Note: The game itself may contain its own episode length
limits, which are always obeyed (on top of this value here).
"""
super().__init__()
if file_name is None:
print(
"No game binary provided, will use a running Unity editor "
"instead.\nMake sure you are pressing the Play (|>) button in "
"your editor to start."
)
import mlagents_envs
from mlagents_envs.environment import UnityEnvironment
# Try connecting to the Unity3D game instance. If a port is blocked
port_ = None
while True:
# Sleep for random time to allow for concurrent startup of many
# environments (num_env_runners >> 1). Otherwise, would lead to port
# conflicts sometimes.
if port_ is not None:
time.sleep(random.randint(1, 10))
port_ = port or (
self._BASE_PORT_ENVIRONMENT if file_name else self._BASE_PORT_EDITOR
)
# cache the worker_id and
# increase it for the next environment
worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0
Unity3DEnv._WORKER_ID += 1
try:
self.unity_env = UnityEnvironment(
file_name=file_name,
worker_id=worker_id_,
base_port=port_,
seed=seed,
no_graphics=no_graphics,
timeout_wait=timeout_wait,
)
print("Created UnityEnvironment for port {}".format(port_ + worker_id_))
except mlagents_envs.exception.UnityWorkerInUseException:
pass
else:
break
# ML-Agents API version.
self.api_version = self.unity_env.API_VERSION.split(".")
self.api_version = [int(s) for s in self.api_version]
# Reset entire env every this number of step calls.
self.episode_horizon = episode_horizon
# Keep track of how many times we have called `step` so far.
self.episode_timesteps = 0
def step(
self, action_dict: MultiAgentDict
) -> Tuple[
MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
]:
"""Performs one multi-agent step through the game.
Args:
action_dict: Multi-agent action dict with:
keys=agent identifier consisting of
[MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
[Agent index, a unique MLAgent-assigned index per single agent]
Returns:
tuple:
- obs: Multi-agent observation dict.
Only those observations for which to get new actions are
returned.
- rewards: Rewards dict matching `obs`.
- dones: Done dict with only an __all__ multi-agent entry in
it. __all__=True, if episode is done for all agents.
- infos: An (empty) info dict.
"""
from mlagents_envs.base_env import ActionTuple
# Set only the required actions (from the DecisionSteps) in Unity3D.
all_agents = []
for behavior_name in self.unity_env.behavior_specs:
# New ML-Agents API: Set all agents actions at the same time
# via an ActionTuple. Since API v1.4.0.
if self.api_version[0] > 1 or (
self.api_version[0] == 1 and self.api_version[1] >= 4
):
actions = []
for agent_id in self.unity_env.get_steps(behavior_name)[0].agent_id:
key = behavior_name + "_{}".format(agent_id)
all_agents.append(key)
actions.append(action_dict[key])
if actions:
if actions[0].dtype == np.float32:
action_tuple = ActionTuple(continuous=np.array(actions))
else:
action_tuple = ActionTuple(discrete=np.array(actions))
self.unity_env.set_actions(behavior_name, action_tuple)
# Old behavior: Do not use an ActionTuple and set each agent's
# action individually.
else:
for agent_id in self.unity_env.get_steps(behavior_name)[
0
].agent_id_to_index.keys():
key = behavior_name + "_{}".format(agent_id)
all_agents.append(key)
self.unity_env.set_action_for_agent(
behavior_name, agent_id, action_dict[key]
)
# Do the step.
self.unity_env.step()
obs, rewards, terminateds, truncateds, infos = self._get_step_results()
# Global horizon reached? -> Return __all__ truncated=True, so user
# can reset. Set all agents' individual `truncated` to True as well.
self.episode_timesteps += 1
if self.episode_timesteps > self.episode_horizon:
return (
obs,
rewards,
terminateds,
dict({"__all__": True}, **{agent_id: True for agent_id in all_agents}),
infos,
)
return obs, rewards, terminateds, truncateds, infos
def reset(
self, *, seed=None, options=None
) -> Tuple[MultiAgentDict, MultiAgentDict]:
"""Resets the entire Unity3D scene (a single multi-agent episode)."""
self.episode_timesteps = 0
self.unity_env.reset()
obs, _, _, _, infos = self._get_step_results()
return obs, infos
def _get_step_results(self):
"""Collects those agents' obs/rewards that have to act in next `step`.
Returns:
Tuple:
obs: Multi-agent observation dict.
Only those observations for which to get new actions are
returned.
rewards: Rewards dict matching `obs`.
dones: Done dict with only an __all__ multi-agent entry in it.
__all__=True, if episode is done for all agents.
infos: An (empty) info dict.
"""
obs = {}
rewards = {}
infos = {}
for behavior_name in self.unity_env.behavior_specs:
decision_steps, terminal_steps = self.unity_env.get_steps(behavior_name)
# Important: Only update those sub-envs that are currently
# available within _env_state.
# Loop through all envs ("agents") and fill in, whatever
# information we have.
for agent_id, idx in decision_steps.agent_id_to_index.items():
key = behavior_name + "_{}".format(agent_id)
os = tuple(o[idx] for o in decision_steps.obs)
os = os[0] if len(os) == 1 else os
obs[key] = os
rewards[key] = (
decision_steps.reward[idx] + decision_steps.group_reward[idx]
)
for agent_id, idx in terminal_steps.agent_id_to_index.items():
key = behavior_name + "_{}".format(agent_id)
# Only overwrite rewards (last reward in episode), b/c obs
# here is the last obs (which doesn't matter anyways).
# Unless key does not exist in obs.
if key not in obs:
os = tuple(o[idx] for o in terminal_steps.obs)
obs[key] = os = os[0] if len(os) == 1 else os
rewards[key] = (
terminal_steps.reward[idx] + terminal_steps.group_reward[idx]
)
# Only use dones if all agents are done, then we should do a reset.
return obs, rewards, {"__all__": False}, {"__all__": False}, infos
@staticmethod
def get_policy_configs_for_game(
game_name: str,
) -> Tuple[dict, Callable[[AgentID], PolicyID]]:
# The RLlib server must know about the Spaces that the Client will be
# using inside Unity3D, up-front.
obs_spaces = {
# 3DBall.
"3DBall": Box(float("-inf"), float("inf"), (8,)),
# 3DBallHard.
"3DBallHard": Box(float("-inf"), float("inf"), (45,)),
# GridFoodCollector
"GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)),
# Pyramids.
"Pyramids": TupleSpace(
[
Box(float("-inf"), float("inf"), (56,)),
Box(float("-inf"), float("inf"), (56,)),
Box(float("-inf"), float("inf"), (56,)),
Box(float("-inf"), float("inf"), (4,)),
]
),
# SoccerTwos.
"SoccerPlayer": TupleSpace(
[
Box(-1.0, 1.0, (264,)),
Box(-1.0, 1.0, (72,)),
]
),
# SoccerStrikersVsGoalie.
"Goalie": Box(float("-inf"), float("inf"), (738,)),
"Striker": TupleSpace(
[
Box(float("-inf"), float("inf"), (231,)),
Box(float("-inf"), float("inf"), (63,)),
]
),
# Sorter.
"Sorter": TupleSpace(
[
Box(
float("-inf"),
float("inf"),
(
20,
23,
),
),
Box(float("-inf"), float("inf"), (10,)),
Box(float("-inf"), float("inf"), (8,)),
]
),
# Tennis.
"Tennis": Box(float("-inf"), float("inf"), (27,)),
# VisualHallway.
"VisualHallway": Box(float("-inf"), float("inf"), (84, 84, 3)),
# Walker.
"Walker": Box(float("-inf"), float("inf"), (212,)),
# FoodCollector.
"FoodCollector": TupleSpace(
[
Box(float("-inf"), float("inf"), (49,)),
Box(float("-inf"), float("inf"), (4,)),
]
),
}
action_spaces = {
# 3DBall.
"3DBall": Box(-1.0, 1.0, (2,), dtype=np.float32),
# 3DBallHard.
"3DBallHard": Box(-1.0, 1.0, (2,), dtype=np.float32),
# GridFoodCollector.
"GridFoodCollector": MultiDiscrete([3, 3, 3, 2]),
# Pyramids.
"Pyramids": MultiDiscrete([5]),
# SoccerStrikersVsGoalie.
"Goalie": MultiDiscrete([3, 3, 3]),
"Striker": MultiDiscrete([3, 3, 3]),
# SoccerTwos.
"SoccerPlayer": MultiDiscrete([3, 3, 3]),
# Sorter.
"Sorter": MultiDiscrete([3, 3, 3]),
# Tennis.
"Tennis": Box(-1.0, 1.0, (3,)),
# VisualHallway.
"VisualHallway": MultiDiscrete([5]),
# Walker.
"Walker": Box(-1.0, 1.0, (39,)),
# FoodCollector.
"FoodCollector": MultiDiscrete([3, 3, 3, 2]),
}
# Policies (Unity: "behaviors") and agent-to-policy mapping fns.
if game_name == "SoccerStrikersVsGoalie":
policies = {
"Goalie": PolicySpec(
observation_space=obs_spaces["Goalie"],
action_space=action_spaces["Goalie"],
),
"Striker": PolicySpec(
observation_space=obs_spaces["Striker"],
action_space=action_spaces["Striker"],
),
}
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
return "Striker" if "Striker" in agent_id else "Goalie"
elif game_name == "SoccerTwos":
policies = {
"PurplePlayer": PolicySpec(
observation_space=obs_spaces["SoccerPlayer"],
action_space=action_spaces["SoccerPlayer"],
),
"BluePlayer": PolicySpec(
observation_space=obs_spaces["SoccerPlayer"],
action_space=action_spaces["SoccerPlayer"],
),
}
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
return "BluePlayer" if "1_" in agent_id else "PurplePlayer"
else:
policies = {
game_name: PolicySpec(
observation_space=obs_spaces[game_name],
action_space=action_spaces[game_name],
),
}
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
return game_name
return policies, policy_mapping_fn