diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index cf4145bbee..94c2b0acc4 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -53,6 +53,7 @@ SetMaxOrderCountAlgorithm, SetMaxOrderSizeAlgorithm, api_algo, + api_get_environment_algo, api_symbol_algo, call_all_order_methods, call_order_in_init, @@ -407,6 +408,13 @@ def test_api_calls_string(self): algo = TradingAlgorithm(script=api_algo) algo.run(self.df) + def test_api_get_environment(self): + environment = 'zipline' + algo = TradingAlgorithm(script=api_get_environment_algo, + environment=environment) + algo.run(self.df) + self.assertEqual(algo.environment, environment) + def test_api_symbol(self): algo = TradingAlgorithm(script=api_symbol_algo) algo.run(self.df) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index c10daba3d4..1ccfe46a4f 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -126,6 +126,8 @@ def __init__(self, *args, **kwargs): How much capital to start with. instant_fill : bool Whether to fill orders immediately or on next bar. + environment : str + The environment that this algorithm is running in. """ self.datetime = None @@ -139,6 +141,8 @@ def __init__(self, *args, **kwargs): self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) + self._environment = kwargs.pop('environment', 'zipline') + self.logger = None self.benchmark_return_source = None @@ -470,6 +474,10 @@ def add_transform(self, transform_class, tag, *args, **kwargs): 'args': args, 'kwargs': kwargs} + @api_method + def get_environment(self): + return self._environment + @api_method def record(self, *args, **kwargs): """ diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index d3928d852c..fe05543c04 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -939,6 +939,16 @@ def handle_data(context, data): record(incr=context.incr) """ +api_get_environment_algo = """ +from zipline.api import get_environment, order, symbol + + +def initialize(context): + context.environment = get_environment() + +handle_data = lambda context, data: order(symbol(0), 1) +""" + api_symbol_algo = """ from zipline.api import (order, symbol)