/
utils.py
560 lines (461 loc) · 20.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
#
# Copyright 2012 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generator versions of transforms.
"""
import types
import logbook
from copy import deepcopy
from datetime import datetime
from collections import deque
from abc import ABCMeta, abstractmethod
import pandas as pd
from zipline import ndict
from zipline.utils.tradingcalendar import non_trading_days
from zipline.gens.utils import assert_sort_unframe_protocol, hash_args
log = logbook.Logger('Transform')
class Passthrough(object):
PASSTHROUGH = True
"""
Trivial class for forwarding events.
"""
def __init__(self):
pass
def update(self, event):
pass
class TransformMeta(type):
"""
Metaclass that automatically packages a class inside of
StatefulTransform on initialization. Specifically, if Foo is a
class with its __metaclass__ attribute set to TransformMeta, then
calling Foo(*args, **kwargs) will return StatefulTransform(Foo,
*args, **kwargs) instead of an instance of Foo. (Note that you can
still recover an instance of a "raw" Foo by introspecting the
resulting StatefulTransform's 'state' field.
"""
def __call__(cls, *args, **kwargs):
return StatefulTransform(cls, *args, **kwargs)
class StatefulTransform(object):
"""
Generic transform generator that takes each message from an
in-stream and passes it to a state object. For each call to
update, the state class must produce a message to be fed
downstream. Any transform class with the FORWARDER class variable
set to true will forward all fields in the original message.
Otherwise only dt, tnfm_id, and tnfm_value are forwarded.
"""
def __init__(self, tnfm_class, *args, **kwargs):
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
"Stateful transform requires a class."
assert hasattr(tnfm_class, 'update'), \
"Stateful transform requires the class to have an update method"
# Flag set inside the Passthrough transform class to signify special
# behavior if we are being fed to merged_transforms.
self.passthrough = hasattr(tnfm_class, 'PASSTHROUGH')
# Flags specifying how to append the calculated value.
# Merged is the default for ease of testing, but we use sequential
# in production.
self.sequential = False
self.merged = True
# Create an instance of our transform class.
if isinstance(tnfm_class, TransformMeta):
# Classes derived TransformMeta have their __call__
# attribute overridden. Since this is what is usually
# used to create an instance, we have to delegate the
# responsibility of creating an instance to
# TransformMeta's parent class, which is 'type'. This is
# what is implicitly done behind the scenes by the python
# interpreter for most classes anyway, but here we have to
# be explicit because we've overridden the method that
# usually resolves to our super call.
self.state = super(TransformMeta, tnfm_class).__call__(
*args, **kwargs)
# Normal object instantiation.
else:
self.state = tnfm_class(*args, **kwargs)
# Create the string associated with this generator's output.
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
def get_hash(self):
return self.namestring
def transform(self, stream_in):
return self._gen(stream_in)
def _gen(self, stream_in):
# IMPORTANT: Messages may contain pointers that are shared with
# other streams. Transforms that modify their input
# messages should only manipulate copies.
log.info('Running StatefulTransform [%s]' % self.get_hash())
for message in stream_in:
# allow upstream generators to yield None to avoid
# blocking.
if message is None:
continue
assert_sort_unframe_protocol(message)
# This flag is set by by merged_transforms to ensure
# isolation of messages.
if self.merged:
message = deepcopy(message)
tnfm_value = self.state.update(message)
# PASSTHROUGH flag means we want to keep all original
# values, plus append tnfm_id and tnfm_value. Used for
# preserving the original event fields when our output
# will be fed into a merge. Currently only Passthrough
# uses this flag.
if self.passthrough and self.merged:
out_message = message
out_message.tnfm_id = self.namestring
out_message.tnfm_value = tnfm_value
yield out_message
# If the merged flag is set, we create a new message
# containing just the tnfm_id, the event's datetime, and
# the calculated tnfm_value. This is the default behavior
# for a non-passthrough transform being fed into a merge.
elif self.merged:
out_message = ndict()
out_message.tnfm_id = self.namestring
out_message.tnfm_value = tnfm_value
out_message.dt = message.dt
yield out_message
# Sequential flag should be used to add a single new
# key-value pair to the event. The new key is this
# transform's namestring, and its value is the value
# returned by state.update(event). This is almost
# identical to the behavior of FORWARDER, except we
# compress the two calculated values (tnfm_id, and
# tnfm_value) into a single field. This mode is used by
# the sequential_transforms composite and is the default
# if no behavior is specified by the internal state class.
elif self.sequential:
out_message = message
out_message[self.namestring] = tnfm_value
yield out_message
log.info('Finished StatefulTransform [%s]' % self.get_hash())
class EventWindow(object):
"""
Abstract base class for transform classes that calculate iterative
metrics on events within a given timedelta. Maintains a list of
events that are within a certain timedelta of the most recent
tick. Calls self.handle_add(event) for each event added to the
window. Calls self.handle_remove(event) for each event removed
from the window. Subclass these methods along with init(*args,
**kwargs) to calculate metrics over the window.
If the market_aware flag is True, the EventWindow drops old events
based on the number of elapsed trading days between newest and oldest.
Otherwise old events are dropped based on a raw timedelta.
See zipline/transforms/mavg.py and zipline/transforms/vwap.py for example
implementations of moving average and volume-weighted average
price.
"""
# Mark this as an abstract base class.
__metaclass__ = ABCMeta
def __init__(self, market_aware=True, window_length=None, delta=None):
self.market_aware = market_aware
self.window_length = window_length
self.delta = delta
self.ticks = deque()
# Market-aware mode only works with full-day windows.
if self.market_aware:
assert self.window_length and self.delta is None, \
"Market-aware mode only works with full-day windows."
self.all_holidays = deque(non_trading_days)
self.cur_holidays = deque()
# Non-market-aware mode requires a timedelta.
else:
assert self.delta and not self.window_length, \
"Non-market-aware mode requires a timedelta."
# Set the behavior for dropping events from the back of the
# event window.
if self.market_aware:
self.drop_condition = self.out_of_market_window
else:
self.drop_condition = self.out_of_delta
@abstractmethod
def handle_add(self, event):
raise NotImplementedError()
@abstractmethod
def handle_remove(self, event):
raise NotImplementedError()
def __len__(self):
return len(self.ticks)
def update(self, event):
self.assert_well_formed(event)
# Add new event and increment totals.
self.ticks.append(deepcopy(event))
# Subclasses should override handle_add to define behavior for
# adding new ticks.
self.handle_add(event)
if self.market_aware:
self.add_new_holidays(event['dt'])
# Clear out any expired events. drop_condition changes depending
# on whether or not we are running in market_aware mode.
#
# oldest newest
# | |
# V V
while self.drop_condition(self.ticks[0]['dt'], self.ticks[-1]['dt']):
# popleft removes and returns the oldest tick in self.ticks
popped = self.ticks.popleft()
# Subclasses should override handle_remove to define
# behavior for removing ticks.
self.handle_remove(popped)
def add_new_holidays(self, newest):
# Add to our tracked window any untracked holidays that are
# older than our newest event. (newest should always be
# self.ticks[-1])
while len(self.all_holidays) > 0 and self.all_holidays[0] <= newest:
self.cur_holidays.append(self.all_holidays.popleft())
def drop_old_holidays(self, oldest):
# Drop from our tracked window any holidays that are older
# than our oldest tracked event. (oldest should always
# be self.ticks[0])
while len(self.cur_holidays) > 0 and self.cur_holidays[0] < oldest:
self.cur_holidays.popleft()
def out_of_market_window(self, oldest, newest):
self.drop_old_holidays(oldest)
calendar_dates_between = (newest.date() - oldest.date()).days
holidays_between = len(self.cur_holidays)
trading_days_between = calendar_dates_between - holidays_between
# "Put back" a day if oldest is earlier in its day than newest,
# reflecting the fact that we haven't yet completed the last
# day in the window.
if oldest.time() > newest.time():
trading_days_between -= 1
return trading_days_between >= self.window_length
def out_of_delta(self, oldest, newest):
return (newest - oldest) >= self.delta
# All event windows expect to receive events with datetime fields
# that arrive in sorted order.
def assert_well_formed(self, event):
assert 'dt' in event, "Missing dt in EventWindow:%s" % event
assert isinstance(event['dt'], datetime), \
"Bad dt in EventWindow:%s" % event
if len(self.ticks) > 0:
# Something is wrong if new event is older than previous.
assert event['dt'] >= self.ticks[-1]['dt'], \
"Events arrived out of order in EventWindow: %s -> %s" % \
(event, self.ticks[0])
class BatchTransform(EventWindow):
"""Base class for batch transforms with a trailing window of
variable length. As opposed to pure EventWindows that get a stream
of events and are bound to a single SID, this class creates stream
of pandas DataFrames with each colum representing a sid.
There are two ways to create a new batch window:
(i) Inherit from BatchTransform and overload get_value(data).
E.g.:
```
class MyBatchTransform(BatchTransform):
def get_value(self, data):
# compute difference between the means of sid 0 and sid 1
return data[0].mean() - data[1].mean()
```
(ii) Use the batch_transform decorator.
E.g.:
```
@batch_transform
def my_batch_transform(data):
return data[0].mean() - data[1].mean()
```
In you algorithm you would then have to instantiate
this in the initialize() method:
```
self.my_batch_transform = MyBatchTransform()
```
To then use it, inside of the algorithm handle_data(), call the
handle_data() of the BatchTransform and pass it the current event:
```
result = self.my_batch_transform(data)
```
"""
def __init__(self,
func=None,
refresh_period=None,
window_length=None,
clean_nans=True,
sids=None,
fields=None,
create_panel=True,
compute_only_full=True):
"""Instantiate new batch_transform object.
:Arguments:
func : python function <optional>
If supplied will be called after each refresh_period
with the data panel and all args and kwargs supplied
to the handle_data() call.
refresh_period : int
Interval to call batch_transform function.
window_length : int
How many days the trailing window should have.
clean_nans : bool <default=True>
Whether to (forward) fill in nans.
sids : list <optional>
Which sids to include in the moving window. If not
supplied sids will be extracted from incoming
events.
fields : list <optional>
Which fields to include in the moving window
(e.g. 'price'). If not supplied, fields will be
extracted from incoming events.
create_panel : bool <default=True>
If False, will create a pandas panel every refresh
period and pass it to the user-defined function.
If True, will pass the underlying deque reference
directly to the function which will be significantly
faster.
compute_only_full : bool <default=True>
Only call the user-defined function once the window is
full. Returns None if window is not full yet.
"""
super(BatchTransform, self).__init__(True,
window_length=window_length)
if func is not None:
self.compute_transform_value = func
else:
self.compute_transform_value = self.get_value
self.clean_nans = clean_nans
self.create_panel = create_panel
self.compute_only_full = compute_only_full
self.sids = sids
if isinstance(self.sids, (str, int)):
self.sids = [self.sids]
self.field_names = fields
if isinstance(self.field_names, str):
self.field_names = [self.field_names]
self.refresh_period = refresh_period
self.window_length = window_length
self.trading_days_since_update = 0
self.trading_days_total = 0
self.full = False
self.last_dt = None
self.updated = False
self.cached = None
def handle_data(self, data, *args, **kwargs):
"""
New method to handle a data frame as sent to the algorithm's
handle_data method.
"""
# extract dates
#dts = [data[sid].datetime for sid in self.sids]
dts = [event['datetime'] for event in data.itervalues()]
# we have to provide the event with a dt. This is only for
# checking if the event is outside the window or not so a
# couple of seconds shouldn't matter. We don't add it to
# the data parameter, because it would mix dt with the
# sid keys.
event = dict()
event['dt'] = max(dts)
# Hack: convert (and copy) to dict for later panel conversion
new_data = dict()
for sid, frame in data.iteritems():
new_data[sid] = dict(frame)
event['data'] = dict(new_data)
# append data frame to window. update() will call handle_add() and
# handle_remove() appropriately
self.update(event)
# return newly computed or cached value
return self.get_transform_value(*args, **kwargs)
def _extract_field_names(self, event):
# extract field names from sids (price, volume etc), make sure
# every sid has the same fields.
sid_keys = []
for sid in event['data'].itervalues():
keys = set([name for name, value in sid.items()
if (isinstance(value, (int, float)))])
sid_keys.append(keys)
assert sid_keys[0] == set.intersection(*sid_keys),\
"Each sid must have the same keys."
unwanted_fields = set(['portfolio', 'sid', 'dt', 'type',
'datetime', 'source_id'])
return sid_keys[0] - unwanted_fields
def handle_add(self, event):
if not self.last_dt:
if self.field_names is None:
self.field_names = self._extract_field_names(event)
self.last_dt = event['dt']
# update trading day counters
if self.last_dt.day != event['dt'].day:
self.last_dt = event['dt']
self.trading_days_since_update += 1
self.trading_days_total += 1
if self.trading_days_total >= self.window_length:
self.full = True
if self.trading_days_since_update >= self.refresh_period:
# Setting updated to True will cause get_transform_value()
# to call the user-defined batch-transform with the most
# recent datapanel
self.updated = True
self.trading_days_since_update = 0
else:
self.updated = False
def get_data(self):
"""Create a pandas.Panel (i.e. 3d DataFrame) from the
events in the current window.
Returns:
The resulting panel looks like this:
index : field_name (e.g. price)
major axis/rows : dt
minor axis/colums : sid
"""
# This Panel data structure ultimately gets passed to the
# user-overloaded get_value() method.
data_dict = dict((tick['dt'], tick['data']) for tick in self.ticks)
data = pd.Panel(data_dict, major_axis=self.field_names,
minor_axis=self.sids)
# Panel interprets the outer-most keys as the items, the the
# inner dicts are treated as though passed to DataFrame
# (e.g. their outer keys become the columns--then
# minor_axis--of each inner dataframe). so the resulting panel
# should be dates x fields x sids, so swapping 0 for 1
data = data.swapaxes(0, 1)
if self.clean_nans:
# Fills in gaps of missing data during transform
# of multiple stocks. E.g. we may be missing
# minute data because of illiquidity of one stock
data = data.fillna(method='ffill')
# Drop any empty rows after the fill.
# This will drop a leading row of N/A
data = data.dropna(axis=1)
return data
def handle_remove(self, event):
pass
def get_value(self, *args, **kwargs):
raise NotImplementedError(
"Either overwrite get_value or provide a func argument.")
def get_transform_value(self, *args, **kwargs):
"""Call user-defined batch-transform function passing all
arguments.
Note that this will only call the transform if the datapanel
has actually been updated. Otherwise, the previously, cached
value will be returned.
"""
if self.compute_only_full and not self.full:
return None
if self.updated:
# Either create new pandas panel or pass ticks dequeue
# directly
data = self.get_data() if self.create_panel else self.ticks
self.cached = self.compute_transform_value(data, *args,
**kwargs)
return self.cached
def __call__(self, f):
self.compute_transform_value = f
return self.handle_data
def batch_transform(func):
"""Decorator function to use instead of inheriting from BatchTransform.
For an example on how to use this, see the doc string of BatchTransform.
"""
def create_window(*args, **kwargs):
# passes the user defined function to BatchTransform which it
# will call instead of self.get_value()
return BatchTransform(*args, func=func, **kwargs)
return create_window