Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion predicthq/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_headers(self, headers):
def request(self, method, path, **kwargs):
headers = self.get_headers(kwargs.pop("headers", {}))
response = requests.request(method, self.build_url(path), headers=headers, **kwargs)

self.logger.debug(response.request.url)
try:
response.raise_for_status()
except requests.HTTPError:
Expand Down
25 changes: 25 additions & 0 deletions predicthq/endpoints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,28 @@ class BaseEndpoint(six.with_metaclass(MetaEndpoint)):

def __init__(self, client):
self.client = client

def build_url(self, prefix, suffix):
return '/{}/{}/'.format(prefix.strip('/'), suffix.strip('/'))


class UserBaseEndpoint(BaseEndpoint):

def __init__(self, client, account_id=None):
self.account_id = account_id
super(UserBaseEndpoint, self).__init__(client)

def for_account(self, account_id):
"""
Parameterised endpoint for account.

Required when using a user access token, as the user may belong to multiple accounts with different plans

"""
return self.__class__(self.client, account_id)

def build_url(self, prefix, suffix):
if self.account_id is not None:
return '/{}/accounts/{}/{}/'.format(prefix.strip('/'), self.account_id, suffix.strip('/'))
else:
return super(UserBaseEndpoint, self).build_url(prefix, suffix)
6 changes: 5 additions & 1 deletion predicthq/endpoints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,11 @@ def iter_items(self):
return iter(self.results)

def iter_all(self):
return itertools.chain(self.iter_items(), *(page.iter_items() for page in self.iter_pages()))
for item in self.iter_items():
yield item
for page in self.iter_pages():
for item in page.iter_items():
yield item

def __iter__(self):
return self.iter_items()
Expand Down
2 changes: 1 addition & 1 deletion predicthq/endpoints/v1/accounts/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ class AccountsEndpoint(BaseEndpoint):

@returns(Account)
def self(self):
return self.client.get('/v1/accounts/self/')
return self.client.get(self.build_url('v1', 'accounts/self'))
13 changes: 7 additions & 6 deletions predicthq/endpoints/v1/events/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function

from predicthq.endpoints.base import BaseEndpoint
from predicthq.endpoints.base import UserBaseEndpoint
from predicthq.endpoints.decorators import accepts, returns
from .schemas import SearchParams, EventResultSet, CalendarParams, CalendarResultSet
from .schemas import SearchParams, EventResultSet, Count, CalendarParams, CalendarResultSet


class EventsEndpoint(BaseEndpoint):
class EventsEndpoint(UserBaseEndpoint):

@accepts(SearchParams)
@returns(EventResultSet)
def search(self, **params):
return self.client.get('/v1/events', params=params)
return self.client.get(self.build_url('v1', 'events'), params=params)

@accepts(SearchParams)
@returns(Count)
def count(self, **params):
return self.client.get('/v1/events/count/', params=params).get('count')
return self.client.get(self.build_url('v1', 'events/count'), params=params)

@accepts(CalendarParams)
@returns(CalendarResultSet)
def calendar(self, **params):
return self.client.get('/v1/events/calendar/', params=params)
return self.client.get(self.build_url('v1', 'events/calendar'), params=params)
10 changes: 10 additions & 0 deletions predicthq/endpoints/v1/events/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Options:
category = ListType(StringType)
start = ModelType(DateTimeRange)
end = ModelType(DateTimeRange)
active = ModelType(DateTimeRange)
rank_level = ListType(IntType(min_value=1, max_value=5))
rank = ModelType(IntRange)
country = ListType(StringType)
Expand Down Expand Up @@ -50,6 +51,15 @@ class EventResultSet(ResultSet):
results = ResultType(Event)


class Count(Model):

count = IntType()
top_rank = FloatType()
rank_levels = DictType(IntType)
categories = DictType(IntType)
labels = DictType(IntType)


class TopEventsSearchParams(SortableMixin, Model):

limit = IntType(min_value=0, max_value=10)
Expand Down
20 changes: 10 additions & 10 deletions predicthq/endpoints/v1/signals/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import itertools

from predicthq.endpoints.base import BaseEndpoint
from predicthq.endpoints.base import UserBaseEndpoint
from predicthq.endpoints.decorators import returns, accepts
from predicthq.endpoints.v1.signals.schemas import SignalsSearchParams, AnalysisResultSet, AnalysisParams, Dimensions
from .schemas import Signal, SignalID, SavedSignal, SignalResultSet, SignalDataPoints
Expand All @@ -17,44 +17,44 @@ def chunks(iterator, size):
yield itertools.chain([next(iterable)], itertools.islice(iterable, size - 1))


class SignalsEndpoint(BaseEndpoint):
class SignalsEndpoint(UserBaseEndpoint):

@accepts(SignalsSearchParams)
@returns(SignalResultSet)
def search(self, **params):
return self.client.get('/v1/signals/', params=params)
return self.client.get(self.build_url('v1', 'signals'), params=params)

@accepts(SignalID)
@returns(SavedSignal)
def get(self, id):
return self.client.get('/v1/signals/{}/'.format(id))
return self.client.get(self.build_url('v1', 'signals/{}'.format(id)))

@accepts(Signal, query_string=False, role="create")
@returns(SavedSignal)
def create(self, **data):
return self.client.post('/v1/signals/', json=data)
return self.client.post(self.build_url('v1', 'signals'), json=data)

@accepts(SavedSignal, query_string=False, role="update")
@returns(SavedSignal)
def update(self, id, **data):
return self.client.put('/v1/signals/{}/'.format(id), json=data)
return self.client.put(self.build_url('v1', 'signals/{}'.format(id)), json=data)

@accepts(SignalID)
def delete(self, id):
self.client.delete('/v1/signals/{}/'.format(id))
self.client.delete(self.build_url('v1', 'signals/{}'.format(id)))

@accepts(SignalDataPoints, query_string=False)
def sink(self, id, data_points, chunk_size):
for data_chunk in chunks(data_points, chunk_size):
data = "\n".join(json.dumps(item, indent=None) for item in data_chunk)
self.client.post('/v1/signals/{}/sink/'.format(id), data=data, headers={"Content-Type": "application/x-ldjson"})
self.client.post(self.build_url('v1', 'signals/{}/sink'.format(id)), data=data, headers={"Content-Type": "application/x-ldjson"})

@accepts(SignalID)
@returns(Dimensions)
def dimensions(self, id):
return self.client.get('/v1/signals/{}/dimensions/'.format(id))
return self.client.get(self.build_url('v1', 'signals/{}/dimensions'.format(id)))

@accepts(AnalysisParams)
@returns(AnalysisResultSet)
def analysis(self, id, **params):
return self.client.get('/v1/signals/{}/analysis/'.format(id), params=params)
return self.client.get(self.build_url('v1', 'signals/{}/analysis'.format(id)), params=params)
25 changes: 9 additions & 16 deletions tests/endpoints/v1/events_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tests import with_mock_client, with_mock_responses, with_client

from predicthq.endpoints.v1.events.schemas import EventResultSet, CalendarResultSet
from predicthq.endpoints.v1.events.schemas import EventResultSet, CalendarResultSet, Count


class EventsTest(unittest.TestCase):
Expand All @@ -18,22 +18,15 @@ def test_search_params(self, client):
label=["label1", "label2"], category="category",
start__gte="2016-03-01", start__lt=datetime(2016, 4, 1), start__tz="Pacific/Auckland",)

client.request.assert_called_once_with('get', '/v1/events', params={
client.request.assert_called_once_with('get', '/v1/events/', params={
'id': 'id', 'rank.gt': 85, 'rank_level': '4,5', 'category': 'category', 'country': 'NZ,AU',
'within': '2km@42.346,-71.0432', 'label': 'label1,label2', 'q': 'query',
'start.lt': '2016-04-01T00:00:00.000000', 'start.gte': '2016-03-01T00:00:00.000000', 'start.tz': 'Pacific/Auckland'})

@with_mock_client(request_returns={"count": 12})
def test_count_params(self, client):
client.events.count(id="id", q="query", rank_level=[4,5], rank__gt=85, country=["NZ", "AU"],
within__radius="2km", within__longitude=-71.0432, within__latitude=42.346,
label=["label1", "label2"], category="category",
start__gte="2016-03-01", start__lt=datetime(2016, 4, 1), start__tz="Pacific/Auckland",)

client.request.assert_called_once_with('get', '/v1/events/count/', params={
'id': 'id', 'rank.gt': 85, 'rank_level': '4,5', 'category': 'category', 'country': 'NZ,AU',
'within': '2km@42.346,-71.0432', 'label': 'label1,label2', 'q': 'query',
'start.lt': '2016-04-01T00:00:00.000000', 'start.gte': '2016-03-01T00:00:00.000000', 'start.tz': 'Pacific/Auckland'})
@with_mock_client()
def test_search_for_account(self, client):
client.events.for_account('account-id').search(q="query")
client.request.assert_called_once_with('get', '/v1/accounts/account-id/events/', params={'q': 'query'})

@with_client()
@with_mock_responses()
Expand All @@ -46,9 +39,9 @@ def test_search(self, client, responses):
@with_client()
@with_mock_responses()
def test_count(self, client, responses):
result = client.events.count(q="Foo Fighters", country="AU", limit=10)
self.assertIsInstance(result, int)
self.assertEqual(result, 12)
result = client.events.count(active__gte="2015-01-01", active__lte="2015-12-31", within="50km@-27.470784,153.030124")
self.assertIsInstance(result, Count)
self.assertEqual(result.count, 2501)

@with_client()
@with_mock_responses()
Expand Down
5 changes: 5 additions & 0 deletions tests/endpoints/v1/signals_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def test_search_params(self, client):
client.signals.search(sort=["-created_at", "updated_at"])
client.request.assert_called_once_with('get', '/v1/signals/', params={'sort': '-created_at,updated_at'})

@with_mock_client()
def test_search_for_account(self, client):
client.signals.for_account('account-id').search()
client.request.assert_called_once_with('get', '/v1/accounts/account-id/signals/', params={})

@with_mock_client(request_returns={"id": "signal-id", "name": "Test", "dimensions": [], "country": "NZ"})
def test_get_params(self, client):
client.signals.get(id="signal-id")
Expand Down
77 changes: 75 additions & 2 deletions tests/fixtures/requests_responses/events_test/test_count.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,84 @@
{
"method": "GET",
"match_querystring": true,
"url": "/v1/events/count/?q=Foo+Fighters&country=AU&limit=10",
"url": "/v1/events/count/?active.gte=2015-01-01T00%3A00%3A00.000000&active.lte=2015-12-31T00%3A00%3A00.000000&within=50km%40-27.470784%2C153.030124",
"status": 200,
"content_type": "application/json",
"body": {
"count": 12
"count": 2501,
"labels": {
"outdoor": 69,
"observance": 16,
"concert": 966,
"family": 138,
"food": 12,
"sales": 2,
"community": 16,
"fundraiser": 22,
"weather": 5,
"sport": 627,
"politics": 6,
"performing-arts": 219,
"technology": 9,
"campus": 10,
"holiday-religious": 50,
"conference": 347,
"education": 16,
"religion": 29,
"festival": 220,
"movie": 38,
"delay": 16,
"attraction": 3,
"holiday-national": 8,
"music": 978,
"holiday-hebrew": 17,
"observance-local": 1,
"health": 31,
"soccer": 417,
"holiday": 84,
"cricket": 9,
"storm": 5,
"business": 289,
"holiday-christian": 19,
"club": 2,
"social": 111,
"rain": 1,
"holiday-orthodox": 6,
"school": 5,
"comedy": 6,
"rugby": 17,
"science": 9,
"daylight-savings": 2,
"holiday-local-common": 3,
"airport": 16,
"holiday-local": 1,
"expo": 48,
"cold-wave": 1,
"observance-season": 4,
"holiday-muslim": 8
},
"rank_levels": {
"1": 560,
"3": 592,
"2": 1194,
"5": 27,
"4": 111
},
"categories": {
"public-holidays": 58,
"expos": 48,
"conferences": 346,
"sports": 625,
"observances": 21,
"airport-delays": 16,
"school-holidays": 5,
"severe-weather": 5,
"festivals": 210,
"performing-arts": 219,
"daylight-savings": 2,
"concerts": 946
},
"top_rank": 90.0
}
}
]