Skip to content

Commit

Permalink
Moving the SSH connect into the host class to improve performance of …
Browse files Browse the repository at this point in the history
…making multiple calls to a single host.

Signed-off-by: Rob Dobson <rob.dobson@citrix.com>
  • Loading branch information
rdobson committed Jul 16, 2014
1 parent 82211e4 commit d2fffe5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
24 changes: 18 additions & 6 deletions hwinfo/tools/inspector.py
Expand Up @@ -13,18 +13,20 @@
from hwinfo.host import dmidecode
from hwinfo.host import cpuinfo

def remote_command(host, username, password, cmd):
def get_ssh_client(host, username, password, timeout=10):
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(host, username=username, password=password, timeout=10)
client.connect(host, username=username, password=password, timeout=timeout)
return client

def remote_command(client, cmd):
cmdstr = ' '.join(cmd)
#print "Executing '%s' on host '%s'" % (cmdstr, host)
_, stdout, stderr = client.exec_command(cmdstr)
output = stdout.readlines()
error = stderr.readlines()
if error:
raise Exception("stderr: %s" % error)
client.close()
return ''.join(output)

def local_command(cmd):
Expand All @@ -44,12 +46,22 @@ def __init__(self, host='localhost', username=None, password=None):
self.host = host
self.username = username
self.password = password
self.client = None
if self.is_remote():
self.client = get_ssh_client(self.host, self.username, self.password)

def __del__(self):
if self.client:
self.client.close()

def is_remote(self):
return self.host != 'localhost'

def exec_command(self, cmd):
if self.host == 'localhost':
return local_command(cmd)
if self.is_remote():
return remote_command(self.client, cmd)
else:
return remote_command(self.host, self.username, self.password, cmd)
return local_command(cmd)

def get_lspci_data(self):
return self.exec_command(['lspci', '-nnmm'])
Expand Down
19 changes: 15 additions & 4 deletions hwinfo/tools/tests/test_inspector.py
Expand Up @@ -13,11 +13,13 @@ def test_local_exec_command(self, local_command):
host.exec_command('ls')
inspector.local_command.assert_called_once_with('ls')

@patch('hwinfo.tools.inspector.get_ssh_client')
@patch('hwinfo.tools.inspector.remote_command')
def test_remote_exec_command(self, remote_command):
def test_remote_exec_command(self, remote_command, get_ssh_client):
mclient = get_ssh_client.return_value = mock.MagicMock()
host = inspector.Host('mymachine', 'root', 'pass')
host.exec_command('ls')
inspector.remote_command.assert_called_once_with('mymachine', 'root', 'pass', 'ls')
inspector.remote_command.assert_called_once_with(mclient, 'ls')

@patch('hwinfo.tools.inspector.Host.exec_command')
def test_get_pci_devices(self, exec_command):
Expand All @@ -35,6 +37,15 @@ def test_get_info(self, mock_exec_command, mock_dmidecode_parser_cls):
rec = host.get_info()
self.assertEqual(rec, {'key':'value'})

def test_is_not_remote(self):
host = inspector.Host()
self.assertEqual(host.is_remote(), False)

@patch('hwinfo.tools.inspector.get_ssh_client')
def test_is_remote(self, get_ssh_client):
get_ssh_client.return_value = mock.MagicMock()
host = inspector.Host('test', 'user', 'pass')
self.assertEqual(host.is_remote(), True)

class RemoteCommandTests(unittest.TestCase):

Expand All @@ -47,15 +58,15 @@ def setUp(self):
def test_ssh_connect(self, ssh_client_cls):
client = ssh_client_cls.return_value = mock.Mock()
client.exec_command.return_value = self.stdout, self.stdin, self.stderr
inspector.remote_command('test', 'user', 'pass', 'ls')
inspector.get_ssh_client('test', 'user', 'pass')
client.connect.assert_called_with('test', password='pass', username='user', timeout=10)

@patch('paramiko.SSHClient')
def test_ssh_connect_error(self, ssh_client_cls):
client = ssh_client_cls.return_value = mock.Mock()
client.exec_command.return_value = self.stdout, self.stdin, StringIO("Error")
with self.assertRaises(Exception) as context:
inspector.remote_command('test', 'user', 'pass', 'ls')
inspector.remote_command(client, 'ls')
self.assertEqual(context.exception.message, "stderr: ['Error']")

class LocalCommandTests(unittest.TestCase):
Expand Down

0 comments on commit d2fffe5

Please sign in to comment.