diff --git a/predicthq/endpoints/schemas.py b/predicthq/endpoints/schemas.py index ea662ef..84cfe2b 100644 --- a/predicthq/endpoints/schemas.py +++ b/predicthq/endpoints/schemas.py @@ -5,7 +5,7 @@ import itertools -from datetime import datetime +from datetime import datetime, date from dateutil.parser import parse as parse_date import six @@ -13,7 +13,7 @@ from schematics.models import Model from schematics.transforms import Role -from schematics.types import StringType, DateTimeType as SchematicsDateTimeType, IntType, FloatType, URLType, GeoPointType, BooleanType, DateType +from schematics.types import StringType, DateTimeType as SchematicsDateTimeType, IntType, FloatType, URLType, GeoPointType, BooleanType, DateType as SchematicsDateType from schematics.types.compound import ListType as SchematicsListType, ModelType, DictType from schematics.exceptions import ValidationError as SchematicsValidationError, DataError as SchematicsDataError, ConversionError from schematics.types.serializable import serializable @@ -28,6 +28,14 @@ def to_native(self, value, context=None): return parse_date(value) +class DateType(SchematicsDateType): + + def to_native(self, value, context=None): + if isinstance(value, date): + return value + return parse_date(value).date() + + class StringModelType(ModelType): def _convert(self, value, context): @@ -103,7 +111,7 @@ class Area(StringModel): longitude = FloatType(required=True) -class DateRange(Model): +class DateTimeRange(Model): class Options: serialize_when_none = False diff --git a/predicthq/endpoints/v1/events/schemas.py b/predicthq/endpoints/v1/events/schemas.py index 9b1d559..c57aa1e 100644 --- a/predicthq/endpoints/v1/events/schemas.py +++ b/predicthq/endpoints/v1/events/schemas.py @@ -3,7 +3,7 @@ from predicthq.endpoints.schemas import PaginatedMixin, SortableMixin, Model, ResultSet, \ ListType, StringType, GeoJSONPointType, StringListType, StringModelType, Area, \ - ModelType, IntRange, IntType, DateRange, DateTimeType, FloatType, ResultType, \ + ModelType, IntRange, IntType, DateTimeRange, DateTimeType, FloatType, ResultType, \ DictType, DateType @@ -16,8 +16,8 @@ class Options: q = StringType() label = ListType(StringType) category = ListType(StringType) - start = ModelType(DateRange) - end = ModelType(DateRange) + start = ModelType(DateTimeRange) + end = ModelType(DateTimeRange) rank_level = ListType(IntType(min_value=1, max_value=5)) rank = ModelType(IntRange) country = ListType(StringType) @@ -57,6 +57,7 @@ class TopEventsSearchParams(SortableMixin, Model): class CalendarParams(SearchParams): + dates = ModelType(DateTimeRange) top_events = ModelType(TopEventsSearchParams) diff --git a/predicthq/endpoints/v1/signals/schemas.py b/predicthq/endpoints/v1/signals/schemas.py index 545eddb..77c1f1f 100644 --- a/predicthq/endpoints/v1/signals/schemas.py +++ b/predicthq/endpoints/v1/signals/schemas.py @@ -4,7 +4,7 @@ from schematics.transforms import blacklist, wholelist, whitelist from predicthq.endpoints.schemas import Model, StringType, ListType, ModelType, DateTimeType, ResultSet, ResultType, SortableMixin, FloatType, IntType, DictType, \ - PaginatedMixin, DateRange, StringListType, StringModelType, Area, BooleanType + PaginatedMixin, DateTimeRange, StringListType, StringModelType, Area, BooleanType, DateType class SignalsSearchParams(SortableMixin, Model): @@ -176,7 +176,7 @@ class DailyAnalysisDetails(Model): class DailyAnalysis(Model): - date = DateTimeType() + date = DateType() trend = FloatType() actual = FloatType() expected = FloatType() @@ -195,9 +195,9 @@ class Options: serialize_when_none = False id = StringType(required=True) - date = ModelType(DateRange) - initiated = ModelType(DateRange) - completed = ModelType(DateRange) + date = ModelType(DateTimeRange) + initiated = ModelType(DateTimeRange) + completed = ModelType(DateTimeRange) within = StringListType(StringModelType(Area), separator="+") significance = IntType(min_value=0, max_value=100) lead = BooleanType(default=False) diff --git a/tests/endpoints/schemas_tests.py b/tests/endpoints/schemas_tests.py index 82adee0..69c2021 100644 --- a/tests/endpoints/schemas_tests.py +++ b/tests/endpoints/schemas_tests.py @@ -4,7 +4,7 @@ import unittest import pytz -from datetime import datetime +from datetime import datetime, date from predicthq.endpoints import decorators, schemas from predicthq.endpoints.base import BaseEndpoint @@ -16,14 +16,27 @@ def test_datetime_type(self): class SchemaExample(schemas.Model): - my_date = schemas.DateTimeType() + my_datetime = schemas.DateTimeType() test_date = datetime(2016, 1, 1, tzinfo=pytz.UTC) - self.assertEqual(SchemaExample({"my_date": "2016-01-01T00:00:00+00:00"}).my_date, test_date) + self.assertEqual(SchemaExample({"my_datetime": "2016-01-01T00:00:00+00:00"}).my_datetime, test_date) + self.assertEqual(SchemaExample({"my_datetime": "2016-01-01T00:00:00+0000"}).my_datetime, test_date) + self.assertEqual(SchemaExample({"my_datetime": "2016-01-01T00:00:00Z"}).my_datetime, test_date) + self.assertEqual(SchemaExample({"my_datetime": test_date}).my_datetime, test_date) + + def test_date_type(self): + + class SchemaExample(schemas.Model): + + my_date = schemas.DateType() + + test_date = date(2016, 1, 1) + self.assertEqual(SchemaExample({"my_date": "2016-01-01"}).my_date, test_date) self.assertEqual(SchemaExample({"my_date": "2016-01-01T00:00:00+0000"}).my_date, test_date) self.assertEqual(SchemaExample({"my_date": "2016-01-01T00:00:00Z"}).my_date, test_date) self.assertEqual(SchemaExample({"my_date": test_date}).my_date, test_date) + def test_string_model_and_string_model_type(self): class MyModel(schemas.StringModel): diff --git a/tests/endpoints/v1/events_tests.py b/tests/endpoints/v1/events_tests.py index 08f9be4..6444c44 100644 --- a/tests/endpoints/v1/events_tests.py +++ b/tests/endpoints/v1/events_tests.py @@ -53,7 +53,7 @@ def test_count(self, client, responses): @with_client() @with_mock_responses() def test_calendar(self, client, responses): - result = client.events.calendar(start__gte="2015-12-24", start__lte="2015-12-26", country="NZ", top_events__limit=1, top_events__sort=["rank"]) + result = client.events.calendar(start__gte="2015-12-24", start__lte="2015-12-26", country="NZ", top_events__limit=1, top_events__sort=["rank"], dates__tz="Pacific/Auckland") self.assertIsInstance(result, CalendarResultSet) self.assertEqual(result.count, 60) self.assertEqual(3, len(list(result.iter_all())))