Skip to content

Commit

Permalink
Add control of max_concurrency to scan API. Closes #138.
Browse files Browse the repository at this point in the history
  • Loading branch information
chambridge committed Nov 1, 2017
1 parent 487aba8 commit 1cc1dc3
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 10 deletions.
6 changes: 3 additions & 3 deletions quipucords/api/scanjob/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ class ScanJob(models.Model):
choices=STATUS_CHOICES,
default=PENDING,
)
max_concurrency = models.PositiveIntegerField(default=50)

def __str__(self):
return '{id:%s, scan_type:%s, profile:%s}' % (self.id,
self.scan_type,
self.profile)
return '{id:%s, scan_type:%s, profile:%s, max_concurrency: %d}' % \
(self.id, self.scan_type, self.profile, self.max_concurrency)

class Meta:
verbose_name_plural = _(messages.PLURAL_SCAN_JOBS_MSG)
4 changes: 3 additions & 1 deletion quipucords/api/scanjob/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from rest_framework.serializers import (ModelSerializer,
PrimaryKeyRelatedField,
ValidationError,
ChoiceField)
ChoiceField,
IntegerField)
from api.models import NetworkProfile, ScanJob
import api.messages as messages

Expand All @@ -37,6 +38,7 @@ class ScanJobSerializer(ModelSerializer):
scan_type = ChoiceField(required=False, choices=ScanJob.SCAN_TYPE_CHOICES)
status = ChoiceField(required=False, read_only=True,
choices=ScanJob.STATUS_CHOICES)
max_concurrency = IntegerField(required=False, min_value=1)

class Meta:
model = ScanJob
Expand Down
12 changes: 10 additions & 2 deletions quipucords/api/scanjob/tests_scanjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def test_create_invalid_profile(self):
self.create_expect_400(
{'profile': -1})

def test_create_invalid_forks(self):
"""Test valid number of forks"""
data = {'profile': self.network_profile.id,
'max_concurrency': -5}
self.create_expect_400(data)

@patch('api.scanjob.view.DiscoveryScanner.start', side_effect=dummy_start)
def test_list(self, start): # pylint: disable=unused-argument
"""List all ScanJob objects."""
Expand All @@ -116,9 +122,11 @@ def test_list(self, start): # pylint: disable=unused-argument

content = response.json()
expected = [{'id': 1, 'profile': {'id': 1, 'name': 'profile1'},
'scan_type': 'host', 'status': 'pending'},
'scan_type': 'host', 'status': 'pending',
'max_concurrency': 50},
{'id': 2, 'profile': {'id': 1, 'name': 'profile1'},
'scan_type': 'discovery', 'status': 'pending'}]
'scan_type': 'discovery', 'status': 'pending',
'max_concurrency': 50}]
self.assertEqual(content, expected)

@patch('api.scanjob.view.DiscoveryScanner.start', side_effect=dummy_start)
Expand Down
4 changes: 3 additions & 1 deletion quipucords/scanner/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,13 @@ def discovery(self):

logger.info('Discovery scan started for %s.', self.scanjob)

forks = self.scanjob.max_concurrency
for cred_id in credentials:
cred_obj = HostCredential.objects.get(pk=cred_id)
hc_serializer = HostCredentialSerializer(cred_obj)
cred = hc_serializer.data
connected, remaining = connect(remaining, cred, connection_port)
connected, remaining = connect(remaining, cred, connection_port,
forks=forks)
if remaining == []:
break

Expand Down
3 changes: 2 additions & 1 deletion quipucords/scanner/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def host_scan(self):
inventory = construct_scan_inventory(connected, connection_port)
inventory_file = write_inventory(inventory)
callback = ResultCallback()
result = run_playbook(inventory_file, callback, playbook)
forks = self.scanjob.max_concurrency
result = run_playbook(inventory_file, callback, playbook, forks=forks)

if result != TaskQueueManager.RUN_OK:
raise _construct_error(result)
Expand Down
5 changes: 3 additions & 2 deletions quipucords/scanner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,13 @@ def _construct_error(return_code):
return AnsibleError(message=message)


def connect(hosts, credential, connection_port):
def connect(hosts, credential, connection_port, forks=50):
"""Attempt to connect to hosts using the given credential.
:param hosts: The collection of hosts to test connections
:param credential: The credential used for connections
:param connection_port: The connection port
:param forks: number of forks to run with, default of 50
:returns: list of connected hosts credential tuples and
list of host that failed connection
"""
Expand All @@ -239,7 +240,7 @@ def connect(hosts, credential, connection_port):
'tasks': [{'action': {'module': 'raw',
'args': parse_kv('echo "Hello"')}}]}

result = run_playbook(inventory_file, callback, playbook)
result = run_playbook(inventory_file, callback, playbook, forks=forks)
if (result != TaskQueueManager.RUN_OK and
result != TaskQueueManager.RUN_UNREACHABLE_HOSTS):
raise _construct_error(result)
Expand Down

0 comments on commit 1cc1dc3

Please sign in to comment.