diff --git a/README.rst b/README.rst index 3219cf6..406f026 100644 --- a/README.rst +++ b/README.rst @@ -42,6 +42,7 @@ Endpoints * ``Client.accounts`` * ``Client.events`` * ``Client.signals`` +* ``Client.places`` For a description of all available endpoints, refer to our `API Documentation `_. diff --git a/predicthq/client.py b/predicthq/client.py index bd00bc1..b9aff76 100644 --- a/predicthq/client.py +++ b/predicthq/client.py @@ -32,6 +32,7 @@ def initialize_endpoints(self): self.events = endpoints.EventsEndpoint(proxy(self)) self.accounts = endpoints.AccountsEndpoint(proxy(self)) self.signals = endpoints.SignalsEndpoint(proxy(self)) + self.places = endpoints.PlacesEndpoint(proxy(self)) def get_headers(self, headers): _headers = {"Accept": "application/json",} diff --git a/predicthq/endpoints/__init__.py b/predicthq/endpoints/__init__.py index 85cd5fc..6083a13 100644 --- a/predicthq/endpoints/__init__.py +++ b/predicthq/endpoints/__init__.py @@ -5,3 +5,4 @@ from .v1.accounts import AccountsEndpoint from .v1.events import EventsEndpoint from .v1.signals import SignalsEndpoint +from .v1.places import PlacesEndpoint diff --git a/predicthq/endpoints/schemas.py b/predicthq/endpoints/schemas.py index 6040104..b775cb8 100644 --- a/predicthq/endpoints/schemas.py +++ b/predicthq/endpoints/schemas.py @@ -89,7 +89,7 @@ def _export(self, *args, **kwargs): class ListType(SchematicsListType): def _coerce(self, value): - if isinstance(value, six.string_types): + if not isinstance(value, (list, tuple)): return [value] else: return super(ListType, self)._coerce(value) @@ -111,6 +111,20 @@ class Area(StringModel): longitude = FloatType(required=True) +class Location(StringModel): + + import_format = r'@(?P-?\d+(\.\d+)?),(?P-?\d+(\.\d+)?)' + export_format = "@{latitude},{longitude}" + + latitude = FloatType(required=True) + longitude = FloatType(required=True) + + +class Place(Model): + + scope = ListType(IntType, required=True) + + class DateTimeRange(Model): class Options: @@ -145,12 +159,21 @@ class Options: lte = IntType() -class PaginatedMixin(Model): +class LimitMixin(Model): limit = IntType(min_value=1, max_value=200) + + +class OffsetMixin(Model): + offset = IntType(min_value=0, max_value=50) +class PaginatedMixin(LimitMixin, OffsetMixin): + + pass + + class SortableMixin(Model): sort = ListType(StringType()) diff --git a/predicthq/endpoints/v1/events/schemas.py b/predicthq/endpoints/v1/events/schemas.py index 95e41a2..eca2cd0 100644 --- a/predicthq/endpoints/v1/events/schemas.py +++ b/predicthq/endpoints/v1/events/schemas.py @@ -4,7 +4,7 @@ from predicthq.endpoints.schemas import PaginatedMixin, SortableMixin, Model, ResultSet, \ ListType, StringType, GeoJSONPointType, StringListType, StringModelType, Area, \ ModelType, IntRange, IntType, DateTimeRange, DateTimeType, FloatType, ResultType, \ - DictType, DateType + DictType, DateType, Place class SearchParams(PaginatedMixin, SortableMixin, Model): @@ -23,6 +23,7 @@ class Options: rank = ModelType(IntRange) country = ListType(StringType) within = StringListType(StringModelType(Area), separator="+") + place = ModelType(Place) class Event(Model): diff --git a/predicthq/endpoints/v1/places/__init__.py b/predicthq/endpoints/v1/places/__init__.py new file mode 100644 index 0000000..c6e5ee9 --- /dev/null +++ b/predicthq/endpoints/v1/places/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals, absolute_import, print_function + +from .endpoint import PlacesEndpoint +from .schemas import Place diff --git a/predicthq/endpoints/v1/places/endpoint.py b/predicthq/endpoints/v1/places/endpoint.py new file mode 100644 index 0000000..d3100b7 --- /dev/null +++ b/predicthq/endpoints/v1/places/endpoint.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals, absolute_import, print_function + +from predicthq.endpoints.base import BaseEndpoint +from predicthq.endpoints.decorators import accepts, returns +from .schemas import SearchParams, PlaceResultSet + + +class PlacesEndpoint(BaseEndpoint): + + @accepts(SearchParams) + @returns(PlaceResultSet) + def search(self, **params): + return self.client.get(self.build_url('v1', 'places'), params=params) diff --git a/predicthq/endpoints/v1/places/schemas.py b/predicthq/endpoints/v1/places/schemas.py new file mode 100644 index 0000000..da031b2 --- /dev/null +++ b/predicthq/endpoints/v1/places/schemas.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals, absolute_import, print_function + +from predicthq.endpoints.schemas import LimitMixin, Model, ResultSet, \ + ListType, StringType, GeoJSONPointType, StringListType, StringModelType, Location, \ + DateTimeType, ResultType, SchematicsValidationError + + +class SearchParams(LimitMixin, Model): + + class Options: + serialize_when_none = False + + q = StringType() + id = ListType(StringType) + location = StringListType(StringModelType(Location), separator="+") + country = ListType(StringType) + type = ListType(StringType(choices=('planet', 'continent', 'country', 'region', 'county', 'local', 'major', 'metro', 'all'))) + + def validate(self, *args, **kwargs): + super(SearchParams, self).validate(*args, **kwargs) + if not any((self.q, self.id, self.location, self.country)): + raise SchematicsValidationError("Places search requires one of q, id, location or country") + + +class Place(Model): + + class Options: + serialize_when_none = False + + id = StringType() + type = StringType() + name = StringType() + county = StringType() + region = StringType() + country = StringType() + country_alpha2 = StringType() + country_alpha3 = StringType() + location = GeoJSONPointType() + + +class PlaceResultSet(ResultSet): + + results = ResultType(Place) diff --git a/tests/client_tests.py b/tests/client_tests.py index aa74704..a218649 100644 --- a/tests/client_tests.py +++ b/tests/client_tests.py @@ -26,6 +26,7 @@ def test_endpoints_initialization(self): self.assertIsInstance(self.client.accounts, endpoints.AccountsEndpoint) self.assertIsInstance(self.client.events, endpoints.EventsEndpoint) self.assertIsInstance(self.client.signals, endpoints.SignalsEndpoint) + self.assertIsInstance(self.client.places, endpoints.PlacesEndpoint) @with_mock_responses() def test_request(self, responses): diff --git a/tests/endpoints/schemas_tests.py b/tests/endpoints/schemas_tests.py index 69c2021..c4993eb 100644 --- a/tests/endpoints/schemas_tests.py +++ b/tests/endpoints/schemas_tests.py @@ -141,6 +141,31 @@ class SchemaExample(schemas.Model): with self.assertRaises(schemas.SchematicsDataError): m.import_data(invalid_data) + def test_location_model(self): + + class SchemaExample(schemas.Model): + + location = schemas.StringModelType(schemas.Location) + + short_data = {"location": "@-36.847585,174.765742"} + long_data = {"location": {"latitude": -36.847585, "longitude": 174.765742}} + model_data = {"location": schemas.Location("@-36.847585,174.765742")} + invalid_data = {"location": "-36.847585,174.765742"} + + expected_expected = {"location": "@-36.847585,174.765742"} + + m = SchemaExample() + self.assertDictEqual(m.import_data(short_data).to_primitive(), expected_expected) + self.assertDictEqual(m.import_data(long_data).to_primitive(), expected_expected) + self.assertDictEqual(m.import_data(model_data).to_primitive(), expected_expected) + + self.assertDictEqual(m.import_data(short_data).to_dict(), expected_expected) + self.assertDictEqual(m.import_data(long_data).to_dict(), expected_expected) + self.assertDictEqual(m.import_data(model_data).to_dict(), expected_expected) + + with self.assertRaises(schemas.SchematicsDataError): + m.import_data(invalid_data) + def test_resultset(self): class ResultExample(schemas.Model): diff --git a/tests/endpoints/v1/places_tests.py b/tests/endpoints/v1/places_tests.py new file mode 100644 index 0000000..7f833a2 --- /dev/null +++ b/tests/endpoints/v1/places_tests.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals, absolute_import, print_function + +import unittest + +from predicthq.endpoints import schemas +from tests import with_mock_client, with_mock_responses, with_client + +from predicthq.endpoints.v1.places.schemas import PlaceResultSet + + +class PlacesTest(unittest.TestCase): + + @with_mock_client() + def test_search_params(self, client): + client.places.search(country=["NZ", "AU"]) + client.request.assert_called_once_with('get', '/v1/places/', params={'country': 'NZ,AU'}) + + @with_mock_client() + def test_invalide_search_params(self, client): + with self.assertRaises(schemas.SchematicsValidationError): + client.places.search() + + @with_client() + @with_mock_responses() + def test_search(self, client, responses): + result = client.places.search(country=["NZ", "AU"]) + self.assertIsInstance(result, PlaceResultSet) + self.assertEqual(result.count, len(list(result.iter_all()))) + self.assertEqual(1, len(responses.calls)) diff --git a/tests/fixtures/requests_responses/places_test/test_search.json b/tests/fixtures/requests_responses/places_test/test_search.json new file mode 100644 index 0000000..369f4cc --- /dev/null +++ b/tests/fixtures/requests_responses/places_test/test_search.json @@ -0,0 +1,44 @@ +[ + { + "method": "GET", + "match_querystring": true, + "url": "/v1/places/?country=NZ,AU", + "status": 200, + "content_type": "application/json", + "body": { + "count": 2, + "next": null, + "previous": null, + "results": [ + { + "id": "2186224", + "type": "country", + "name": "New Zealand", + "county": null, + "region": null, + "country": "New Zealand", + "country_alpha2": "NZ", + "country_alpha3": "NZL", + "location": [ + 174, + -42 + ] + }, + { + "id": "2077456", + "type": "country", + "name": "Australia", + "county": null, + "region": null, + "country": "Australia", + "country_alpha2": "AU", + "country_alpha3": "AUS", + "location": [ + 135, + -25 + ] + } + ] + } + } +] \ No newline at end of file