Skip to content

Commit

Permalink
endpoint schemas & decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
Thierry Jossermoz committed Feb 29, 2016
1 parent 70eaae8 commit 733fd52
Show file tree
Hide file tree
Showing 17 changed files with 57 additions and 39 deletions.
15 changes: 14 additions & 1 deletion predicthq/endpoints/base.py
@@ -1,8 +1,21 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function

import six

class BaseEndpoint(object):

class MetaEndpoint(type):

def __new__(mcs, name, bases, data):
if 'Meta' not in data:
class Meta:
""" Used by decorators when overriding schema classes """
pass
data['Meta'] = Meta
return super(MetaEndpoint, mcs).__new__(mcs, name, bases, data)


class BaseEndpoint(six.with_metaclass(MetaEndpoint)):

def __init__(self, client):
self.client = client
25 changes: 16 additions & 9 deletions predicthq/decorators.py → predicthq/endpoints/decorators.py
@@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function

import six
import functools
from collections import defaultdict

from .exceptions import ValidationError
from .schemas import ResultSet, Model, SchematicsDataError
import six

from predicthq.endpoints.schemas import ResultSet, Model, SchematicsDataError
from predicthq.exceptions import ValidationError


def _to_url_params(data, glue=".", separator=","):
Expand Down Expand Up @@ -43,11 +44,14 @@ def accepts(schema_class, query_string=True, role=None):
def decorator(f):

@functools.wraps(f)
def wrapper(*args, **kwargs):
def wrapper(endpoint, *args, **kwargs):

schema = getattr(endpoint.Meta, f.__name__, {}).get("accepts", schema_class)

if not kwargs: # accept instance of schema_class
new_args = tuple(a for a in args if not isinstance(a, (schema_class, dict)))
new_args = tuple(a for a in args if not isinstance(a, (schema, dict)))
if args != new_args:
instance = next(a for a in args if isinstance(a, (schema_class, dict)))
instance = next(a for a in args if isinstance(a, (schema, dict)))
if isinstance(instance, dict):
kwargs = instance
else:
Expand All @@ -56,7 +60,7 @@ def wrapper(*args, **kwargs):

try:
data = _process_kwargs(kwargs)
model = schema_class()
model = schema()
model.import_data(data, strict=True, partial=False)
model.validate()
except SchematicsDataError as e:
Expand All @@ -67,7 +71,7 @@ def wrapper(*args, **kwargs):
else:
params = model.to_primitive(role=role)

return f(*args, **params)
return f(endpoint, *args, **params)

return wrapper

Expand All @@ -80,9 +84,12 @@ def decorator(f):

@functools.wraps(f)
def wrapper(endpoint, *args, **kwargs):

schema = getattr(endpoint.Meta, f.__name__, {}).get("returns", schema_class)

data = f(endpoint, *args, **kwargs)
try:
model = schema_class()
model = schema()
model._endpoint = endpoint

# if schema class is a ResultSet, tell it how to load more results
Expand Down
3 changes: 1 addition & 2 deletions predicthq/endpoints/oauth2/endpoint.py
@@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function

from predicthq.decorators import accepts, returns
from predicthq.endpoints.base import BaseEndpoint

from predicthq.endpoints.decorators import accepts, returns
from .schemas import AccessToken, GetTokenParams, RevokeTokenParams


Expand Down
2 changes: 1 addition & 1 deletion predicthq/endpoints/oauth2/schemas.py
Expand Up @@ -2,7 +2,7 @@
from __future__ import unicode_literals, absolute_import, print_function

from predicthq.config import config
from predicthq.schemas import Model, StringType, StringListType, IntType
from predicthq.endpoints.schemas import Model, StringType, StringListType, IntType


class GetTokenParams(Model):
Expand Down
2 changes: 1 addition & 1 deletion predicthq/schemas.py → predicthq/endpoints/schemas.py
Expand Up @@ -13,7 +13,7 @@

from schematics.models import Model
from schematics.transforms import Role
from schematics.types import StringType, DateTimeType as SchematicsDateTimeType, IntType, FloatType, URLType, GeoPointType, BooleanType
from schematics.types import StringType, DateTimeType as SchematicsDateTimeType, IntType, FloatType, URLType, GeoPointType, BooleanType, DateType
from schematics.types.compound import ListType as SchematicsListType, ModelType, DictType
from schematics.exceptions import ValidationError as SchematicsValidationError, DataError as SchematicsDataError, ConversionError
from schematics.types.serializable import serializable
Expand Down
3 changes: 1 addition & 2 deletions predicthq/endpoints/v1/accounts/endpoint.py
@@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function

from predicthq.decorators import returns
from predicthq.endpoints.base import BaseEndpoint

from predicthq.endpoints.decorators import returns
from .schemas import Account


Expand Down
2 changes: 1 addition & 1 deletion predicthq/endpoints/v1/accounts/schemas.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function

from predicthq.schemas import Model, StringType, DateTimeType
from predicthq.endpoints.schemas import Model, StringType, DateTimeType


class Account(Model):
Expand Down
3 changes: 1 addition & 2 deletions predicthq/endpoints/v1/events/endpoint.py
Expand Up @@ -2,8 +2,7 @@
from __future__ import unicode_literals, absolute_import, print_function

from predicthq.endpoints.base import BaseEndpoint
from predicthq.decorators import accepts, returns

from predicthq.endpoints.decorators import accepts, returns
from .schemas import SearchParams, EventResultSet


Expand Down
6 changes: 3 additions & 3 deletions predicthq/endpoints/v1/events/schemas.py
@@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function

from predicthq.schemas import PaginatedMixin, SortableMixin, Model, ResultSet, ListType, StringType, \
GeoJSONPointType, StringListType, StringModelType, Area, ModelType, IntRange, \
IntType, DateRange, DateTimeType, FloatType, ResultType
from predicthq.endpoints.schemas import PaginatedMixin, SortableMixin, Model, ResultSet, \
ListType, StringType, GeoJSONPointType, StringListType, StringModelType, Area, \
ModelType, IntRange, IntType, DateRange, DateTimeType, FloatType, ResultType


class SearchParams(PaginatedMixin, SortableMixin, Model):
Expand Down
6 changes: 3 additions & 3 deletions predicthq/endpoints/v1/signals/endpoint.py
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function

import itertools
import json

from predicthq.decorators import returns, accepts
import itertools

from predicthq.endpoints.base import BaseEndpoint
from predicthq.endpoints.decorators import returns, accepts
from predicthq.endpoints.v1.signals.schemas import SignalsSearchParams, AnalysisResultSet, AnalysisParams, Dimensions

from .schemas import Signal, SignalID, SavedSignal, SignalResultSet, SignalDataPoints


Expand Down
2 changes: 1 addition & 1 deletion predicthq/endpoints/v1/signals/schemas.py
Expand Up @@ -3,7 +3,7 @@

from schematics.transforms import blacklist, wholelist, whitelist

from predicthq.schemas import Model, StringType, ListType, ModelType, DateTimeType, ResultSet, ResultType, SortableMixin, FloatType, IntType, DictType, \
from predicthq.endpoints.schemas import Model, StringType, ListType, ModelType, DateTimeType, ResultSet, ResultType, SortableMixin, FloatType, IntType, DictType, \
PaginatedMixin, DateRange, StringListType, StringModelType, Area, BooleanType


Expand Down
Expand Up @@ -3,8 +3,7 @@

import unittest

from predicthq import schemas
from predicthq import decorators
from predicthq.endpoints import decorators, schemas
from predicthq.endpoints.base import BaseEndpoint
from predicthq.exceptions import ValidationError

Expand All @@ -27,13 +26,13 @@ class SchemaExample(schemas.Model):
arg1 = schemas.StringType(required=True)
arg2 = schemas.ListType(schemas.IntType)

class EndpointExample(object):
class EndpointExample(BaseEndpoint):

@decorators.accepts(SchemaExample)
def func(self, **kwargs):
return kwargs

endpoint = EndpointExample()
endpoint = EndpointExample(None)
self.assertDictEqual(endpoint.func(arg1="test", arg2=[1, 2]), {'arg1': 'test', 'arg2': '1,2'})

self.assertDictEqual(endpoint.func(SchemaExample({"arg1": "test", "arg2": [1, 2]})), {'arg1': 'test', 'arg2': '1,2'})
Expand All @@ -52,13 +51,13 @@ class SchemaExample(schemas.Model):
arg1 = schemas.StringType(required=True)
arg2 = schemas.ListType(schemas.IntType)

class EndpointExample(object):
class EndpointExample(BaseEndpoint):

@decorators.accepts(SchemaExample, query_string=False)
def func(self, **kwargs):
return kwargs

endpoint = EndpointExample()
endpoint = EndpointExample(None)
self.assertDictEqual(endpoint.func({"arg1": "test", "arg2": [1, 2]}), {'arg1': 'test', 'arg2': [1, 2]})

def test_returns(self):
Expand All @@ -67,13 +66,13 @@ class SchemaExample(schemas.Model):
arg1 = schemas.StringType(required=True)
arg2 = schemas.ListType(schemas.IntType)

class EndpointExample(object):
class EndpointExample(BaseEndpoint):

@decorators.returns(SchemaExample)
def func(self, **kwargs):
return kwargs

endpoint = EndpointExample()
endpoint = EndpointExample(None)
self.assertEqual(endpoint.func(arg1="test", arg2=[1, 2]), SchemaExample({'arg1': 'test', 'arg2': [1, 2]}))

with self.assertRaises(ValidationError):
Expand Down
9 changes: 5 additions & 4 deletions tests/schemas_tests.py → tests/endpoints/schemas_tests.py
Expand Up @@ -2,11 +2,12 @@
from __future__ import unicode_literals, absolute_import, print_function

import unittest

import pytz
from datetime import datetime

from predicthq import schemas
from predicthq import decorators
from predicthq.endpoints import decorators, schemas
from predicthq.endpoints.base import BaseEndpoint


class SchemasTest(unittest.TestCase):
Expand Down Expand Up @@ -137,7 +138,7 @@ class ResultSetExample(schemas.ResultSet):

results = schemas.ResultType(ResultExample)

class EndpointExample(object):
class EndpointExample(BaseEndpoint):

@decorators.returns(ResultSetExample)
def load_page(self, page):
Expand All @@ -149,7 +150,7 @@ def load_page(self, page):
"results": [{"value": 1 + (3 * (page - 1))}, {"value": 2 + (3 * (page - 1))}, {"value": 3 + (3 * (page - 1))}]
}

endpoint = EndpointExample()
endpoint = EndpointExample(None)

p1 = endpoint.load_page(page=1)
self.assertEqual(p1.count, 9)
Expand Down
2 changes: 2 additions & 0 deletions tests/endpoints/v1/__init__.py
@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function
File renamed without changes.
Expand Up @@ -23,7 +23,6 @@ def test_search_params(self, client):
'within': '2km@42.346,-71.0432', 'label': 'label1,label2', 'q': 'query',
'start.lt': '2016-04-01T00:00:00.000000', 'start.gte': '2016-03-01T00:00:00.000000', 'start.tz': 'Pacific/Auckland'})


@with_mock_client(request_returns={"count": 12})
def test_count_params(self, client):
client.events.count(id="id", q="query", rank_level=[4,5], rank__gt=85, country=["NZ", "AU"],
Expand Down
File renamed without changes.

0 comments on commit 733fd52

Please sign in to comment.