Skip to content

Commit

Permalink
Merge 46331ab into e5aa830
Browse files Browse the repository at this point in the history
  • Loading branch information
llllllllll committed Mar 7, 2018
2 parents e5aa830 + 46331ab commit 45fe635
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
27 changes: 27 additions & 0 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ZiplineTestCase,
)
from zipline.testing.slippage import TestingSlippage
from zipline.testing.predicates import wildcard, instance_of
from zipline.utils.numpy_utils import bool_dtype


Expand Down Expand Up @@ -173,3 +174,29 @@ def test_fill_all(self):

self.assertEqual(price, self.EQUITY_MINUTE_CONSTANT_CLOSE)
self.assertEqual(volume, order_amount)


class TestPredicates(ZiplineTestCase):

def test_wildcard(self):
for obj in 1, object(), "foo", {}:
self.assertEqual(obj, wildcard)
self.assertEqual([obj], [wildcard])
self.assertEqual({'foo': wildcard}, {'foo': wildcard})

def test_instance_of(self):
self.assertEqual(1, instance_of(int))
self.assertNotEqual(1, instance_of(str))
self.assertEqual(1, instance_of((str, int)))
self.assertEqual("foo", instance_of((str, int)))

def test_instance_of_exact(self):

class Foo(object):
pass

class Bar(Foo):
pass

self.assertEqual(Bar(), instance_of(Foo))
self.assertNotEqual(Bar(), instance_of(Foo, exact=True))
44 changes: 43 additions & 1 deletion zipline/testing/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,49 @@ def __ne__(other):

def __repr__(self):
return '<%s>' % type(self).__name__
__str__ = __repr__


class instance_of(object):
"""An object that compares equal to any instance of a given type or types.
Parameters
----------
types : type or tuple[type]
The types to compare equal to.
exact : bool, optional
Only compare equal to exact instances, not instances of subclasses?
"""
def __init__(self, types, exact=False):
if not isinstance(types, tuple):
types = (types,)

for type_ in types:
if not isinstance(type_, type):
raise TypeError('types must be a type or tuple of types')

self.types = types
self.exact = exact

def __eq__(self, other):
if self.exact:
return type(other) in self.types

return isinstance(other, self.types)

def __ne__(self, other):
return not self == other

def __repr__(self):
typenames = tuple(t.__name__ for t in self.types)
return '%s(%s%s)' % (
type(self).__name__,
(
typenames[0]
if len(typenames) == 1 else
'(%s)' % ', '.join(typenames)
),
', exact=True' if self.exact else ''
)


def keywords(func):
Expand Down

0 comments on commit 45fe635

Please sign in to comment.