Skip to content

Commit

Permalink
Remove base_instrument necessity from TradingContext.
Browse files Browse the repository at this point in the history
  • Loading branch information
notadamking committed Feb 10, 2020
1 parent 22b47fb commit ab9d9e4
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 84 deletions.
7 changes: 4 additions & 3 deletions tensortrade/base/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ class InitContextMeta(ABCMeta):
"""Metaclass that executes `__init__` of instance in it's base."""

def __call__(cls, *args, **kwargs):
context = TradingContext.get_context()
registered_name = get_registry()[cls]
tc = TradingContext.get_context()
data = tc.data.get(registered_name, {})
config = {**tc.shared, **data}

data = context.data.get(registered_name, {})
config = {**context.shared, **data}

instance = cls.__new__(cls, *args, **kwargs)
setattr(instance, 'context', Context(**config))
Expand Down
53 changes: 8 additions & 45 deletions tensortrade/base/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ class TradingContext(UserDict):
Arguments:
shared: A context that is shared between all components that are made under the overarching `TradingContext`.
exchanges: A context that is specific to components with a registered name of `exchanges`.
actions: A context that is specific to components with a registered name of `actions`.
rewards: A context that is specific to components with a registered name of `rewards`.
features: A context that is specific to components with a registered name of `features`.
Warnings:
If there is a conflict in the contexts of different components because
Expand All @@ -36,8 +32,8 @@ class TradingContext(UserDict):
"""
contexts = threading.local()

def __init__(self, base_instrument: Instrument = USD, **config):
super().__init__(base_instrument=base_instrument, **config)
def __init__(self, config: dict):
super().__init__(**config)

for name in registered_names():
if name not in get_major_component_names():
Expand All @@ -46,15 +42,10 @@ def __init__(self, base_instrument: Instrument = USD, **config):
config_items = {k: config[k] for k in config.keys()
if k not in registered_names()}

self._config = config
self._shared = config.get('shared', {})
self._exchanges = config.get('exchanges', {})
self._actions = config.get('actions', {})
self._rewards = config.get('rewards', {})
self._features = config.get('features', {})
self._slippage = config.get('slippage', {})

self._shared = {
'base_instrument': base_instrument,
**self._shared,
**config_items
}
Expand All @@ -63,26 +54,6 @@ def __init__(self, base_instrument: Instrument = USD, **config):
def shared(self) -> dict:
return self._shared

@property
def exchanges(self) -> dict:
return self._exchanges

@property
def actions(self) -> dict:
return self._actions

@property
def rewards(self) -> dict:
return self._rewards

@property
def features(self) -> dict:
return self._features

@property
def slippage(self) -> dict:
return self._slippage

def __enter__(self):
"""Adds a new context to the context stack.
Expand All @@ -99,7 +70,7 @@ def __exit__(self, typ, value, traceback):
@classmethod
def get_contexts(cls):
if not hasattr(cls.contexts, 'stack'):
cls.contexts.stack = [TradingContext()]
cls.contexts.stack = [TradingContext({})]

return cls.contexts.stack

Expand All @@ -113,34 +84,26 @@ def from_json(cls, path: str):
with open(path, "rb") as fp:
config = json.load(fp)

return TradingContext(**config)
return TradingContext(config)

@classmethod
def from_yaml(cls, path: str):
with open(path, "rb") as fp:
config = yaml.load(fp, Loader=yaml.FullLoader)

return TradingContext(**config)
return TradingContext(config)


class Context(UserDict):
"""A context that is injected into every instance of a class that is
a subclass of component.
Arguments:
base_instrument: The exchange symbol of the instrument to store/measure value in.
"""

def __init__(self, base_instrument: Instrument = USD, **kwargs):
super(Context, self).__init__(base_instrument=base_instrument, **kwargs)
def __init__(self, **kwargs):
super(Context, self).__init__(**kwargs)

self._base_instrument = base_instrument
self.__dict__ = {**self.__dict__, **self.data}

@property
def base_instrument(self) -> Instrument:
return self._base_instrument

def __str__(self):
data = ['{}={}'.format(k, getattr(self, k)) for k in self.__slots__]
return '<{}: {}>'.format(self.__class__.__name__, ', '.join(data))
6 changes: 3 additions & 3 deletions tensortrade/base/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@


MAJOR_COMPONENTS = [
'exchanges',
'actions',
'rewards',
'features',
'slippage'
'portfolio',
'exchanges',
'slippage',
]


Expand Down
15 changes: 7 additions & 8 deletions tests/tensortrade/unit/base/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def message(self):
class WorthMessageComponent(DataMessageComponent):

def __init__(self, name, value):
super(WorthMessageComponent, self).__init__({name: value})
super(WorthMessageComponent, self).__init__(data={name: value})
self.name = name
self.value = value

Expand Down Expand Up @@ -82,7 +82,6 @@ def test_no_context_injected_outside_with():
value = 'the time and effort.'
instance = WorthMessageComponent(name=name, value=value)

assert instance.context
assert instance.name == name
assert instance.value == value

Expand All @@ -95,7 +94,7 @@ def test_no_context_injected_outside_with():

def test_injects_concrete_tensor_trade_component_with_context():

with td.TradingContext(**config):
with td.TradingContext(config):

name = 'TensorTrade'
value = 'the time and effort.'
Expand All @@ -106,7 +105,7 @@ def test_injects_concrete_tensor_trade_component_with_context():

def test_inject_multiple_components_with_context():

with td.TradingContext(**config):
with td.TradingContext(config):
name = 'TensorTrade'
value = 'the time and effort.'
instance = WorthMessageComponent(name=name, value=value)
Expand All @@ -127,7 +126,7 @@ def test_injects_component_space():
**config
}

with td.TradingContext(**c) as c:
with td.TradingContext(c) as c:
name = 'TensorTrade'
value = 'the time and effort.'
instance = WorthMessageComponent(name=name, value=value)
Expand All @@ -154,7 +153,7 @@ def test_only_name_registered_component_space():
**config
}

with td.TradingContext(**c) as c:
with td.TradingContext(c) as c:
name = 'TensorTrade'
value = 'the time and effort.'
instance = WorthMessageComponent(name=name, value=value)
Expand Down Expand Up @@ -182,7 +181,7 @@ def test_inject_contexts_at_different_levels():
**config
}

with td.TradingContext(**c1):
with td.TradingContext(c1):
name = 'TensorTrade'
value = 'the time and effort.'
instance1 = WorthMessageComponent(name=name, value=value)
Expand All @@ -192,7 +191,7 @@ def test_inject_contexts_at_different_levels():
assert hasattr(win1.context, 'plans_var')
assert hasattr(lose1.context, 'plans_var')

with td.TradingContext(**c2):
with td.TradingContext(c2):
name = 'TensorTrade'
value = 'the time and effort.'
instance2 = WorthMessageComponent(name=name, value=value)
Expand Down
38 changes: 19 additions & 19 deletions tests/tensortrade/unit/base/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ def test_is_trading_context_class_there():


def test_has_config_attribute():
c = TradingContext()
c = TradingContext({
"test": True,
"exchanges": {"test": True},
"actions": {"test": True},
"rewards": {"test": True},
})

assert hasattr(c, 'shared')
assert hasattr(c, 'exchanges')
assert hasattr(c, 'actions')
assert hasattr(c, 'rewards')
assert hasattr(c, 'features')


config = {
Expand All @@ -36,48 +37,48 @@ def test_has_config_attribute():


def test_init():
c = TradingContext(base_instrument=config['base_instrument'],
instruments=config['instruments'])
c = TradingContext({"base_instrument": config['base_instrument'],
"instruments": config['instruments']})
assert c.shared.get('base_instrument') == 'EURO'
assert c.shared.get('instruments') == ['BTC', 'ETH']


def test_init_with_kwargs():
c = TradingContext(**config)
c = TradingContext(config)
assert c.shared.get('base_instrument') == 'EURO'
assert c.shared.get('instruments') == ['BTC', 'ETH']


def test_context_creation():

with td.TradingContext(**config) as tc1:
with td.TradingContext(config) as tc1:
assert tc1.data == config

with td.TradingContext(**config) as tc2:
with td.TradingContext(config) as tc2:
assert TradingContext.get_context() == tc2

assert TradingContext.get_context() == tc1


def test_get_context_from_tensor_trade_level():
with td.TradingContext(**config) as tc:
with td.TradingContext(config) as tc:
assert get_context() == tc


def test_context_within_context():

with td.TradingContext(**config) as tc1:
with td.TradingContext(config) as tc1:
assert get_context() == tc1

with td.TradingContext(**config) as tc2:
with td.TradingContext(config) as tc2:
assert get_context() == tc2

assert get_context() == tc1


def test_context_retains_data_outside_with():

with td.TradingContext(**config) as tc:
with td.TradingContext(config) as tc:
assert tc.data == config

assert tc.data == config
Expand All @@ -96,11 +97,10 @@ def test_create_trading_context_from_json():
}

with td.TradingContext.from_json(path) as tc:

assert tc.shared['base_instrument'] == "EURO"
assert tc.shared['instruments'] == ["BTC", "ETH"]
assert tc.actions == actions
assert tc.exchanges == exchanges
assert tc._config['actions'] == actions
assert tc._config['exchanges'] == exchanges


def test_create_trading_context_from_yaml():
Expand All @@ -119,5 +119,5 @@ def test_create_trading_context_from_yaml():

assert tc.shared['base_instrument'] == "EURO"
assert tc.shared['instruments'] == ["BTC", "ETH"]
assert tc.actions == actions
assert tc.exchanges == exchanges
assert tc._config['actions'] == actions
assert tc._config['exchanges'] == exchanges
13 changes: 9 additions & 4 deletions tests/tensortrade/unit/base/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,23 @@
import tensortrade.actions as actions
import tensortrade.rewards as rewards

from tensortrade.actions import DynamicOrders, ManagedRiskOrders
from tensortrade.actions import SimpleOrders, ManagedRiskOrders
from tensortrade.rewards import SimpleProfit, RiskAdjustedReturns

warnings.filterwarnings("ignore")


def test_dynamic_actions():
assert isinstance(actions.get('dynamic'), DynamicOrders)
def test_simple_actions():
assert isinstance(actions.get('simple'), SimpleOrders)


def test_managed_risk_actions():
assert isinstance(actions.get('managed-risk'), ManagedRiskOrders)


def test_simple_reward_scheme():
assert isinstance(rewards.get('simple'), rewards.SimpleProfit)
assert isinstance(rewards.get('simple'), SimpleProfit)


def test_risk_adjusted_reward_scheme():
assert isinstance(rewards.get('risk-adjusted'), RiskAdjustedReturns)
4 changes: 2 additions & 2 deletions tests/tensortrade/unit/rewards/test_reward_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_reward(self, current_step: int, trade: Trade) -> float:

def test_injects_reward_scheme_with_context():

with TradingContext(**config):
with TradingContext(config):

reward_scheme = ConcreteRewardScheme()

Expand All @@ -31,7 +31,7 @@ def test_injects_reward_scheme_with_context():

def test_injects_string_intialized_reward_scheme():

with TradingContext(**config):
with TradingContext(config):

reward_scheme = get('simple')

Expand Down

0 comments on commit ab9d9e4

Please sign in to comment.