Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/52675.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added list target type support to the `scan` salt-ssh roster.
33 changes: 15 additions & 18 deletions salt/roster/scan.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
# -*- coding: utf-8 -*-
"""
Scan a netmask or ipaddr for open ssh ports
"""

# Import python libs
from __future__ import absolute_import, print_function, unicode_literals

import copy
import logging
import socket

# Import salt libs
import salt.utils.network
from salt._compat import ipaddress
from salt.ext import six

# Import 3rd-party libs
from salt.ext.six.moves import map # pylint: disable=import-error,redefined-builtin

log = logging.getLogger(__name__)
Expand All @@ -30,7 +22,7 @@ def targets(tgt, tgt_type="glob", **kwargs):
return rmatcher.targets()


class RosterMatcher(object):
class RosterMatcher:
"""
Matcher for the roster data structure
"""
Expand All @@ -44,21 +36,26 @@ def targets(self):
Return ip addrs based on netmask, sitting in the "glob" spot because
it is the default
"""
addrs = ()
addrs = []
ret = {}
ports = __opts__["ssh_scan_ports"]
if not isinstance(ports, list):
# Comma-separate list of integers
ports = list(map(int, six.text_type(ports).split(",")))
try:
addrs = [ipaddress.ip_address(self.tgt)]
except ValueError:
ports = list(map(int, str(ports).split(",")))
if self.tgt_type == "list":
tgts = self.tgt
else:
tgts = [self.tgt]
for tgt in tgts:
try:
addrs = ipaddress.ip_network(self.tgt).hosts()
addrs.append(ipaddress.ip_address(tgt))
except ValueError:
pass
try:
addrs.extend(ipaddress.ip_network(tgt).hosts())
except ValueError:
pass
for addr in addrs:
addr = six.text_type(addr)
addr = str(addr)
ret[addr] = copy.deepcopy(__opts__.get("roster_defaults", {}))
log.trace("Scanning host: %s", addr)
for port in ports:
Expand All @@ -70,6 +67,6 @@ def targets(self):
sock.shutdown(socket.SHUT_RDWR)
sock.close()
ret[addr].update({"host": addr, "port": port})
except socket.error:
except OSError:
pass
return ret
106 changes: 106 additions & 0 deletions tests/unit/roster/test_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Test the scan roster.
"""

import socket

import salt.roster.scan as scan_
from tests.support import mixins
from tests.support.mock import MagicMock, patch
from tests.support.unit import TestCase


class ScanRosterTestCase(TestCase, mixins.LoaderModuleMockMixin):
"""Test the directory roster"""

def setup_loader_modules(self):
return {scan_: {"__opts__": {"ssh_scan_ports": "22", "ssh_scan_timeout": 0.01}}}

def test_single_ip(self):
"""Test that minion files in the directory roster match and render."""
with patch("salt.utils.network.get_socket"):
ret = scan_.targets("127.0.0.1")
self.assertEqual(ret, {"127.0.0.1": {"host": "127.0.0.1", "port": 22}})

def test_single_network(self):
"""Test that minion files in the directory roster match and render."""
with patch("salt.utils.network.get_socket"):
ret = scan_.targets("127.0.0.0/30")
self.assertEqual(
ret,
{
"127.0.0.1": {"host": "127.0.0.1", "port": 22},
"127.0.0.2": {"host": "127.0.0.2", "port": 22},
},
)

def test_multiple_ips(self):
"""Test that minion files in the directory roster match and render."""
with patch("salt.utils.network.get_socket"):
ret = scan_.targets(["127.0.0.1", "127.0.0.2"], tgt_type="list")
self.assertEqual(
ret,
{
"127.0.0.1": {"host": "127.0.0.1", "port": 22},
"127.0.0.2": {"host": "127.0.0.2", "port": 22},
},
)

def test_multiple_networks(self):
"""Test that minion files in the directory roster match and render."""
with patch("salt.utils.network.get_socket"):
ret = scan_.targets(
["127.0.0.0/30", "127.0.2.1", "127.0.1.0/30"], tgt_type="list"
)
self.assertEqual(
ret,
{
"127.0.0.1": {"host": "127.0.0.1", "port": 22},
"127.0.0.2": {"host": "127.0.0.2", "port": 22},
"127.0.2.1": {"host": "127.0.2.1", "port": 22},
"127.0.1.1": {"host": "127.0.1.1", "port": 22},
"127.0.1.2": {"host": "127.0.1.2", "port": 22},
},
)

def test_malformed_ip(self):
"""Test that minion files in the directory roster match and render."""
with patch("salt.utils.network.get_socket"):
ret = scan_.targets("127001")
self.assertEqual(ret, {})

def test_multiple_with_malformed(self):
"""Test that minion files in the directory roster match and render."""
with patch("salt.utils.network.get_socket"):
ret = scan_.targets(
["127.0.0.1", "127002", "127.0.1.0/30"], tgt_type="list"
)
self.assertEqual(
ret,
{
"127.0.0.1": {"host": "127.0.0.1", "port": 22},
"127.0.1.1": {"host": "127.0.1.1", "port": 22},
"127.0.1.2": {"host": "127.0.1.2", "port": 22},
},
)

def test_multiple_no_connection(self):
"""Test that minion files in the directory roster match and render."""
socket_mock = MagicMock()
socket_mock.connect = MagicMock(
side_effect=[None, socket.error(), None, socket.error(), None]
)
with patch("salt.utils.network.get_socket", return_value=socket_mock):
ret = scan_.targets(
["127.0.0.0/30", "127.0.2.1", "127.0.1.0/30"], tgt_type="list"
)
self.assertEqual(
ret,
{
"127.0.0.1": {"host": "127.0.0.1", "port": 22},
"127.0.0.2": {},
"127.0.2.1": {"host": "127.0.2.1", "port": 22},
"127.0.1.1": {},
"127.0.1.2": {"host": "127.0.1.2", "port": 22},
},
)