diff --git a/tensortrade/data/stream/feed.py b/tensortrade/data/stream/feed.py index 0e7330503..def63d28f 100644 --- a/tensortrade/data/stream/feed.py +++ b/tensortrade/data/stream/feed.py @@ -99,14 +99,16 @@ def has_next(self) -> bool: return all(node.has_next() for node in self.process) def __add__(self, other): - if isinstance(other, DataFeed): - nodes = list(set(self.inputs + other.inputs)) - feed = DataFeed(nodes) + if not isinstance(other, DataFeed): + raise TypeError(f'can only concatenate DataFeed (not "{type(other).__name__}") to DataFeed.') - for listener in self.listeners + other.listeners: - feed.attach(listener) + nodes = self.inputs + other.inputs + feed = DataFeed(nodes) - return feed + for listener in self.listeners + other.listeners: + feed.attach(listener) + + return feed def reset(self): for node in self.process: diff --git a/tensortrade/environments/trading_environment.py b/tensortrade/environments/trading_environment.py index 4c0ed4fd8..61dcfaf7b 100644 --- a/tensortrade/environments/trading_environment.py +++ b/tensortrade/environments/trading_environment.py @@ -104,8 +104,8 @@ def compile(self): if not self.feed: self.feed = create_internal_feed(self.portfolio) - - self.feed = self.feed + create_internal_feed(self.portfolio) + else: + self.feed = self.feed + create_internal_feed(self.portfolio) initial_obs = self.feed.next() n_features = len(initial_obs.keys()) if self.use_internal else len(self._external_keys)