Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions hwinfo/tools/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@
from hwinfo.host import dmidecode
from hwinfo.host import cpuinfo

def remote_command(host, username, password, cmd):
def get_ssh_client(host, username, password, timeout=10):
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(host, username=username, password=password, timeout=10)
client.connect(host, username=username, password=password, timeout=timeout)
return client

def remote_command(client, cmd):
cmdstr = ' '.join(cmd)
#print "Executing '%s' on host '%s'" % (cmdstr, host)
_, stdout, stderr = client.exec_command(cmdstr)
output = stdout.readlines()
error = stderr.readlines()
if error:
raise Exception("stderr: %s" % error)
client.close()
return ''.join(output)

def local_command(cmd):
Expand All @@ -44,12 +46,22 @@ def __init__(self, host='localhost', username=None, password=None):
self.host = host
self.username = username
self.password = password
self.client = None
if self.is_remote():
self.client = get_ssh_client(self.host, self.username, self.password)

def __del__(self):
if self.client:
self.client.close()

def is_remote(self):
return self.host != 'localhost'

def exec_command(self, cmd):
if self.host == 'localhost':
return local_command(cmd)
if self.is_remote():
return remote_command(self.client, cmd)
else:
return remote_command(self.host, self.username, self.password, cmd)
return local_command(cmd)

def get_lspci_data(self):
return self.exec_command(['lspci', '-nnmm'])
Expand Down
19 changes: 15 additions & 4 deletions hwinfo/tools/tests/test_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ def test_local_exec_command(self, local_command):
host.exec_command('ls')
inspector.local_command.assert_called_once_with('ls')

@patch('hwinfo.tools.inspector.get_ssh_client')
@patch('hwinfo.tools.inspector.remote_command')
def test_remote_exec_command(self, remote_command):
def test_remote_exec_command(self, remote_command, get_ssh_client):
mclient = get_ssh_client.return_value = mock.MagicMock()
host = inspector.Host('mymachine', 'root', 'pass')
host.exec_command('ls')
inspector.remote_command.assert_called_once_with('mymachine', 'root', 'pass', 'ls')
inspector.remote_command.assert_called_once_with(mclient, 'ls')

@patch('hwinfo.tools.inspector.Host.exec_command')
def test_get_pci_devices(self, exec_command):
Expand All @@ -35,6 +37,15 @@ def test_get_info(self, mock_exec_command, mock_dmidecode_parser_cls):
rec = host.get_info()
self.assertEqual(rec, {'key':'value'})

def test_is_not_remote(self):
host = inspector.Host()
self.assertEqual(host.is_remote(), False)

@patch('hwinfo.tools.inspector.get_ssh_client')
def test_is_remote(self, get_ssh_client):
get_ssh_client.return_value = mock.MagicMock()
host = inspector.Host('test', 'user', 'pass')
self.assertEqual(host.is_remote(), True)

class RemoteCommandTests(unittest.TestCase):

Expand All @@ -47,15 +58,15 @@ def setUp(self):
def test_ssh_connect(self, ssh_client_cls):
client = ssh_client_cls.return_value = mock.Mock()
client.exec_command.return_value = self.stdout, self.stdin, self.stderr
inspector.remote_command('test', 'user', 'pass', 'ls')
inspector.get_ssh_client('test', 'user', 'pass')
client.connect.assert_called_with('test', password='pass', username='user', timeout=10)

@patch('paramiko.SSHClient')
def test_ssh_connect_error(self, ssh_client_cls):
client = ssh_client_cls.return_value = mock.Mock()
client.exec_command.return_value = self.stdout, self.stdin, StringIO("Error")
with self.assertRaises(Exception) as context:
inspector.remote_command('test', 'user', 'pass', 'ls')
inspector.remote_command(client, 'ls')
self.assertEqual(context.exception.message, "stderr: ['Error']")

class LocalCommandTests(unittest.TestCase):
Expand Down