Skip to content

Commit

Permalink
Merge branch 'master' into issues/608
Browse files Browse the repository at this point in the history
  • Loading branch information
kholdaway committed Feb 9, 2018
2 parents 127a988 + 88b148a commit 2ebe67b
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 1 deletion.
6 changes: 6 additions & 0 deletions quipucords/api/credential/tests_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ def test_hostcred_update_double(self):
format='json')
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)

def test_hostcred_get_bad_id(self):
"""Tests the get view set of the Credential API with a bad id."""
url = reverse('cred-detail', args=('string',))
resp = self.client.get(url, format='json')
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)

def test_hostcred_delete_view(self):
"""Tests the delete view set of the Credential API."""
cred = Credential(name='cred2', username='user2',
Expand Down
10 changes: 10 additions & 0 deletions quipucords/api/credential/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@

import os
from django.shortcuts import get_object_or_404
from django.utils.translation import ugettext as _
from rest_framework import status
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from rest_framework.authentication import (TokenAuthentication,
SessionAuthentication)
from rest_framework.permissions import IsAuthenticated
from rest_framework.filters import OrderingFilter
from rest_framework.serializers import ValidationError
from django_filters.rest_framework import (DjangoFilterBackend,
FilterSet)
from filters import mixins
from api.filters import ListFilter
from api.serializers import CredentialSerializer
from api.models import Credential, Source
import api.messages as messages
from api.common.util import is_int

IDENTIFIER_KEY = 'id'
NAME_KEY = 'name'
Expand Down Expand Up @@ -123,6 +127,12 @@ def create(self, request, *args, **kwargs):

def retrieve(self, request, pk=None): # pylint: disable=unused-argument
"""Get a host credential."""
if not pk or (pk and not is_int(pk)):
error = {
'id': [_(messages.COMMON_ID_INV)]
}
raise ValidationError(error)

host_cred = get_object_or_404(self.queryset, pk=pk)
serializer = CredentialSerializer(host_cred)
cred = format_credential(serializer.data)
Expand Down
1 change: 1 addition & 0 deletions quipucords/api/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
COMMON_CHOICE_STR = 'Must be a string. Valid values are %s.'
COMMON_CHOICE_BLANK = 'This field may not be blank. Valid values are %s.'
COMMON_CHOICE_INV = '%s, is an invalid choice. Valid values are %s.'
COMMON_ID_INV = 'The id must be an integer.'

# report messages
REPORT_GROUP_COUNT_FILTER = 'The group_count filter cannot be used with ' \
Expand Down
9 changes: 9 additions & 0 deletions quipucords/api/report/tests_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@ def test_get_fact_collection_404(self):
response = self.client.get(url, {'fact_collection_id': 2})
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_get_fact_collection_bad_id(self):
"""Fail to get a report for missing collection."""
url = '/api/v1/reports/'

# Query API
response = self.client.get(url, {'fact_collection_id': 'string'})
self.assertEqual(response.status_code,
status.HTTP_400_BAD_REQUEST)

def test_group_count_400_invalid_field(self):
"""Fail to get report with invalid field for group_count."""
url = '/api/v1/reports/'
Expand Down
7 changes: 7 additions & 0 deletions quipucords/api/report/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from api.models import SystemFingerprint
from api.serializers import FingerprintSerializer
import api.messages as messages
from api.common.util import is_int


# Get an instance of a logger
Expand Down Expand Up @@ -73,6 +74,12 @@ def get(self, request):
fact_collection_id)
return Response(collection_report_list)
else:
if not is_int(fact_collection_id):
error = {
'fact_collection_id': [_(messages.COMMON_ID_INV)]
}
raise ValidationError(error)

report = self.build_report(fact_collection_id,
request.query_params)
if report is not None:
Expand Down
30 changes: 30 additions & 0 deletions quipucords/api/scanjob/tests_scanjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,12 @@ def test_retrieve(self, start_scan):
self.assertEqual(
sources, [{'id': 1, 'name': 'source1', 'source_type': 'network'}])

def test_retrieve_bad_id(self):
"""Get ScanJob details by bad primary key."""
url = reverse('scanjob-detail', args=('string',))
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

@patch('api.scanjob.view.start_scan', side_effect=dummy_start)
def test_details(self, start_scan):
"""Get ScanJob result details by primary key."""
Expand Down Expand Up @@ -608,6 +614,14 @@ def test_pause_bad_state(self, start_scan):
self.assertEqual(response.status_code,
status.HTTP_400_BAD_REQUEST)

def test_pause_bad_id(self):
"""Pause a scanjob with bad id."""
url = reverse('scanjob-detail', args=('string',))
pause_url = '{}pause/'.format(url)
response = self.client.put(pause_url, format='json')
self.assertEqual(response.status_code,
status.HTTP_400_BAD_REQUEST)

@patch('api.scanjob.view.start_scan', side_effect=dummy_start)
def test_cancel(self, start_scan):
"""Cancel a scanjob."""
Expand All @@ -625,6 +639,14 @@ def test_cancel(self, start_scan):
self.assertEqual(response.status_code,
status.HTTP_200_OK)

def test_cancel_bad_id(self):
"""Cancel a scanjob with bad id."""
url = reverse('scanjob-detail', args=('string',))
pause_url = '{}cancel/'.format(url)
response = self.client.put(pause_url, format='json')
self.assertEqual(response.status_code,
status.HTTP_400_BAD_REQUEST)

@patch('api.scanjob.view.start_scan', side_effect=dummy_start)
def test_restart_bad_state(self, start_scan):
"""Restart a scanjob."""
Expand All @@ -642,6 +664,14 @@ def test_restart_bad_state(self, start_scan):
self.assertEqual(response.status_code,
status.HTTP_400_BAD_REQUEST)

def test_restart_bad_id(self):
"""Restart a scanjob with bad id."""
url = reverse('scanjob-detail', args=('string',))
pause_url = '{}restart/'.format(url)
response = self.client.put(pause_url, format='json')
self.assertEqual(response.status_code,
status.HTTP_400_BAD_REQUEST)

def test_expand_scanjob(self):
"""Test view expand_scanjob."""
scan_job = ScanJob(scan_type=ScanTask.SCAN_TYPE_INSPECT)
Expand Down
23 changes: 23 additions & 0 deletions quipucords/api/scanjob/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
SessionAuthentication)
from rest_framework.permissions import IsAuthenticated
from rest_framework.filters import OrderingFilter
from rest_framework.serializers import ValidationError
from django_filters.rest_framework import (DjangoFilterBackend, FilterSet)
from django.http import JsonResponse
from django.shortcuts import get_object_or_404
from django.utils.translation import ugettext as _
import api.messages as messages
from api.common.util import is_int
from api.models import (ScanTask, ScanJob, Source,
ConnectionResults, InspectionResults)
from api.serializers import (ScanJobSerializer,
Expand Down Expand Up @@ -236,6 +238,12 @@ def list(self, request): # pylint: disable=unused-argument
# pylint: disable=unused-argument, arguments-differ
def retrieve(self, request, pk=None):
"""Get a scan job."""
if not pk or (pk and not is_int(pk)):
error = {
'id': [_(messages.COMMON_ID_INV)]
}
raise ValidationError(error)

scan = get_object_or_404(self.queryset, pk=pk)
serializer = ScanJobSerializer(scan)
json_scan = serializer.data
Expand Down Expand Up @@ -272,6 +280,11 @@ def results(self, request, pk=None):
@detail_route(methods=['put'])
def pause(self, request, pk=None):
"""Pause the running scan."""
if not pk or (pk and not is_int(pk)):
error = {
'id': [_(messages.COMMON_ID_INV)]
}
raise ValidationError(error)
scan = get_object_or_404(self.queryset, pk=pk)
if scan.status == ScanTask.RUNNING:
scan.pause()
Expand All @@ -290,6 +303,11 @@ def pause(self, request, pk=None):
@detail_route(methods=['put'])
def cancel(self, request, pk=None):
"""Cancel the running scan."""
if not pk or (pk and not is_int(pk)):
error = {
'id': [_(messages.COMMON_ID_INV)]
}
raise ValidationError(error)
scan = get_object_or_404(self.queryset, pk=pk)
if (scan.status == ScanTask.COMPLETED or
scan.status == ScanTask.FAILED or
Expand All @@ -307,6 +325,11 @@ def cancel(self, request, pk=None):
@detail_route(methods=['put'])
def restart(self, request, pk=None):
"""Restart a paused scan."""
if not pk or (pk and not is_int(pk)):
error = {
'id': [_(messages.COMMON_ID_INV)]
}
raise ValidationError(error)
scan = get_object_or_404(self.queryset, pk=pk)
if scan.status == ScanTask.PAUSED:
scan.restart()
Expand Down
7 changes: 7 additions & 0 deletions quipucords/api/source/tests_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,13 @@ def test_retrieve(self):
self.assertIn('hosts', response_json)
self.assertEqual(response_json['hosts'][0], '1.2.3.4')

def test_retrieve_bad_id(self):
"""Get details on a specific Source by bad primary key."""
url = reverse('source-detail', args=('string',))
response = self.client.get(url, format='json')
self.assertEqual(response.status_code,
status.HTTP_400_BAD_REQUEST)

# We don't have to test that update validates fields correctly
# because the validation code is shared between create and update.
def test_update(self):
Expand Down
19 changes: 18 additions & 1 deletion quipucords/api/source/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,20 @@

import os
from django.shortcuts import get_object_or_404
from django.utils.translation import ugettext as _
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from rest_framework.authentication import (TokenAuthentication,
SessionAuthentication)
from rest_framework.permissions import IsAuthenticated
from rest_framework.filters import OrderingFilter
from rest_framework.serializers import ValidationError
from django_filters.rest_framework import (DjangoFilterBackend, FilterSet)
from api.filters import ListFilter
from api.serializers import SourceSerializer
from api.models import Source, Credential
import api.messages as messages
from api.common.util import is_int


CREDENTIALS_KEY = 'credentials'
Expand Down Expand Up @@ -101,14 +105,27 @@ def create(self, request, *args, **kwargs):

# Modify json for response
json_source = response.data
get_object_or_404(self.queryset, pk=json_source['id'])
source_id = json_source.get('id')
if not source_id or (source_id and not isinstance(source_id, int)):
error = {
'id': [_(messages.COMMON_ID_INV)]
}
raise ValidationError(error)

get_object_or_404(self.queryset, pk=source_id)

# Create expanded host cred JSON
expand_credential(json_source)
return response

def retrieve(self, request, pk=None): # pylint: disable=unused-argument
"""Get a source."""
if not pk or (pk and not is_int(pk)):
error = {
'id': [_(messages.COMMON_ID_INV)]
}
raise ValidationError(error)

source = get_object_or_404(self.queryset, pk=pk)
serializer = SourceSerializer(source)
json_source = serializer.data
Expand Down

0 comments on commit 2ebe67b

Please sign in to comment.