Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
Merge pull request #51 from postatum/94990302_es_aggregation
Browse files Browse the repository at this point in the history
Elasticsearch aggregations
  • Loading branch information
jstoiko committed Jun 8, 2015
2 parents 84a11fa + 3e29e07 commit 0a2e58e
Show file tree
Hide file tree
Showing 8 changed files with 401 additions and 32 deletions.
44 changes: 44 additions & 0 deletions nefertari/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,50 @@ def do_count(self, params):
except IndexNotFoundException:
return 0

def aggregate(self, **params):
""" Perform aggreration
Arguments:
:_aggregations_params: Dict of aggregation params. Root key is an
aggregation name. Required.
:__raise_on_empty: Boolean indicating whether to raise exception
when IndexNotFoundException exception happens. Optional,
defaults to False.
:_search_type: Type of search to use. Optional, defaults to
'count'. You might want to provide this argument explicitly
when performing nested aggregations on buckets.
"""
_aggregations_params = params.pop('_aggregations_params', None)
__raise_on_empty = params.pop('__raise_on_empty', False)
_search_type = params.pop('_search_type', 'count')

if not _aggregations_params:
raise Exception('Missing _aggregations_params')

# Set limit so ES won't complain. It is ignored in the end
params['_limit'] = 0
search_params = self.build_search_params(params)
search_params.pop('size', None)
search_params.pop('from_', None)
search_params.pop('sort', None)

search_params['search_type'] = _search_type
search_params['body']['aggregations'] = _aggregations_params

log.debug('Performing aggregation: {}'.format(_aggregations_params))
try:
response = ES.api.search(**search_params)
except IndexNotFoundException:
if __raise_on_empty:
raise JHTTPNotFound(
'Aggregation failed: Index does not exist')
return {}

try:
return response['aggregations']
except KeyError:
raise JHTTPNotFound('No aggregations returned from ES')

def get_collection(self, **params):
__raise_on_empty = params.pop('__raise_on_empty', False)

Expand Down
2 changes: 1 addition & 1 deletion nefertari/utils/dictset.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def process_int_param(self, name, default=None):
try:
self[name] = int(self[name])
except ValueError:
raise ValueError('%s must be a decimal' % name)
raise ValueError('%s must be an integer' % name)

elif default is not None:
self[name] = default
Expand Down
36 changes: 36 additions & 0 deletions nefertari/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,39 @@ def issequence(arg):
return (not hasattr(arg, 'strip') and
hasattr(arg, '__getitem__') or
hasattr(arg, '__iter__'))


def merge_dicts(a, b, path=None):
""" Merge dict :b: into dict :a:
Code snippet from http://stackoverflow.com/a/7205107
"""
if path is None:
path = []

for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
merge_dicts(a[key], b[key], path + [str(key)])
elif a[key] == b[key]:
pass # same leaf value
else:
raise Exception(
'Conflict at %s' % '.'.join(path + [str(key)]))
else:
a[key] = b[key]
return a


def str2dict(dotted_str, value=None, separator='.'):
""" Convert dotted string to dict splitting by :separator: """
dict_ = {}
parts = dotted_str.split(separator)
d, prev = dict_, None
for part in parts:
prev = d
d = d.setdefault(part, {})
else:
if value is not None:
prev[part] = value
return dict_
126 changes: 111 additions & 15 deletions nefertari/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from nefertari.json_httpexceptions import (
JHTTPBadRequest, JHTTPNotFound, JHTTPMethodNotAllowed)
from nefertari.utils import dictset
from nefertari import wrappers
from nefertari.utils import dictset, merge_dicts, str2dict
from nefertari import wrappers, engine
from nefertari.resource import ACTIONS
from nefertari import engine

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,13 +75,11 @@ def convert_dotted(params):
if not isinstance(params, dictset):
params = dictset(params)

dotted = defaultdict(dict)
dotted_items = {k: v for k, v in params.items() if '.' in k}

if dotted_items:
for key, value in dotted_items.items():
field, subfield = key.split('.')
dotted[field].update({subfield: value})
dicts = [str2dict(key, val) for key, val in dotted_items.items()]
dotted = reduce(merge_dicts, dicts)
params = params.subset(['-' + k for k in dotted_items.keys()])
params.update(dict(dotted))

Expand Down Expand Up @@ -142,6 +139,9 @@ def __init__(self, context, request, _query_params={}, _json_params={}):
else:
self.refresh_index = None

root_resource = getattr(self, 'root_resource', None)
self._auth_enabled = root_resource is not None and root_resource.auth

self._run_init_actions()

def _run_init_actions(self):
Expand Down Expand Up @@ -171,9 +171,7 @@ def set_public_limits(self):
""" Set public limits if auth is enabled and user is not
authenticated.
"""
root_resource = getattr(self, 'root_resource', None)
auth_enabled = root_resource is not None and root_resource.auth
if auth_enabled and not getattr(self.request, 'user', None):
if self._auth_enabled and not getattr(self.request, 'user', None):
wrappers.set_public_limits(self)

def convert_ids2objects(self, model_cls=None):
Expand Down Expand Up @@ -202,14 +200,11 @@ def get_debug(self, package=None):
return asbool(self.request.registry.settings.get(key))

def setup_default_wrappers(self):
root_resource = getattr(self, 'root_resource', None)
auth_enabled = root_resource and root_resource.auth

self._after_calls['index'] = [
wrappers.wrap_in_dict(self.request),
wrappers.add_meta(self.request),
]
if auth_enabled:
if self._auth_enabled:
self._after_calls['index'] += [
wrappers.apply_privacy(self.request),
]
Expand All @@ -221,7 +216,7 @@ def setup_default_wrappers(self):
wrappers.wrap_in_dict(self.request),
wrappers.add_meta(self.request),
]
if auth_enabled:
if self._auth_enabled:
self._after_calls['show'] += [
wrappers.apply_privacy(self.request),
]
Expand Down Expand Up @@ -309,6 +304,107 @@ def _get_object(id_):
self._json_params[name] = _get_object(ids)


class ESAggregationMixin(object):
""" Mixin that provides methods to perform Elasticsearch aggregations.
Should be mixed with subclasses of `nefertari.view.BaseView`.
To use aggregation at collection route requests, simply return
`self.aggregate()`.
Attributes:
:_aggregations_keys: Sequence of strings representing name(s) of the
root key under which aggregations names are defined. Order of keys
matters - first key found in request is popped and returned.
:_auth_enabled: Boolean indicating whether authentication is enabled.
Is calculated in BaseView.
Examples:
If _aggregations_keys=('_aggregations',), then query string params
should look like:
_aggregations.min_price.min.field=price
"""
_aggregations_keys = ('_aggregations', '_aggs')
_auth_enabled = None

def pop_aggregations_params(self):
""" Pop and return aggregation params from query string params.
Aggregation params are expected to be prefixed(nested under) by
any of `self._aggregations_keys`.
"""
self._query_params = BaseView.convert_dotted(self._query_params)

for key in self._aggregations_keys:
if key in self._query_params:
return self._query_params.pop(key)
else:
raise KeyError('Missing aggregation params')

def stub_wrappers(self):
""" Remove default 'index' after call wrappers and add only
those needed for aggregation results output.
"""
self._after_calls['index'] = []

@classmethod
def get_aggregations_fields(cls, params):
""" Recursively get values under the 'field' key.
Is used to get names of fields on which aggregations should be
performed.
"""
fields = []
for key, val in params.items():
if isinstance(val, dict):
fields += cls.get_aggregations_fields(val)
if key == 'field':
fields.append(val)
return fields

def check_aggregations_privacy(self, aggregations_params):
""" Check per-field privacy rules in aggregations.
Privacy is checked by making sure user has access to the fields
used in aggregations.
"""
fields = self.get_aggregations_fields(aggregations_params)
fields_dict = dictset.fromkeys(fields)
fields_dict['_type'] = self._model_class.__name__

wrapper = wrappers.apply_privacy(self.request)
allowed_fields = set(wrapper(result=fields_dict).keys())
not_allowed_fields = set(fields) - set(allowed_fields)

if not_allowed_fields:
err = 'Not enough permissions to aggregate on fields: {}'.format(
','.join(not_allowed_fields))
raise ValueError(err)

def aggregate(self):
""" Perform aggregation and return response. """
from nefertari.elasticsearch import ES
if not ES.settings.asbool('enable_aggregations'):
log.warn('Elasticsearch aggregations are disabled')
raise KeyError('Elasticsearch aggregations are disabled')

aggregations_params = self.pop_aggregations_params()
if self._auth_enabled:
self.check_aggregations_privacy(aggregations_params)
self.stub_wrappers()

search_params = []
if 'q' in self._query_params:
search_params.append(self._query_params.pop('q'))
_raw_terms = ' AND '.join(search_params)

return ES(self._model_class.__name__).aggregate(
_aggregations_params=aggregations_params,
_raw_terms=_raw_terms,
**self._query_params
)


def key_error_view(context, request):
return JHTTPBadRequest("Bad or missing param '%s'" % context.message)

Expand Down
50 changes: 50 additions & 0 deletions tests/test_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,56 @@ def test_do_count_no_index(self, mock_count):
assert val == 0
mock_count.assert_called_once_with(foo=1)

def test_aggregate_no_aggregations(self):
obj = es.ES('Foo', 'foondex')
with pytest.raises(Exception) as ex:
obj.aggregate(foo='bar')
assert 'Missing _aggregations_params' in str(ex.value)

@patch('nefertari.elasticsearch.ES.build_search_params')
@patch('nefertari.elasticsearch.ES.api.search')
def test_aggregation(self, mock_search, mock_build):
mock_search.return_value = {'aggregations': {'foo': 1}}
mock_build.return_value = {
'size': 1, 'from_': 2, 'sort': 3,
'body': {'query': 'query1'}
}
obj = es.ES('Foo', 'foondex')
resp = obj.aggregate(_aggregations_params={'zoo': 5}, param1=6)
assert resp == {'foo': 1}
mock_build.assert_called_once_with({'_limit': 0, 'param1': 6})
mock_search.assert_called_once_with(
search_type='count',
body={'aggregations': {'zoo': 5}, 'query': 'query1'},
)

@patch('nefertari.elasticsearch.ES.build_search_params')
@patch('nefertari.elasticsearch.ES.api.search')
def test_aggregation_nothing_returned(self, mock_search, mock_build):
mock_search.return_value = {}
mock_build.return_value = {
'size': 1, 'from_': 2, 'sort': 3,
'body': {'query': 'query1'}
}
obj = es.ES('Foo', 'foondex')
with pytest.raises(JHTTPNotFound) as ex:
obj.aggregate(_aggregations_params={'zoo': 5}, param1=6)
assert 'No aggregations returned from ES' in str(ex.value)

@patch('nefertari.elasticsearch.ES.build_search_params')
@patch('nefertari.elasticsearch.ES.api.search')
def test_aggregation_index_not_exists(self, mock_search, mock_build):
mock_search.side_effect = es.IndexNotFoundException()
mock_build.return_value = {
'size': 1, 'from_': 2, 'sort': 3,
'body': {'query': 'query1'}
}
obj = es.ES('Foo', 'foondex')
with pytest.raises(JHTTPNotFound) as ex:
obj.aggregate(_aggregations_params={'zoo': 5}, param1=6,
__raise_on_empty=True)
assert 'Aggregation failed: Index does not exist' in str(ex.value)

@patch('nefertari.elasticsearch.ES.build_search_params')
@patch('nefertari.elasticsearch.ES.do_count')
def test_get_collection_count_without_body(self, mock_count, mock_build):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/test_dictset.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_process_int_param_value_err(self):
dset = dictset({'boo': 'a'})
with pytest.raises(ValueError) as ex:
dset.process_int_param('boo')
assert 'boo must be a decimal' == str(ex.value)
assert 'boo must be an integer' == str(ex.value)

def test_process_int_param_default(self):
dset = dictset({'boo': '1'})
Expand Down
22 changes: 22 additions & 0 deletions tests/test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,25 @@ def test_issequence(self):
assert not utils.issequence('asd')
assert not utils.issequence(1)
assert not utils.issequence(2.0)

def test_merge_dicts(self):
dict1 = {'a': {'b': {'c': 1}}}
dict2 = {'a': {'d': 2}, 'q': 3}
merged = utils.merge_dicts(dict1, dict2)
assert merged == {
'a': {
'b': {'c': 1},
'd': 2,
},
'q': 3
}

def test_str2dict(self):
assert utils.str2dict('foo.bar') == {'foo': {'bar': {}}}

def test_str2dict_value(self):
assert utils.str2dict('foo.bar', value=2) == {'foo': {'bar': 2}}

def test_str2dict_separator(self):
assert utils.str2dict('foo:bar', value=2, separator=':') == {
'foo': {'bar': 2}}

0 comments on commit 0a2e58e

Please sign in to comment.