Skip to content

Commit

Permalink
First pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
Eddie Hebert committed Mar 30, 2015
1 parent 3b56a62 commit d3c421e
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 347 deletions.
6 changes: 0 additions & 6 deletions tests/test_algorithm.py
Expand Up @@ -553,9 +553,6 @@ def handle_data(context, data):
# placed.
self.zipline_test_config['order_count'] = 1

# self.zipline_test_config['transforms'] = \
# test_algo.transform_visitor.transforms.values()

zipline = simfactory.create_test_zipline(
**self.zipline_test_config)

Expand Down Expand Up @@ -618,9 +615,6 @@ def handle_data(context, data):
# https://www.dropbox.com/s/ulrk2qt0nrtrigb/Volume%20Share%20Worksheet.xlsx
self.zipline_test_config['expected_transactions'] = 67

# self.zipline_test_config['transforms'] = \
# test_algo.transform_visitor.transforms.values()

zipline = simfactory.create_test_zipline(
**self.zipline_test_config)
output, _ = assert_single_position(self, zipline)
Expand Down
16 changes: 0 additions & 16 deletions tests/test_exception_handling.py
Expand Up @@ -24,15 +24,13 @@
SetPortfolioAlgorithm,
)
from zipline.finance.slippage import FixedSlippage
from zipline.transforms.utils import StatefulTransform


from zipline.utils.test_utils import (
drain_zipline,
setup_logger,
teardown_logger,
ExceptionSource,
ExceptionTransform
)

DEFAULT_TIMEOUT = 15 # seconds
Expand Down Expand Up @@ -60,20 +58,6 @@ def test_datasource_exception(self):
with self.assertRaises(ZeroDivisionError):
output, _ = drain_zipline(self, zipline)

def test_tranform_exception(self):
exc_tnfm = StatefulTransform(ExceptionTransform)
self.zipline_test_config['transforms'] = [exc_tnfm]

zipline = simfactory.create_test_zipline(
**self.zipline_test_config
)

with self.assertRaises(AssertionError) as ctx:
output, _ = drain_zipline(self, zipline)

self.assertEqual(str(ctx.exception),
'An assertion message')

def test_exception_in_handle_data(self):
# Simulation
# ----------
Expand Down
38 changes: 6 additions & 32 deletions zipline/algorithm.py
Expand Up @@ -62,13 +62,9 @@
SlippageModel,
transact_partial
)
from zipline.gens.composites import (
date_sorted_sources,
sequential_transforms,
)
from zipline.gens.composites import date_sorted_sources
from zipline.gens.tradesimulation import AlgorithmSimulator
from zipline.sources import DataFrameSource, DataPanelSource
from zipline.transforms.utils import StatefulTransform
from zipline.utils.api_support import ZiplineAPI, api_method
import zipline.utils.events
from zipline.utils.events import (
Expand Down Expand Up @@ -143,8 +139,6 @@ def __init__(self, *args, **kwargs):
"""
self.datetime = None

self.registered_transforms = {}
self.transforms = []
self.sources = []

# List of trading controls to be used to validate orders.
Expand Down Expand Up @@ -312,8 +306,8 @@ def __repr__(self):

def _create_data_generator(self, source_filter, sim_params=None):
"""
Create a merged data generator using the sources and
transforms attached to this algorithm.
Create a merged data generator using the sources attached to this
algorithm.
::source_filter:: is a method that receives events in date
sorted order, and returns True for those events that should be
Expand Down Expand Up @@ -350,20 +344,16 @@ def update_time(date):
if source_filter:
date_sorted = filter(source_filter, date_sorted)

with_tnfms = sequential_transforms(date_sorted,
*self.transforms)

with_benchmarks = date_sorted_sources(benchmark_return_source,
with_tnfms)
date_sorted)

# Group together events with the same dt field. This depends on the
# events already being sorted.
return groupby(with_benchmarks, attrgetter('dt'))

def _create_generator(self, sim_params, source_filter=None):
"""
Create a basic generator setup using the sources and
transforms attached to this algorithm.
Create a basic generator setup using the sources to this algorithm.
::source_filter:: is a method that receives events in date
sorted order, and returns True for those events that should be
Expand Down Expand Up @@ -459,23 +449,11 @@ def run(self, source, overwrite_sim_params=True,
self.sim_params.data_frequency,
)

# Create transforms by wrapping them into StatefulTransforms
self.transforms = []
for namestring, trans_descr in iteritems(self.registered_transforms):
sf = StatefulTransform(
trans_descr['class'],
*trans_descr['args'],
**trans_descr['kwargs']
)
sf.namestring = namestring

self.transforms.append(sf)

# force a reset of the performance tracker, in case
# this is a repeat run of the algorithm.
self.perf_tracker = None

# create transforms and zipline
# create zipline
self.gen = self._create_generator(self.sim_params)

with ZiplineAPI(self):
Expand Down Expand Up @@ -854,10 +832,6 @@ def set_sources(self, sources):
assert isinstance(sources, list)
self.sources = sources

def set_transforms(self, transforms):
assert isinstance(transforms, list)
self.transforms = transforms

# Remain backwards compatibility
@property
def data_frequency(self):
Expand Down
16 changes: 0 additions & 16 deletions zipline/gens/composites.py
Expand Up @@ -15,8 +15,6 @@

import heapq

from six.moves import reduce


def _decorate_source(source):
for message in source:
Expand All @@ -33,17 +31,3 @@ def date_sorted_sources(*sources):
# Strip out key decoration
for _, message in sorted_stream:
yield message


def sequential_transforms(stream_in, *transforms):
"""
Apply each transform in transforms sequentially to each event in stream_in.
Each transform application will add a new entry indexed to the transform's
hash string.
"""
# Recursively apply all transforms to the stream.
stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream),
transforms,
stream_in)

return stream_out
24 changes: 22 additions & 2 deletions zipline/transforms/batch_transform.py
Expand Up @@ -37,8 +37,6 @@

from zipline.finance import trading

from . utils import check_window_length

log = logbook.Logger('BatchTransform')
func_map = {'open_price': 'first',
'close_price': 'last',
Expand Down Expand Up @@ -92,6 +90,28 @@ def get_date(mkt_close, d1, d2, d):
return d1


class InvalidWindowLength(Exception):
"""
Error raised when the window length is unusable.
"""
pass


def check_window_length(window_length):
"""
Ensure the window length provided to a transform is valid.
"""
if window_length is None:
raise InvalidWindowLength("window_length must be provided")
if not isinstance(window_length, Integral):
raise InvalidWindowLength(
"window_length must be an integer-like number")
if window_length == 0:
raise InvalidWindowLength("window_length must be non-zero")
if window_length < 0:
raise InvalidWindowLength("window_length must be positive")


class BatchTransform(object):
"""Base class for batch transforms with a trailing window of
variable length. As opposed to pure EventWindows that get a stream
Expand Down

0 comments on commit d3c421e

Please sign in to comment.