Skip to content

Commit

Permalink
Allow setting max_allowed_loss in TradingEnvironment.
Browse files Browse the repository at this point in the history
  • Loading branch information
notadamking committed Feb 10, 2020
1 parent ab9d9e4 commit a6f04ca
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tensortrade/environments/trading_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ def __init__(self,

self._enable_logger = kwargs.get('enable_logger', False)
self._observation_dtype = kwargs.get('dtype', np.float32)
self._observation_lows = kwargs.get('observation_lows', 0)
self._observation_highs = kwargs.get('observation_highs', 1)
self._observation_lows = kwargs.get('observation_lows', -np.iinfo(np.int32).max)
self._observation_highs = kwargs.get('observation_highs', np.iinfo(np.int32).max)
self._max_allowed_loss = kwargs.get('max_allowed_loss', 0.1)

if self._enable_logger:
self.logger = logging.getLogger(kwargs.get('logger_name', __name__))
Expand Down Expand Up @@ -188,14 +189,15 @@ def step(self, action: int) -> Tuple[np.array, float, bool, dict]:
self.history.push(obs_row)

obs = self.history.observe()
obs = obs.astype(self._observation_dtype)

reward = self.reward_scheme.get_reward(self._portfolio)
reward = np.nan_to_num(reward)

if np.bitwise_not(np.isfinite(reward)):
raise ValueError('Reward returned by the reward scheme must by a finite float.')

done = (self.portfolio.profit_loss < 0.1) or not self.feed.has_next()
done = (self.portfolio.profit_loss < self._max_allowed_loss) or not self.feed.has_next()

info = {
'step': self.clock.step,
Expand Down

0 comments on commit a6f04ca

Please sign in to comment.