Skip to content

Commit

Permalink
rest: introduce generic permission framework
Browse files Browse the repository at this point in the history
Until now, the REST API didn't really have a good permission system.  In
particular, it did not know about user groups.  We will need to allow
the "importer" group to add messages to any project, so it's time to
improve the permissions.  Similar to the old API, all the knowledge of
permissions is encapsulated in one class, in this case a DRF Permission
subclass.  The new class replaces the old IsAdminUserOrReadOnly and
IsMaintainerUserOrReadOnly permissions.
  • Loading branch information
bonzini committed May 11, 2018
1 parent ebfacd1 commit 9093169
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 19 deletions.
77 changes: 58 additions & 19 deletions api/rest.py
Expand Up @@ -30,23 +30,52 @@

# patchew-specific permission classes

class IsAdminUserOrReadOnly(permissions.BasePermission):
class PatchewPermission(permissions.BasePermission):
"""
Allows access only to admin users.
Generic code to lookup for permissions based on message and project
objects. If the view has a "project" property, it should return an
api.models.Project, and has_permission will check that property too.
Subclasses can override the methods, or specify a set of groups that
are granted authorization independent of object permissions.
"""

allowed_groups = ()

def is_superuser(self, request):
return request.user and request.user.is_superuser

def has_project_permission(self, request, view, obj):
return obj.maintained_by(request.user)

def has_message_permission(self, request, view, obj):
return obj.project.maintained_by(request.user)

def has_group_permission(self, request, view):
for grp in request.user.groups.all():
if grp.name in self.allowed_groups:
return True
return False

def has_generic_permission(self, request, view):
return (request.method in permissions.SAFE_METHODS) or \
self.is_superuser(request) or \
self.has_group_permission(request, view)

def has_permission(self, request, view):
return request.method in permissions.SAFE_METHODS or \
(request.user and request.user.is_superuser)
return self.has_generic_permission(request, view) or \
(hasattr(view, 'project') and view.project and \
self.has_project_permission(request, view, view.project))

class IsMaintainerOrReadOnly(permissions.BasePermission):
"""
Allows access only to admin users or maintainers.
"""
def has_object_permission(self, request, view, obj):
if isinstance(obj, Message):
obj = obj.project
return request.method in permissions.SAFE_METHODS or \
obj.maintained_by(request.user)
return self.has_generic_permission(request, view) or \
(isinstance(obj, Message) and \
self.has_message_permission(request, view, obj)) or \
(isinstance(obj, Project) and \
self.has_project_permission(request, view, obj))

class ImportPermission(PatchewPermission):
allowed_groups = ('importers',)

# pluggable field for plugin support

Expand Down Expand Up @@ -87,7 +116,7 @@ class Meta:
class UsersViewSet(viewsets.ModelViewSet):
queryset = User.objects.all().order_by('id')
serializer_class = UserSerializer
permission_classes = (IsAdminUserOrReadOnly,)
permission_classes = (PatchewPermission,)

# Projects

Expand All @@ -110,7 +139,7 @@ class Meta:
class ProjectsViewSet(viewsets.ModelViewSet):
queryset = Project.objects.all().order_by('id')
serializer_class = ProjectSerializer
permission_classes = (IsMaintainerOrReadOnly,)
permission_classes = (PatchewPermission,)

# Common classes for series and messages

Expand Down Expand Up @@ -152,7 +181,7 @@ def create(self, validated_data):
class BaseMessageViewSet(mixins.ListModelMixin, viewsets.GenericViewSet):
serializer_class = BaseMessageSerializer
queryset = Message.objects.all()
permission_classes = ()
permission_classes = (ImportPermission,)
lookup_field = 'message_id'
lookup_value_regex = '[^/]+'

Expand All @@ -161,11 +190,21 @@ class ProjectMessagesViewSetMixin(mixins.RetrieveModelMixin):
def get_queryset(self):
return self.queryset.filter(project=self.kwargs['projects_pk'])

def get_serializer_context(self):
@property
def project(self):
if hasattr(self, '__project'):
return self.__project
try:
return {'project': Project.objects.get(id=self.kwargs['projects_pk']), 'request': self.request}
except:
self.__project = Project.objects.get(id=self.kwargs['projects_pk'])
except:
self.__project = None
return self.__project

def get_serializer_context(self):
if self.project is None:
return Http404
return {'project': self.project, 'request': self.request}

# Series

class ReplySerializer(BaseMessageSerializer):
Expand Down Expand Up @@ -248,7 +287,6 @@ class SeriesViewSet(BaseMessageViewSet):
queryset = Message.objects.filter(is_series_head=True).order_by('-last_reply_date')
filter_backends = (PatchewSearchFilter,)
search_fields = (SEARCH_PARAM,)
permission_classes = (IsMaintainerOrReadOnly,)


class ProjectSeriesViewSet(ProjectMessagesViewSetMixin,
Expand Down Expand Up @@ -376,6 +414,7 @@ class ResultSerializerFull(ResultSerializer):
class ResultsViewSet(viewsets.ViewSet, generics.GenericAPIView):
lookup_field = 'name'
lookup_value_regex = '[^/]+'
permission_classes = (PatchewPermission,)

def get_serializer_class(self, *args, **kwargs):
if self.lookup_field in self.kwargs:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_rest.py
Expand Up @@ -84,10 +84,18 @@ def test_project(self):
self.assertEquals(resp.data['mailing_list'], "qemu-block@nongnu.org")
self.assertEquals(resp.data['parent_project'], self.PROJECT_BASE)

def test_project_post_no_login(self):
data = {
'name': 'keycodemapdb',
}
resp = self.api_client.post(self.REST_BASE + 'projects/', data=data)
self.assertEquals(resp.status_code, 403)

def test_project_post_minimal(self):
data = {
'name': 'keycodemapdb',
}
self.api_client.login(username=self.user, password=self.password)
resp = self.api_client.post(self.REST_BASE + 'projects/', data=data)
self.assertEquals(resp.status_code, 201)
self.assertEquals(resp.data['resource_uri'].startswith(self.REST_BASE + 'projects/'), True)
Expand All @@ -97,6 +105,7 @@ def test_project_post_minimal(self):
self.assertEquals(resp.data['name'], data['name'])

def test_project_post(self):
self.api_client.login(username=self.user, password=self.password)
data = {
'name': 'keycodemapdb',
'mailing_list': 'qemu-devel@nongnu.org',
Expand Down Expand Up @@ -267,6 +276,7 @@ def test_create_message(self):
dp = self.get_data_path("0022-another-simple-patch.json.gz")
with open(dp, "r") as f:
data = f.read()
self.api_client.login(username=self.user, password=self.password)
resp = self.api_client.post(self.PROJECT_BASE + "messages/", data, content_type='application/json')
self.assertEqual(resp.status_code, 201)
resp_get = self.api_client.get(self.PROJECT_BASE + "messages/20171023201055.21973-11-andrew.smirnov@gmail.com/")
Expand All @@ -278,6 +288,7 @@ def test_create_text_message(self):
dp = self.get_data_path("0004-multiple-patch-reviewed.mbox.gz")
with open(dp, "r") as f:
data = f.read()
self.api_client.login(username=self.user, password=self.password)
resp = self.api_client.post(self.PROJECT_BASE + "messages/", data, content_type='message/rfc822')
self.assertEqual(resp.status_code, 201)
resp_get = self.api_client.get(self.PROJECT_BASE + "messages/1469192015-16487-1-git-send-email-berrange@redhat.com/")
Expand Down

0 comments on commit 9093169

Please sign in to comment.