Skip to content

Commit

Permalink
BUG: Put initialization of perf_tracker back in __init__
Browse files Browse the repository at this point in the history
The initialization of perf_tracker had been moved from __init__
in TradingAlgorithm to _create_generator. This caused perf_tracker
to not be ready when portfolio requested it. portfolio was consequently
not ready for access in init. portfolio can now be accessed in init
again, assuming valid sim_params are passed. Otherwise it will be
available in handle_data().
  • Loading branch information
CaptainKanuk committed Jul 17, 2014
1 parent eae41b8 commit 96ceaa2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
19 changes: 19 additions & 0 deletions tests/test_algorithm.py
Expand Up @@ -34,6 +34,7 @@
TradingControlViolation,
)
from zipline.test_algorithms import (
access_portfolio_in_init,
AmbitiousStopLimitAlgorithm,
EmptyPositionsAlgorithm,
InvalidOrderAlgorithm,
Expand Down Expand Up @@ -617,6 +618,24 @@ def test_order_in_init(self):
)
set_algo_instance(test_algo)

def test_portfolio_in_init(self):
"""
Test that accessing portfolio in init doesn't break.
"""
test_algo = TradingAlgorithm(
script=access_portfolio_in_init,
sim_params=self.sim_params,
)
set_algo_instance(test_algo)

self.zipline_test_config['algorithm'] = test_algo
self.zipline_test_config['trade_count'] = 1

zipline = simfactory.create_test_zipline(
**self.zipline_test_config)

output, _ = drain_zipline(self, zipline)


class TestHistory(TestCase):
def test_history(self):
Expand Down
4 changes: 1 addition & 3 deletions zipline/algorithm.py
Expand Up @@ -157,9 +157,7 @@ def __init__(self, *args, **kwargs):
self.sim_params = create_simulation_parameters(
capital_base=self.capital_base
)

# perf_tacker gets instantiated in ._create_generator()
self.perf_tracker = None
self.perf_tracker = PerformanceTracker(self.sim_params)

self.blotter = kwargs.pop('blotter', None)
if not self.blotter:
Expand Down
9 changes: 9 additions & 0 deletions zipline/test_algorithms.py
Expand Up @@ -961,6 +961,15 @@ def handle_data(context, data):
pass
"""

access_portfolio_in_init = """
def initialize(context):
var = context.portfolio.cash
pass
def handle_data(context, data):
pass
"""

call_all_order_methods = """
from zipline.api import (order,
order_value,
Expand Down

0 comments on commit 96ceaa2

Please sign in to comment.