Skip to content

Commit

Permalink
Update code to handle scans via signal. Closes #177.
Browse files Browse the repository at this point in the history
  • Loading branch information
chambridge committed Nov 6, 2017
1 parent 56b4a48 commit 95eec5e
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[report]
omit = */test*.py,*/manage.py,*/apps.py,*/wsgi.py,*/es_receivers.py,*/settings.py,*/migrations/*
omit = */test*.py,*/manage.py,*/apps.py,*/wsgi.py,*/es_receiver.py,*/settings.py,*/migrations/*
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ PYDIRS = quipucords

BINDIR = bin

OMIT_PATTERNS = */test*.py,*/manage.py,*/apps.py,*/wsgi.py,*/es_receivers.py,*/settings.py,*/migrations/*
OMIT_PATTERNS = */test*.py,*/manage.py,*/apps.py,*/wsgi.py,*/es_receiver.py,*/settings.py,*/migrations/*

help:
@echo "Please use \`make <target>' where <target> is one of:"
Expand Down
4 changes: 2 additions & 2 deletions quipucords/api/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ApiConfig(AppConfig):
def ready(self):
"""Mark server ready."""
# pylint: disable=W0612
import api.fact_collection_receiver # noqa: F401
import api.signals.fact_collection_receiver # noqa: F401

if settings.USE_ELASTICSEARCH == 'True':
import api.es_receivers # noqa: F401
import api.signals.es_receiver # noqa: F401
24 changes: 7 additions & 17 deletions quipucords/api/scanjob/tests_scanjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""Test the API application."""

import json
from unittest.mock import patch
from django.test import TestCase
from django.core.urlresolvers import reverse
from rest_framework import status
Expand Down Expand Up @@ -60,8 +59,7 @@ def create_expect_201(self, data):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
return response.json()

@patch('api.scanjob.view.DiscoveryScanner.start', side_effect=dummy_start)
def test_successful_create(self, start): # pylint: disable=unused-argument
def test_successful_create(self):
"""A valid create request should succeed."""
data = {'profile': self.network_profile.id,
'scan_type': 'discovery'}
Expand All @@ -79,9 +77,7 @@ def test_create_invalid_scan_type(self):
'scan_type': 'foo'}
self.create_expect_400(data)

@patch('api.scanjob.view.DiscoveryScanner.start', side_effect=dummy_start)
# pylint: disable=unused-argument
def test_create_default_host_type(self, start):
def test_create_default_host_type(self):
"""A valid create request should succeed with defaulted type."""
data = {'profile': self.network_profile.id}
response = self.create_expect_201(data)
Expand All @@ -100,8 +96,7 @@ def test_create_invalid_forks(self):
'max_concurrency': -5}
self.create_expect_400(data)

@patch('api.scanjob.view.DiscoveryScanner.start', side_effect=dummy_start)
def test_list(self, start): # pylint: disable=unused-argument
def test_list(self):
"""List all ScanJob objects."""
data_default = {'profile': self.network_profile.id}
data_discovery = {'profile': self.network_profile.id,
Expand Down Expand Up @@ -133,8 +128,7 @@ def test_list(self, start): # pylint: disable=unused-argument
'fact_collection_id': None}]
self.assertEqual(content, expected)

@patch('api.scanjob.view.DiscoveryScanner.start', side_effect=dummy_start)
def test_retrieve(self, start): # pylint: disable=unused-argument
def test_retrieve(self):
"""Get details on a specific ScanJob by primary key."""
data_discovery = {'profile': self.network_profile.id,
'scan_type': 'discovery'}
Expand All @@ -148,9 +142,7 @@ def test_retrieve(self, start): # pylint: disable=unused-argument

self.assertEqual(profile, {'id': 1, 'name': 'profile1'})

@patch('api.scanjob.view.DiscoveryScanner.start', side_effect=dummy_start)
# pylint: disable=unused-argument
def test_update_not_allowed(self, start):
def test_update_not_allowed(self):
"""Completely update a NetworkProfile."""
data_discovery = {'profile': self.network_profile.id,
'scan_type': 'discovery'}
Expand All @@ -166,8 +158,7 @@ def test_update_not_allowed(self, start):
self.assertEqual(response.status_code,
status.HTTP_405_METHOD_NOT_ALLOWED)

@patch('api.scanjob.view.DiscoveryScanner.start', side_effect=dummy_start)
def test_partial_update(self, start): # pylint: disable=unused-argument
def test_partial_update(self):
"""Partially update a ScanJob is not supported."""
data_discovery = {'profile': self.network_profile.id,
'scan_type': 'discovery'}
Expand All @@ -182,8 +173,7 @@ def test_partial_update(self, start): # pylint: disable=unused-argument
self.assertEqual(response.status_code,
status.HTTP_405_METHOD_NOT_ALLOWED)

@patch('api.scanjob.view.DiscoveryScanner.start', side_effect=dummy_start)
def test_delete(self, start): # pylint: disable=unused-argument
def test_delete(self):
"""Delete a ScanJob is not supported."""
data_discovery = {'profile': self.network_profile.id,
'scan_type': 'discovery'}
Expand Down
22 changes: 6 additions & 16 deletions quipucords/api/scanjob/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from rest_framework.decorators import detail_route
from django.core.urlresolvers import reverse
from django.shortcuts import get_object_or_404
from api.models import ScanJob, NetworkProfile, ScanJobResults
from api.models import ScanJob, ScanJobResults
from api.serializers import (ScanJobSerializer,
ScanJobResultsSerializer,
ResultsSerializer,
ResultKeyValueSerializer)
from scanner.discovery import DiscoveryScanner
from scanner.host import HostScanner
from api.signals.scanjob_signal import start_scan


PROFILE_KEY = 'profile'
Expand Down Expand Up @@ -95,19 +94,10 @@ def create(self, request, *args, **kwargs):
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
scanjob = serializer.data
scan_type = scanjob['scan_type']
scanjob_id = scanjob['id']
scanjob_profile_id = scanjob['profile']
scanjob_obj = ScanJob.objects.get(pk=scanjob_id)
profile = NetworkProfile.objects.get(pk=scanjob_profile_id)
if scan_type == ScanJob.DISCOVERY:
scan = DiscoveryScanner(scanjob_obj, profile)
scan.start()
else:
fact_endpoint = request.build_absolute_uri(reverse('facts-list'))
scan = HostScanner(scanjob_obj, profile, fact_endpoint)
scan.start()
scanjob_obj = ScanJob.objects.get(pk=serializer.data['id'])
fact_endpoint = request.build_absolute_uri(reverse('facts-list'))
start_scan.send(sender=self.__class__, instance=scanjob_obj,
fact_endpoint=fact_endpoint)

return Response(serializer.data, status=status.HTTP_201_CREATED,
headers=headers)
Expand Down
1 change: 1 addition & 0 deletions quipucords/api/signals/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# noqa
File renamed without changes.
File renamed without changes.
45 changes: 45 additions & 0 deletions quipucords/api/signals/scanjob_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright (c) 2017 Red Hat, Inc.
#
# This software is licensed to you under the GNU General Public License,
# version 3 (GPLv3). There is NO WARRANTY for this software, express or
# implied, including the implied warranties of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv3
# along with this software; if not, see
# https://www.gnu.org/licenses/gpl-3.0.txt.
#

"""Signal manager to handle scan triggering."""

import django.dispatch
from api.models import ScanJob
from scanner.discovery import DiscoveryScanner
from scanner.host import HostScanner


def handle_scan(sender, instance, fact_endpoint, **kwargs):
"""Handle incoming scan.
:param sender: Class that was saved
:param instance: ScanJob that was saved
:param fact_endpoint: The API endpoint to send collect fact to
:param kwargs: Other args
:returns: None
"""
# pylint: disable=unused-argument
if kwargs.get('created', False):
# nothing need for an existing scan.
return

if instance.scan_type == ScanJob.DISCOVERY:
scan = DiscoveryScanner(instance)
scan.start()
else:
scan = HostScanner(instance, fact_endpoint)
scan.start()


# pylint: disable=C0103
start_scan = django.dispatch.Signal(providing_args=['instance',
'fact_endpoint'])
start_scan.connect(handle_scan)
2 changes: 1 addition & 1 deletion quipucords/quipucords/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
'handlers': ['console'],
'level': os.getenv('DJANGO_LOG_LEVEL', 'DEBUG'),
},
'api.es_receivers': {
'api.signals.es_receiver': {
'handlers': ['console'],
'level': os.getenv('DJANGO_LOG_LEVEL', 'DEBUG'),
},
Expand Down
3 changes: 2 additions & 1 deletion quipucords/scanner/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ class DiscoveryScanner(Thread):
failures (host/ip).
"""

def __init__(self, scanjob, network_profile):
def __init__(self, scanjob):
"""Create discovery scanner."""
Thread.__init__(self)
self.scanjob = scanjob
network_profile = scanjob.profile
serializer = NetworkProfileSerializer(network_profile)
self.network_profile = serializer.data
self.scan_results = ScanJobResults(scan_job=self.scanjob)
Expand Down
4 changes: 2 additions & 2 deletions quipucords/scanner/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class HostScanner(DiscoveryScanner):
reachable. Collects the associated facts for the scanned systems
"""

def __init__(self, scanjob, network_profile, fact_endpoint):
def __init__(self, scanjob, fact_endpoint):
"""Create host scanner."""
DiscoveryScanner.__init__(self, scanjob, network_profile)
DiscoveryScanner.__init__(self, scanjob)
self.fact_endpoint = fact_endpoint

# pylint: disable=too-many-locals
Expand Down
4 changes: 2 additions & 2 deletions quipucords/scanner/tests_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,14 @@ def test_discovery(self, mock_connect):
"""Test running a discovery scan with mocked connection."""
expected = ([('1.2.3.4', {'name': 'cred1'})], [])
mock_connect.return_value = expected
scanner = DiscoveryScanner(self.scanjob, self.network_profile)
scanner = DiscoveryScanner(self.scanjob)
conn_dict = scanner.run()
mock_connect.assert_called()
self.assertEqual(conn_dict, {'1.2.3.4': {'name': 'cred1'}})

def test_store_discovery_success(self):
"""Test running a discovery scan with mocked connection."""
scanner = DiscoveryScanner(self.scanjob, self.network_profile)
scanner = DiscoveryScanner(self.scanjob)
hc_serializer = HostCredentialSerializer(self.cred)
cred = hc_serializer.data
connected = [('1.2.3.4', cred)]
Expand Down
12 changes: 4 additions & 8 deletions quipucords/scanner/tests_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def setUp(self):

def test_store_host_scan_success(self):
"""Test success storage."""
scanner = HostScanner(self.scanjob, self.network_profile,
self.fact_endpoint)
scanner = HostScanner(self.scanjob, self.fact_endpoint)
facts = [{'connection_host': '1.2.3.4',
'key1': 'value1',
'key2': 'value2'}]
Expand Down Expand Up @@ -98,17 +97,15 @@ def test_scan_inventory(self):
@patch('scanner.utils.TaskQueueManager.run', side_effect=mock_run_failed)
def test_host_scan_failure(self, mock_run):
"""Test scan flow with mocked manager and failure."""
scanner = HostScanner(self.scanjob, self.network_profile,
self.fact_endpoint)
scanner = HostScanner(self.scanjob, self.fact_endpoint)
with self.assertRaises(AnsibleError):
scanner.host_scan()
mock_run.assert_called()

@patch('scanner.host.HostScanner.host_scan', side_effect=mock_scan_error)
def test_host_scan_error(self, mock_scan):
"""Test scan flow with mocked manager and failure."""
scanner = HostScanner(self.scanjob, self.network_profile,
self.fact_endpoint)
scanner = HostScanner(self.scanjob, self.fact_endpoint)
facts = scanner.run()
mock_scan.assert_called()
self.assertEqual(facts, [])
Expand All @@ -120,8 +117,7 @@ def test_host_scan(self, mock_run):
mock_run.return_value = expected
with requests_mock.Mocker() as mocker:
mocker.post(self.fact_endpoint, status_code=201, json={'id': 1})
scanner = HostScanner(self.scanjob, self.network_profile,
self.fact_endpoint)
scanner = HostScanner(self.scanjob, self.fact_endpoint)
facts = scanner.run()
mock_run.assert_called()
self.assertEqual(facts, [])

0 comments on commit 95eec5e

Please sign in to comment.