-
Notifications
You must be signed in to change notification settings - Fork 0
/
mine_sweeper_demo.py
189 lines (170 loc) · 6.3 KB
/
mine_sweeper_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import random
from dataclasses import dataclass
from typing import Dict, List, Mapping, Optional, Tuple
import enn_trainer.config as config
import hyperstate
from enn_trainer.agent import PPOAgent
from enn_trainer.train import State, init_train_state, train
from entity_gym.env import (
Action,
ActionName,
ActionSpace,
CategoricalAction,
CategoricalActionMask,
CategoricalActionSpace,
Entity,
Environment,
Observation,
ObsSpace,
SelectEntityAction,
SelectEntityActionMask,
SelectEntityActionSpace,
)
from hyperstate import StateManager
class MineSweeper(Environment):
"""
The MineSweeper environment contains two types of objects, mines and robots.
The player controls all robots in the environment.
On every step, each robot may move in one of four cardinal directions, or stay in place and defuse all adjacent mines.
If a robot defuses a mine, it is removed from the environment.
If a robot steps on a mine, it is removed from the environment and the player loses the game.
The player wins the game when all mines are defused.
"""
def __init__(
self,
width: int = 6,
height: int = 6,
nmines: int = 5,
nrobots: int = 2,
orbital_cannon: bool = False,
cooldown_period: int = 5,
):
self.width = width
self.height = height
self.nmines = nmines
self.nrobots = nrobots
self.orbital_cannon = orbital_cannon
self.cooldown_period = cooldown_period
self.orbital_cannon_cooldown = cooldown_period
# Positions of robots and mines
self.robots: List[Tuple[int, int]] = []
self.mines: List[Tuple[int, int]] = []
def obs_space(cls) -> ObsSpace:
return ObsSpace(
entities={
"Mine": Entity(features=["x", "y"]),
"Robot": Entity(features=["x", "y"]),
"Orbital Cannon": Entity(["cooldown"]),
}
)
def action_space(cls) -> Dict[ActionName, ActionSpace]:
return {
"Move": CategoricalActionSpace(
["Up", "Down", "Left", "Right", "Defuse Mines"],
),
"Fire Orbital Cannon": SelectEntityActionSpace(),
}
def reset(self) -> Observation:
positions = random.sample(
[(x, y) for x in range(self.width) for y in range(self.height)],
self.nmines + self.nrobots,
)
self.mines = positions[: self.nmines]
self.robots = positions[self.nmines :]
self.orbital_cannon_cooldown = self.cooldown_period
return self.observe()
def observe(self) -> Observation:
done = len(self.mines) == 0 or len(self.robots) == 0
reward = 1.0 if len(self.mines) == 0 else 0.0
return Observation(
entities={
"Mine": (
self.mines,
[("Mine", i) for i in range(len(self.mines))],
),
"Robot": (
self.robots,
[("Robot", i) for i in range(len(self.robots))],
),
"Orbital Cannon": (
[(self.orbital_cannon_cooldown,)],
[("Orbital Cannon", 0)],
)
if self.orbital_cannon
else None,
},
actions={
"Move": CategoricalActionMask(
# Allow all robots to move
actor_types=["Robot"],
mask=[self.valid_moves(x, y) for x, y in self.robots],
),
"Fire Orbital Cannon": SelectEntityActionMask(
# Only the Orbital Cannon can fire, but not if cooldown > 0
actor_types=["Orbital Cannon"] if self.orbital_cannon_cooldown == 0 else [],
# Both mines and robots can be fired at
actee_types=["Mine", "Robot"],
),
},
# The game is done once there are no more mines or robots
done=done,
# Give reward of 1.0 for defusing all mines
reward=reward,
)
def act(self, actions: Mapping[ActionName, Action]) -> Observation:
fire = actions["Fire Orbital Cannon"]
assert isinstance(fire, SelectEntityAction)
remove_robot = None
for (entity_type, i) in fire.actees:
if entity_type == "Mine":
self.mines.remove(self.mines[i])
elif entity_type == "Robot":
# Don't remove yet to keep indices valid
remove_robot = i
move = actions["Move"]
assert isinstance(move, CategoricalAction)
for (_, i), choice in zip(move.actors, move.indices):
if self.robots[i] is None:
continue
# Action space is ["Up", "Down", "Left", "Right", "Defuse Mines"],
x, y = self.robots[i]
if choice == 0 and y < self.height - 1:
self.robots[i] = (x, y + 1)
elif choice == 1 and y > 0:
self.robots[i] = (x, y - 1)
elif choice == 2 and x > 0:
self.robots[i] = (x - 1, y)
elif choice == 3 and x < self.width - 1:
self.robots[i] = (x + 1, y)
elif choice == 4:
# Remove all mines adjacent to this robot
rx, ry = self.robots[i]
self.mines = [(x, y) for (x, y) in self.mines if abs(x - rx) + abs(y - ry) > 1]
if remove_robot is not None:
self.robots.pop(remove_robot)
# Remove all robots that stepped on a mine
self.robots = [r for r in self.robots if r not in self.mines]
return self.observe()
def valid_moves(self, x: int, y: int) -> List[bool]:
return [
x < self.width - 1,
x > 0,
y < self.height - 1,
y > 0,
# Always allow staying in place and defusing mines
True,
]
@dataclass
class TrainConfig(config.TrainConfig):
pass
@hyperstate.stateful_command(TrainConfig, State, init_train_state)
def main(state_manager: StateManager) -> None:
env = MineSweeper
agent: Optional[PPOAgent] = None
train(
state_manager=state_manager,
env=env,
agent=agent,
)
if __name__ == "__main__":
main()