Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions predicthq/endpoints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

import itertools

from datetime import datetime
from datetime import datetime, date

from dateutil.parser import parse as parse_date
import six
import pytz

from schematics.models import Model
from schematics.transforms import Role
from schematics.types import StringType, DateTimeType as SchematicsDateTimeType, IntType, FloatType, URLType, GeoPointType, BooleanType, DateType
from schematics.types import StringType, DateTimeType as SchematicsDateTimeType, IntType, FloatType, URLType, GeoPointType, BooleanType, DateType as SchematicsDateType
from schematics.types.compound import ListType as SchematicsListType, ModelType, DictType
from schematics.exceptions import ValidationError as SchematicsValidationError, DataError as SchematicsDataError, ConversionError
from schematics.types.serializable import serializable
Expand All @@ -28,6 +28,14 @@ def to_native(self, value, context=None):
return parse_date(value)


class DateType(SchematicsDateType):

def to_native(self, value, context=None):
if isinstance(value, date):
return value
return parse_date(value).date()


class StringModelType(ModelType):

def _convert(self, value, context):
Expand Down Expand Up @@ -103,7 +111,7 @@ class Area(StringModel):
longitude = FloatType(required=True)


class DateRange(Model):
class DateTimeRange(Model):

class Options:
serialize_when_none = False
Expand Down
7 changes: 4 additions & 3 deletions predicthq/endpoints/v1/events/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from predicthq.endpoints.schemas import PaginatedMixin, SortableMixin, Model, ResultSet, \
ListType, StringType, GeoJSONPointType, StringListType, StringModelType, Area, \
ModelType, IntRange, IntType, DateRange, DateTimeType, FloatType, ResultType, \
ModelType, IntRange, IntType, DateTimeRange, DateTimeType, FloatType, ResultType, \
DictType, DateType


Expand All @@ -16,8 +16,8 @@ class Options:
q = StringType()
label = ListType(StringType)
category = ListType(StringType)
start = ModelType(DateRange)
end = ModelType(DateRange)
start = ModelType(DateTimeRange)
end = ModelType(DateTimeRange)
rank_level = ListType(IntType(min_value=1, max_value=5))
rank = ModelType(IntRange)
country = ListType(StringType)
Expand Down Expand Up @@ -57,6 +57,7 @@ class TopEventsSearchParams(SortableMixin, Model):

class CalendarParams(SearchParams):

dates = ModelType(DateTimeRange)
top_events = ModelType(TopEventsSearchParams)


Expand Down
10 changes: 5 additions & 5 deletions predicthq/endpoints/v1/signals/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from schematics.transforms import blacklist, wholelist, whitelist

from predicthq.endpoints.schemas import Model, StringType, ListType, ModelType, DateTimeType, ResultSet, ResultType, SortableMixin, FloatType, IntType, DictType, \
PaginatedMixin, DateRange, StringListType, StringModelType, Area, BooleanType
PaginatedMixin, DateTimeRange, StringListType, StringModelType, Area, BooleanType, DateType


class SignalsSearchParams(SortableMixin, Model):
Expand Down Expand Up @@ -176,7 +176,7 @@ class DailyAnalysisDetails(Model):

class DailyAnalysis(Model):

date = DateTimeType()
date = DateType()
trend = FloatType()
actual = FloatType()
expected = FloatType()
Expand All @@ -195,9 +195,9 @@ class Options:
serialize_when_none = False

id = StringType(required=True)
date = ModelType(DateRange)
initiated = ModelType(DateRange)
completed = ModelType(DateRange)
date = ModelType(DateTimeRange)
initiated = ModelType(DateTimeRange)
completed = ModelType(DateTimeRange)
within = StringListType(StringModelType(Area), separator="+")
significance = IntType(min_value=0, max_value=100)
lead = BooleanType(default=False)
Expand Down
19 changes: 16 additions & 3 deletions tests/endpoints/schemas_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

import pytz
from datetime import datetime
from datetime import datetime, date

from predicthq.endpoints import decorators, schemas
from predicthq.endpoints.base import BaseEndpoint
Expand All @@ -16,14 +16,27 @@ def test_datetime_type(self):

class SchemaExample(schemas.Model):

my_date = schemas.DateTimeType()
my_datetime = schemas.DateTimeType()

test_date = datetime(2016, 1, 1, tzinfo=pytz.UTC)
self.assertEqual(SchemaExample({"my_date": "2016-01-01T00:00:00+00:00"}).my_date, test_date)
self.assertEqual(SchemaExample({"my_datetime": "2016-01-01T00:00:00+00:00"}).my_datetime, test_date)
self.assertEqual(SchemaExample({"my_datetime": "2016-01-01T00:00:00+0000"}).my_datetime, test_date)
self.assertEqual(SchemaExample({"my_datetime": "2016-01-01T00:00:00Z"}).my_datetime, test_date)
self.assertEqual(SchemaExample({"my_datetime": test_date}).my_datetime, test_date)

def test_date_type(self):

class SchemaExample(schemas.Model):

my_date = schemas.DateType()

test_date = date(2016, 1, 1)
self.assertEqual(SchemaExample({"my_date": "2016-01-01"}).my_date, test_date)
self.assertEqual(SchemaExample({"my_date": "2016-01-01T00:00:00+0000"}).my_date, test_date)
self.assertEqual(SchemaExample({"my_date": "2016-01-01T00:00:00Z"}).my_date, test_date)
self.assertEqual(SchemaExample({"my_date": test_date}).my_date, test_date)


def test_string_model_and_string_model_type(self):

class MyModel(schemas.StringModel):
Expand Down
2 changes: 1 addition & 1 deletion tests/endpoints/v1/events_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_count(self, client, responses):
@with_client()
@with_mock_responses()
def test_calendar(self, client, responses):
result = client.events.calendar(start__gte="2015-12-24", start__lte="2015-12-26", country="NZ", top_events__limit=1, top_events__sort=["rank"])
result = client.events.calendar(start__gte="2015-12-24", start__lte="2015-12-26", country="NZ", top_events__limit=1, top_events__sort=["rank"], dates__tz="Pacific/Auckland")
self.assertIsInstance(result, CalendarResultSet)
self.assertEqual(result.count, 60)
self.assertEqual(3, len(list(result.iter_all())))
Expand Down