Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
Merge pull request #25 from postatum/93492110_improve_coverage
Browse files Browse the repository at this point in the history
Add tests for elasticsearch and some authentication tests
  • Loading branch information
jstoiko committed May 6, 2015
2 parents 74756cc + 010f1cc commit b386bd6
Show file tree
Hide file tree
Showing 19 changed files with 1,042 additions and 489 deletions.
55 changes: 35 additions & 20 deletions nefertari/authentication/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyramid.security import authenticated_userid, forget

from nefertari.json_httpexceptions import *
from nefertari import engine as eng
from nefertari import engine
from nefertari.utils import dictset

log = logging.getLogger(__name__)
Expand All @@ -29,6 +29,18 @@ class AuthModelDefaultMixin(object):
All implemented methods must me class methods.
"""
@classmethod
def get_resource(self, *args, **kwargs):
return super(AuthModelDefaultMixin, self).get_resource(*args, **kwargs)

@classmethod
def id_field(self, *args, **kwargs):
return super(AuthModelDefaultMixin, self).id_field(*args, **kwargs)

@classmethod
def get_or_create(self, *args, **kwargs):
return super(AuthModelDefaultMixin, self).get_or_create(*args, **kwargs)

@classmethod
def is_admin(cls, user):
""" Determine if :user: is an admin. Is used by `apply_privacy` wrapper.
Expand All @@ -46,8 +58,9 @@ def token_credentials(cls, username, request):
except Exception as ex:
log.error(unicode(ex))
forget(request)
if user:
return user.api_key.token
else:
if user:
return user.api_key.token

@classmethod
def groups_by_token(cls, username, token, request):
Expand All @@ -61,8 +74,10 @@ def groups_by_token(cls, username, token, request):
except Exception as ex:
log.error(unicode(ex))
forget(request)
if user and user.api_key.token == token:
return ['g:%s' % g for g in user.groups]
return
else:
if user and user.api_key.token == token:
return ['g:%s' % g for g in user.groups]

@classmethod
def authenticate_by_password(cls, params):
Expand All @@ -73,14 +88,14 @@ def authenticate_by_password(cls, params):
def verify_password(user, password):
return crypt.check(user.password, password)

success = False
user = None
login = params['login'].lower().strip()
key = 'email' if '@' in login else 'username'
try:
user = cls.get_resource(**{key: login})
except Exception as ex:
log.error(unicode(ex))
success = False
user = None

if user:
password = params.get('password', None)
Expand Down Expand Up @@ -141,24 +156,24 @@ def authuser_by_name(cls, request):
return cls.get_resource(username=username)


class AuthUser(AuthModelDefaultMixin, eng.BaseDocument):
class AuthUser(AuthModelDefaultMixin, engine.BaseDocument):
""" Class that is meant to be User class in Auth system.
Implements basic operations to support Pyramid Ticket-based and custom
ApiKey token-based authentication.
"""
__tablename__ = 'nefertari_authuser'

id = eng.IdField(primary_key=True)
username = eng.StringField(
id = engine.IdField(primary_key=True)
username = engine.StringField(
min_length=1, max_length=50, unique=True,
required=True, processors=[lower_strip])
email = eng.StringField(
email = engine.StringField(
unique=True, required=True, processors=[lower_strip])
password = eng.StringField(
password = engine.StringField(
min_length=3, required=True, processors=[crypt_password])
groups = eng.ListField(
item_type=eng.StringField,
groups = engine.ListField(
item_type=engine.StringField,
choices=['admin', 'user'], default=['user'])


Expand All @@ -182,7 +197,7 @@ def apikey_model(user_model):
be generated and with which ApiKey will have relationship.
"""
try:
return eng.get_document_cls('ApiKey')
return engine.get_document_cls('ApiKey')
except ValueError:
pass

Expand All @@ -194,17 +209,17 @@ def apikey_model(user_model):
user_model.__tablename__, user_model.id_field()])
fk_kwargs['ref_column_type'] = user_model.id_field_type()

class ApiKey(eng.BaseDocument):
class ApiKey(engine.BaseDocument):
__tablename__ = 'nefertari_apikey'

id = eng.IdField(primary_key=True)
token = eng.StringField(default=apikey_token)
user = eng.Relationship(
id = engine.IdField(primary_key=True)
token = engine.StringField(default=apikey_token)
user = engine.Relationship(
document=user_model.__name__,
uselist=False,
backref_name='api_key',
backref_uselist=False)
user_id = eng.ForeignKeyField(
user_id = engine.ForeignKeyField(
ref_document=user_model.__name__,
**fk_kwargs)

Expand Down
35 changes: 22 additions & 13 deletions nefertari/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from nefertari.utils import (
dictset, dict2obj, process_limit, split_strip)
from nefertari.json_httpexceptions import *
from nefertari.engine import ESJSONSerializer
from nefertari import engine

log = logging.getLogger(__name__)

Expand All @@ -33,10 +33,11 @@ def perform_request(self, *args, **kw):

return super(ESHttpConnection, self).perform_request(*args, **kw)
except Exception as e:
if e.status_code == 'N/A':
e.status_code = 400
status_code = e.status_code
if status_code == 'N/A':
status_code = 400
raise exception_response(
e.status_code,
status_code,
detail='elasticsearch error.',
extra=dict(data=e))

Expand All @@ -46,13 +47,17 @@ def includeme(config):
ES.setup(Settings)


def _bulk_body(body):
return ES.api.bulk(body=body)


def apply_sort(_sort):
_sort_param = []

if _sort:
for each in [e.strip() for e in _sort.split(',')]:
if each.startswith('-'):
_sort_param.append(each[1:]+':desc')
_sort_param.append(each[1:] + ':desc')
elif each.startswith('+'):
_sort_param.append(each[1:] + ':asc')
else:
Expand Down Expand Up @@ -117,7 +122,7 @@ def setup(cls, settings):
)

ES.api = elasticsearch.Elasticsearch(
hosts=hosts, serializer=ESJSONSerializer(),
hosts=hosts, serializer=engine.ESJSONSerializer(),
connection_class=ESHttpConnection, **params)
log.info('Including ElasticSearch. %s' % ES.settings)

Expand Down Expand Up @@ -154,8 +159,9 @@ def prep_bulk_documents(self, action, documents):
_docs = []
for doc in documents:
if not isinstance(doc, dict):
raise ValueError('document type must be `dict` not a '
'%s' % (type(doc)))
raise ValueError(
'Document type must be `dict` not a `{}`'.format(
type(doc).__name__))

if '_type' in doc:
_doc_type = self.src2type(doc['_type'])
Expand All @@ -176,7 +182,9 @@ def prep_bulk_documents(self, action, documents):
return _docs

def _bulk(self, action, documents, chunk_size=None):
chunk_size = chunk_size or self.chunk_size
if chunk_size is None:
chunk_size = self.chunk_size

if not documents:
log.debug('empty documents: %s' % self.doc_type)
return
Expand All @@ -198,10 +206,10 @@ def _bulk(self, action, documents, chunk_size=None):
# meta, document, meta, ...
self.process_chunks(
documents=body,
operation=lambda b: ES.api.bulk(body=b),
operation=_bulk_body,
chunk_size=chunk_size*2)
else:
log.warning('empty body')
log.warning('Empty body')

def index(self, documents, chunk_size=None):
""" Reindex all `document`. """
Expand Down Expand Up @@ -239,7 +247,8 @@ def delete(self, ids):
if not isinstance(ids, list):
ids = [ids]

self._bulk('delete', [{'id':_id, '_type': self.doc_type} for _id in ids])
documents = [{'id': _id, '_type': self.doc_type} for _id in ids]
self._bulk('delete', documents)

def get_by_ids(self, ids, **params):
if not ids:
Expand Down Expand Up @@ -364,7 +373,7 @@ def get_collection(self, **params):
documents = _ESDocs()

for da in data['hits']['hits']:
_d = da['fields'] if 'fields' in _params else da['_source']
_d = da['fields'] if _fields else da['_source']
_d['_score'] = da['_score']
documents.append(dict2obj(_d))

Expand Down
124 changes: 0 additions & 124 deletions nefertari/utility_views.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from pyramid.view import view_config

import nefertari
from nefertari.json_httpexceptions import *
from nefertari import wrappers
from nefertari.view import BaseView


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -32,124 +29,3 @@ def __call__(self):
'origin, x-requested-with, content-type'

return request.response


class EngineView(BaseView):
def __init__(self, context, request):
super(EngineView, self).__init__(context, request)
self._params.process_int_param('_limit', 20)

def add_self(**kwargs):
result = kwargs['result']
request = kwargs['request']

try:
for each in result['data']:
each['self'] = "%s?id=%s" % (
request.current_route_url(), each['id'])
except KeyError:
pass

return result

self.add_after_call('show', add_self)
# Wrap in a dict so it acts as "index"
self.add_after_call('show', wrappers.wrap_in_dict(self.request), pos=0)

def show(self, id):
return self._model_class.get_collection(**self._params)

def delete(self, id):
objs = self._model_class.get_collection(**self._params)

if self.needs_confirmation():
return objs

count = self._model_class.count(objs)
self._model_class._delete_many(objs)
return JHTTPOk("Deleted %s %s objects" % (count, id))


LOGNAME_MAP = dict(
NOTSET=logging.NOTSET,
DEBUG=logging.DEBUG,
INFO=logging.INFO,
WARNING=logging.WARNING,
ERROR=logging.ERROR,
CRITICAL=logging.CRITICAL,
)


class LogLevelView(BaseView):
def __init__(self, *arg, **kw):
super(LogLevelView, self).__init__(*arg, **kw)

self.name = self.request.matchdict.get('id', 'root')
if self.name == 'root':
self.log = logging.getLogger()
else:
self.log = logging.getLogger(self.name)

def setlevel(self, level):
log.info("SET logger '%s' to '%s'" % (self.name, level))
self.log.setLevel(LOGNAME_MAP[level])

def show(self, id=None):
return dict(
logger=self.name,
level=logging.getLevelName(self.log.getEffectiveLevel())
)

def update(self, id=None):
level = self._params.keys()[0].upper()
self.setlevel(level)
return JHTTPOk()

def delete(self, id=None):
self.setlevel('INFO')
return JHTTPOk()


class SettingsView(BaseView):
settings = None
__orig = None

def __init__(self, *arg, **kw):
super(SettingsView, self).__init__(*arg, **kw)
assert(self.settings)
self.__orig = self.settings.copy()

def index(self):
return dict(self.settings)

def show(self, id):
return self.settings[id]

def update(self, id):
self.settings[id] = self._params['value']
return JHTTPOk()

def create(self):
key = self._params['key']
value = self._params['value']

self.settings[key] = value

return JHTTPCreated()

def delete(self, id):
if 'reset' in self._params:
self.settings[id] = self.request.registry.settings[id]
else:
self.settings.pop(id, None)

return JHTTPOk()

def delete_many(self):
if self.needs_confirmation():
return self.settings.keys()

for name, val in self.settings.items():
self.settings[name] = self.__orig[name]

return JHTTPOk("Reset the settings to original values")
2 changes: 0 additions & 2 deletions nefertari/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from nefertari.utils.data import *
from nefertari.utils.dictset import *
from nefertari.utils.utils import *
from nefertari.utils.request import *

_requests = Requests
_split = split_strip

0 comments on commit b386bd6

Please sign in to comment.