-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
exception_wrapper.py
38 lines (30 loc) · 1.1 KB
/
exception_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import logging
import traceback
import gymnasium as gym
logger = logging.getLogger(__name__)
class TooManyResetAttemptsException(Exception):
def __init__(self, max_attempts: int):
super().__init__(
f"Reached the maximum number of attempts ({max_attempts}) "
f"to reset an environment."
)
class ResetOnExceptionWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, max_reset_attempts: int = 5):
super().__init__(env)
self.max_reset_attempts = max_reset_attempts
def reset(self, **kwargs):
attempt = 0
while attempt < self.max_reset_attempts:
try:
return self.env.reset(**kwargs)
except Exception:
logger.error(traceback.format_exc())
attempt += 1
else:
raise TooManyResetAttemptsException(self.max_reset_attempts)
def step(self, action):
try:
return self.env.step(action)
except Exception:
logger.error(traceback.format_exc())
return self.reset(), 0.0, False, {"__terminated__": True}