diff --git a/predicthq/client.py b/predicthq/client.py index 67bd0bd..bd00bc1 100644 --- a/predicthq/client.py +++ b/predicthq/client.py @@ -43,7 +43,7 @@ def get_headers(self, headers): def request(self, method, path, **kwargs): headers = self.get_headers(kwargs.pop("headers", {})) response = requests.request(method, self.build_url(path), headers=headers, **kwargs) - + self.logger.debug(response.request.url) try: response.raise_for_status() except requests.HTTPError: diff --git a/predicthq/endpoints/base.py b/predicthq/endpoints/base.py index e7c08a3..9e94b36 100644 --- a/predicthq/endpoints/base.py +++ b/predicthq/endpoints/base.py @@ -19,3 +19,28 @@ class BaseEndpoint(six.with_metaclass(MetaEndpoint)): def __init__(self, client): self.client = client + + def build_url(self, prefix, suffix): + return '/{}/{}/'.format(prefix.strip('/'), suffix.strip('/')) + + +class UserBaseEndpoint(BaseEndpoint): + + def __init__(self, client, account_id=None): + self.account_id = account_id + super(UserBaseEndpoint, self).__init__(client) + + def for_account(self, account_id): + """ + Parameterised endpoint for account. + + Required when using a user access token, as the user may belong to multiple accounts with different plans + + """ + return self.__class__(self.client, account_id) + + def build_url(self, prefix, suffix): + if self.account_id is not None: + return '/{}/accounts/{}/{}/'.format(prefix.strip('/'), self.account_id, suffix.strip('/')) + else: + return super(UserBaseEndpoint, self).build_url(prefix, suffix) diff --git a/predicthq/endpoints/schemas.py b/predicthq/endpoints/schemas.py index 84cfe2b..6040104 100644 --- a/predicthq/endpoints/schemas.py +++ b/predicthq/endpoints/schemas.py @@ -197,7 +197,11 @@ def iter_items(self): return iter(self.results) def iter_all(self): - return itertools.chain(self.iter_items(), *(page.iter_items() for page in self.iter_pages())) + for item in self.iter_items(): + yield item + for page in self.iter_pages(): + for item in page.iter_items(): + yield item def __iter__(self): return self.iter_items() diff --git a/predicthq/endpoints/v1/accounts/endpoint.py b/predicthq/endpoints/v1/accounts/endpoint.py index 02d955f..3c8fb07 100644 --- a/predicthq/endpoints/v1/accounts/endpoint.py +++ b/predicthq/endpoints/v1/accounts/endpoint.py @@ -10,4 +10,4 @@ class AccountsEndpoint(BaseEndpoint): @returns(Account) def self(self): - return self.client.get('/v1/accounts/self/') + return self.client.get(self.build_url('v1', 'accounts/self')) diff --git a/predicthq/endpoints/v1/events/endpoint.py b/predicthq/endpoints/v1/events/endpoint.py index 07681b4..0c51b1b 100644 --- a/predicthq/endpoints/v1/events/endpoint.py +++ b/predicthq/endpoints/v1/events/endpoint.py @@ -1,23 +1,24 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals, absolute_import, print_function -from predicthq.endpoints.base import BaseEndpoint +from predicthq.endpoints.base import UserBaseEndpoint from predicthq.endpoints.decorators import accepts, returns -from .schemas import SearchParams, EventResultSet, CalendarParams, CalendarResultSet +from .schemas import SearchParams, EventResultSet, Count, CalendarParams, CalendarResultSet -class EventsEndpoint(BaseEndpoint): +class EventsEndpoint(UserBaseEndpoint): @accepts(SearchParams) @returns(EventResultSet) def search(self, **params): - return self.client.get('/v1/events', params=params) + return self.client.get(self.build_url('v1', 'events'), params=params) @accepts(SearchParams) + @returns(Count) def count(self, **params): - return self.client.get('/v1/events/count/', params=params).get('count') + return self.client.get(self.build_url('v1', 'events/count'), params=params) @accepts(CalendarParams) @returns(CalendarResultSet) def calendar(self, **params): - return self.client.get('/v1/events/calendar/', params=params) + return self.client.get(self.build_url('v1', 'events/calendar'), params=params) diff --git a/predicthq/endpoints/v1/events/schemas.py b/predicthq/endpoints/v1/events/schemas.py index c57aa1e..95e41a2 100644 --- a/predicthq/endpoints/v1/events/schemas.py +++ b/predicthq/endpoints/v1/events/schemas.py @@ -18,6 +18,7 @@ class Options: category = ListType(StringType) start = ModelType(DateTimeRange) end = ModelType(DateTimeRange) + active = ModelType(DateTimeRange) rank_level = ListType(IntType(min_value=1, max_value=5)) rank = ModelType(IntRange) country = ListType(StringType) @@ -50,6 +51,15 @@ class EventResultSet(ResultSet): results = ResultType(Event) +class Count(Model): + + count = IntType() + top_rank = FloatType() + rank_levels = DictType(IntType) + categories = DictType(IntType) + labels = DictType(IntType) + + class TopEventsSearchParams(SortableMixin, Model): limit = IntType(min_value=0, max_value=10) diff --git a/predicthq/endpoints/v1/signals/endpoint.py b/predicthq/endpoints/v1/signals/endpoint.py index b220137..d6f3580 100644 --- a/predicthq/endpoints/v1/signals/endpoint.py +++ b/predicthq/endpoints/v1/signals/endpoint.py @@ -5,7 +5,7 @@ import itertools -from predicthq.endpoints.base import BaseEndpoint +from predicthq.endpoints.base import UserBaseEndpoint from predicthq.endpoints.decorators import returns, accepts from predicthq.endpoints.v1.signals.schemas import SignalsSearchParams, AnalysisResultSet, AnalysisParams, Dimensions from .schemas import Signal, SignalID, SavedSignal, SignalResultSet, SignalDataPoints @@ -17,44 +17,44 @@ def chunks(iterator, size): yield itertools.chain([next(iterable)], itertools.islice(iterable, size - 1)) -class SignalsEndpoint(BaseEndpoint): +class SignalsEndpoint(UserBaseEndpoint): @accepts(SignalsSearchParams) @returns(SignalResultSet) def search(self, **params): - return self.client.get('/v1/signals/', params=params) + return self.client.get(self.build_url('v1', 'signals'), params=params) @accepts(SignalID) @returns(SavedSignal) def get(self, id): - return self.client.get('/v1/signals/{}/'.format(id)) + return self.client.get(self.build_url('v1', 'signals/{}'.format(id))) @accepts(Signal, query_string=False, role="create") @returns(SavedSignal) def create(self, **data): - return self.client.post('/v1/signals/', json=data) + return self.client.post(self.build_url('v1', 'signals'), json=data) @accepts(SavedSignal, query_string=False, role="update") @returns(SavedSignal) def update(self, id, **data): - return self.client.put('/v1/signals/{}/'.format(id), json=data) + return self.client.put(self.build_url('v1', 'signals/{}'.format(id)), json=data) @accepts(SignalID) def delete(self, id): - self.client.delete('/v1/signals/{}/'.format(id)) + self.client.delete(self.build_url('v1', 'signals/{}'.format(id))) @accepts(SignalDataPoints, query_string=False) def sink(self, id, data_points, chunk_size): for data_chunk in chunks(data_points, chunk_size): data = "\n".join(json.dumps(item, indent=None) for item in data_chunk) - self.client.post('/v1/signals/{}/sink/'.format(id), data=data, headers={"Content-Type": "application/x-ldjson"}) + self.client.post(self.build_url('v1', 'signals/{}/sink'.format(id)), data=data, headers={"Content-Type": "application/x-ldjson"}) @accepts(SignalID) @returns(Dimensions) def dimensions(self, id): - return self.client.get('/v1/signals/{}/dimensions/'.format(id)) + return self.client.get(self.build_url('v1', 'signals/{}/dimensions'.format(id))) @accepts(AnalysisParams) @returns(AnalysisResultSet) def analysis(self, id, **params): - return self.client.get('/v1/signals/{}/analysis/'.format(id), params=params) + return self.client.get(self.build_url('v1', 'signals/{}/analysis'.format(id)), params=params) diff --git a/tests/endpoints/v1/events_tests.py b/tests/endpoints/v1/events_tests.py index 6444c44..f2c0337 100644 --- a/tests/endpoints/v1/events_tests.py +++ b/tests/endpoints/v1/events_tests.py @@ -6,7 +6,7 @@ from tests import with_mock_client, with_mock_responses, with_client -from predicthq.endpoints.v1.events.schemas import EventResultSet, CalendarResultSet +from predicthq.endpoints.v1.events.schemas import EventResultSet, CalendarResultSet, Count class EventsTest(unittest.TestCase): @@ -18,22 +18,15 @@ def test_search_params(self, client): label=["label1", "label2"], category="category", start__gte="2016-03-01", start__lt=datetime(2016, 4, 1), start__tz="Pacific/Auckland",) - client.request.assert_called_once_with('get', '/v1/events', params={ + client.request.assert_called_once_with('get', '/v1/events/', params={ 'id': 'id', 'rank.gt': 85, 'rank_level': '4,5', 'category': 'category', 'country': 'NZ,AU', 'within': '2km@42.346,-71.0432', 'label': 'label1,label2', 'q': 'query', 'start.lt': '2016-04-01T00:00:00.000000', 'start.gte': '2016-03-01T00:00:00.000000', 'start.tz': 'Pacific/Auckland'}) - @with_mock_client(request_returns={"count": 12}) - def test_count_params(self, client): - client.events.count(id="id", q="query", rank_level=[4,5], rank__gt=85, country=["NZ", "AU"], - within__radius="2km", within__longitude=-71.0432, within__latitude=42.346, - label=["label1", "label2"], category="category", - start__gte="2016-03-01", start__lt=datetime(2016, 4, 1), start__tz="Pacific/Auckland",) - - client.request.assert_called_once_with('get', '/v1/events/count/', params={ - 'id': 'id', 'rank.gt': 85, 'rank_level': '4,5', 'category': 'category', 'country': 'NZ,AU', - 'within': '2km@42.346,-71.0432', 'label': 'label1,label2', 'q': 'query', - 'start.lt': '2016-04-01T00:00:00.000000', 'start.gte': '2016-03-01T00:00:00.000000', 'start.tz': 'Pacific/Auckland'}) + @with_mock_client() + def test_search_for_account(self, client): + client.events.for_account('account-id').search(q="query") + client.request.assert_called_once_with('get', '/v1/accounts/account-id/events/', params={'q': 'query'}) @with_client() @with_mock_responses() @@ -46,9 +39,9 @@ def test_search(self, client, responses): @with_client() @with_mock_responses() def test_count(self, client, responses): - result = client.events.count(q="Foo Fighters", country="AU", limit=10) - self.assertIsInstance(result, int) - self.assertEqual(result, 12) + result = client.events.count(active__gte="2015-01-01", active__lte="2015-12-31", within="50km@-27.470784,153.030124") + self.assertIsInstance(result, Count) + self.assertEqual(result.count, 2501) @with_client() @with_mock_responses() diff --git a/tests/endpoints/v1/signals_tests.py b/tests/endpoints/v1/signals_tests.py index a02597a..3bed7b7 100644 --- a/tests/endpoints/v1/signals_tests.py +++ b/tests/endpoints/v1/signals_tests.py @@ -17,6 +17,11 @@ def test_search_params(self, client): client.signals.search(sort=["-created_at", "updated_at"]) client.request.assert_called_once_with('get', '/v1/signals/', params={'sort': '-created_at,updated_at'}) + @with_mock_client() + def test_search_for_account(self, client): + client.signals.for_account('account-id').search() + client.request.assert_called_once_with('get', '/v1/accounts/account-id/signals/', params={}) + @with_mock_client(request_returns={"id": "signal-id", "name": "Test", "dimensions": [], "country": "NZ"}) def test_get_params(self, client): client.signals.get(id="signal-id") diff --git a/tests/fixtures/requests_responses/events_test/test_count.json b/tests/fixtures/requests_responses/events_test/test_count.json index 6e6f85e..8fea4d5 100644 --- a/tests/fixtures/requests_responses/events_test/test_count.json +++ b/tests/fixtures/requests_responses/events_test/test_count.json @@ -2,11 +2,84 @@ { "method": "GET", "match_querystring": true, - "url": "/v1/events/count/?q=Foo+Fighters&country=AU&limit=10", + "url": "/v1/events/count/?active.gte=2015-01-01T00%3A00%3A00.000000&active.lte=2015-12-31T00%3A00%3A00.000000&within=50km%40-27.470784%2C153.030124", "status": 200, "content_type": "application/json", "body": { - "count": 12 + "count": 2501, + "labels": { + "outdoor": 69, + "observance": 16, + "concert": 966, + "family": 138, + "food": 12, + "sales": 2, + "community": 16, + "fundraiser": 22, + "weather": 5, + "sport": 627, + "politics": 6, + "performing-arts": 219, + "technology": 9, + "campus": 10, + "holiday-religious": 50, + "conference": 347, + "education": 16, + "religion": 29, + "festival": 220, + "movie": 38, + "delay": 16, + "attraction": 3, + "holiday-national": 8, + "music": 978, + "holiday-hebrew": 17, + "observance-local": 1, + "health": 31, + "soccer": 417, + "holiday": 84, + "cricket": 9, + "storm": 5, + "business": 289, + "holiday-christian": 19, + "club": 2, + "social": 111, + "rain": 1, + "holiday-orthodox": 6, + "school": 5, + "comedy": 6, + "rugby": 17, + "science": 9, + "daylight-savings": 2, + "holiday-local-common": 3, + "airport": 16, + "holiday-local": 1, + "expo": 48, + "cold-wave": 1, + "observance-season": 4, + "holiday-muslim": 8 + }, + "rank_levels": { + "1": 560, + "3": 592, + "2": 1194, + "5": 27, + "4": 111 + }, + "categories": { + "public-holidays": 58, + "expos": 48, + "conferences": 346, + "sports": 625, + "observances": 21, + "airport-delays": 16, + "school-holidays": 5, + "severe-weather": 5, + "festivals": 210, + "performing-arts": 219, + "daylight-savings": 2, + "concerts": 946 + }, + "top_rank": 90.0 } } ]