Skip to content

Commit

Permalink
Merge pull request #392 from tensortrade-org/environment-random-start
Browse files Browse the repository at this point in the history
Environment random start
  • Loading branch information
carlogrisetti committed Feb 15, 2022
2 parents 3c75fb1 + 88eada2 commit bdc83c0
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 11 deletions.
5 changes: 5 additions & 0 deletions tensortrade/env/default/__init__.py
Expand Up @@ -20,6 +20,7 @@ def create(portfolio: 'Portfolio',
feed: 'DataFeed',
window_size: int = 1,
min_periods: int = None,
random_start_pct: float = 0.00,
**kwargs) -> TradingEnv:
"""Creates the default `TradingEnv` of the project to be used in training
RL agents.
Expand All @@ -39,6 +40,9 @@ def create(portfolio: 'Portfolio',
The size of the look back window to use for the observation space.
min_periods : int, optional
The minimum number of steps to warm up the `feed`.
random_start_pct : float, optional
Whether to randomize the starting point within the environment at each
observer reset, starting in the first X percentage of the sample
**kwargs : keyword arguments
Extra keyword arguments needed to build the environment.
Expand Down Expand Up @@ -86,5 +90,6 @@ def create(portfolio: 'Portfolio',
informer=kwargs.get("informer", informers.TensorTradeInformer()),
renderer=renderer,
min_periods=min_periods,
random_start_pct=random_start_pct,
)
return env
4 changes: 2 additions & 2 deletions tensortrade/env/default/observers.py
Expand Up @@ -280,11 +280,11 @@ def has_next(self) -> bool:
"""
return self.feed.has_next()

def reset(self) -> None:
def reset(self, random_start=0) -> None:
"""Resets the observer"""
self.renderer_history = []
self.history.reset()
self.feed.reset()
self.feed.reset(random_start)
self.warmup()


Expand Down
2 changes: 1 addition & 1 deletion tensortrade/env/generic/components/observer.py
Expand Up @@ -54,6 +54,6 @@ def observe(self, env: 'TradingEnv') -> np.array:
"""
raise NotImplementedError()

def reset(self):
def reset(self, random_start=0):
"""Resets the observer."""
pass
14 changes: 13 additions & 1 deletion tensortrade/env/generic/environment.py
Expand Up @@ -16,6 +16,7 @@
import logging

from typing import Dict, Any, Tuple
from random import randint

import gym
import numpy as np
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self,
informer: Informer,
renderer: Renderer,
min_periods: int = None,
random_start_pct: float = 0.00,
**kwargs) -> None:
super().__init__()
self.clock = Clock()
Expand All @@ -76,6 +78,7 @@ def __init__(self,
self.informer = informer
self.renderer = renderer
self.min_periods = min_periods
self.random_start_pct = random_start_pct

for c in self.components.values():
c.clock = self.clock
Expand Down Expand Up @@ -139,12 +142,21 @@ def reset(self) -> 'np.array':
obs : `np.array`
The first observation of the environment.
"""
if self.random_start_pct > 0.00:
size = len(self.observer.feed.process[-1].inputs[0].iterable)
random_start = randint(0, int(size * self.random_start_pct))
else:
random_start = 0

self.episode_id = str(uuid.uuid4())
self.clock.reset()

for c in self.components.values():
if hasattr(c, "reset"):
c.reset()
if isinstance(c, Observer):
c.reset(random_start=random_start)
else:
c.reset()

obs = self.observer.observe(self)

Expand Down
9 changes: 7 additions & 2 deletions tensortrade/feed/core/base.py
Expand Up @@ -528,6 +528,8 @@ def __init__(self, source: "Iterable[T]", dtype: str = None):
except StopIteration:
self.stop = True

self._random_start = 0

def forward(self) -> T:
v = self.current
try:
Expand All @@ -539,11 +541,14 @@ def forward(self) -> T:
def has_next(self):
return not self.stop

def reset(self):
def reset(self, random_start=0):
if random_start != 0:
self._random_start = random_start

if self.is_gen:
self.generator = self.gen_fn()
else:
self.generator = iter(self.iterable)
self.generator = iter(self.iterable[self._random_start:])
self.stop = False

try:
Expand Down
9 changes: 6 additions & 3 deletions tensortrade/feed/core/feed.py
Expand Up @@ -2,7 +2,7 @@

from typing import List

from tensortrade.feed.core.base import Stream, T, Placeholder
from tensortrade.feed.core.base import Stream, T, Placeholder, IterableStream


class DataFeed(Stream[dict]):
Expand Down Expand Up @@ -54,9 +54,12 @@ def next(self) -> dict:
def has_next(self) -> bool:
return all(s.has_next() for s in self.process)

def reset(self) -> None:
def reset(self, random_start=0) -> None:
for s in self.process:
s.reset()
if isinstance(s, IterableStream):
s.reset(random_start)
else:
s.reset()


class PushFeed(DataFeed):
Expand Down
2 changes: 1 addition & 1 deletion tensortrade/version.py
@@ -1 +1 @@
__version__ = "1.0.4-dev0"
__version__ = "1.0.4-dev1"
53 changes: 52 additions & 1 deletion tests/tensortrade/unit/env/default/test_env.py
Expand Up @@ -87,7 +87,58 @@ def test_runs_with_external_feed_only(portfolio):
reward_scheme=reward_scheme,
feed=feed,
window_size=50,
enable_logger=False
enable_logger=False,
)

done = False
obs = env.reset()
while not done:
action = env.action_space.sample()
obs, reward, done, info = env.step(action)

assert obs.shape[0] == 50


def test_runs_with_random_start(portfolio):

df = pd.read_csv("tests/data/input/bitfinex_(BTC,ETH)USD_d.csv").tail(100)
df = df.rename({"Unnamed: 0": "date"}, axis=1)
df = df.set_index("date")

bitfinex_btc = df.loc[:, [name.startswith("BTC") for name in df.columns]]
bitfinex_eth = df.loc[:, [name.startswith("ETH") for name in df.columns]]

ta.add_all_ta_features(
bitfinex_btc,
colprefix="BTC:",
**{k: "BTC:" + k for k in ['open', 'high', 'low', 'close', 'volume']}
)
ta.add_all_ta_features(
bitfinex_eth,
colprefix="ETH:",
**{k: "ETH:" + k for k in ['open', 'high', 'low', 'close', 'volume']}
)

streams = []
with NameSpace("bitfinex"):
for name in bitfinex_btc.columns:
streams += [Stream.source(list(bitfinex_btc[name]), dtype="float").rename(name)]
for name in bitfinex_eth.columns:
streams += [Stream.source(list(bitfinex_eth[name]), dtype="float").rename(name)]

feed = DataFeed(streams)

action_scheme = ManagedRiskOrders()
reward_scheme = SimpleProfit()

env = default.create(
portfolio=portfolio,
action_scheme=action_scheme,
reward_scheme=reward_scheme,
feed=feed,
window_size=50,
enable_logger=False,
random_start_pct=0.10, # Randomly start within the first 10% of the sample
)

done = False
Expand Down

0 comments on commit bdc83c0

Please sign in to comment.