From d2fffe5bef550f94e92605cc7427b85ced94bd38 Mon Sep 17 00:00:00 2001 From: Rob Dobson Date: Wed, 16 Jul 2014 12:19:49 +0100 Subject: [PATCH] Moving the SSH connect into the host class to improve performance of making multiple calls to a single host. Signed-off-by: Rob Dobson --- hwinfo/tools/inspector.py | 24 ++++++++++++++++++------ hwinfo/tools/tests/test_inspector.py | 19 +++++++++++++++---- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/hwinfo/tools/inspector.py b/hwinfo/tools/inspector.py index 5c35ab8..10f1a6c 100644 --- a/hwinfo/tools/inspector.py +++ b/hwinfo/tools/inspector.py @@ -13,10 +13,13 @@ from hwinfo.host import dmidecode from hwinfo.host import cpuinfo -def remote_command(host, username, password, cmd): +def get_ssh_client(host, username, password, timeout=10): client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect(host, username=username, password=password, timeout=10) + client.connect(host, username=username, password=password, timeout=timeout) + return client + +def remote_command(client, cmd): cmdstr = ' '.join(cmd) #print "Executing '%s' on host '%s'" % (cmdstr, host) _, stdout, stderr = client.exec_command(cmdstr) @@ -24,7 +27,6 @@ def remote_command(host, username, password, cmd): error = stderr.readlines() if error: raise Exception("stderr: %s" % error) - client.close() return ''.join(output) def local_command(cmd): @@ -44,12 +46,22 @@ def __init__(self, host='localhost', username=None, password=None): self.host = host self.username = username self.password = password + self.client = None + if self.is_remote(): + self.client = get_ssh_client(self.host, self.username, self.password) + + def __del__(self): + if self.client: + self.client.close() + + def is_remote(self): + return self.host != 'localhost' def exec_command(self, cmd): - if self.host == 'localhost': - return local_command(cmd) + if self.is_remote(): + return remote_command(self.client, cmd) else: - return remote_command(self.host, self.username, self.password, cmd) + return local_command(cmd) def get_lspci_data(self): return self.exec_command(['lspci', '-nnmm']) diff --git a/hwinfo/tools/tests/test_inspector.py b/hwinfo/tools/tests/test_inspector.py index 8762777..4bbe90d 100644 --- a/hwinfo/tools/tests/test_inspector.py +++ b/hwinfo/tools/tests/test_inspector.py @@ -13,11 +13,13 @@ def test_local_exec_command(self, local_command): host.exec_command('ls') inspector.local_command.assert_called_once_with('ls') + @patch('hwinfo.tools.inspector.get_ssh_client') @patch('hwinfo.tools.inspector.remote_command') - def test_remote_exec_command(self, remote_command): + def test_remote_exec_command(self, remote_command, get_ssh_client): + mclient = get_ssh_client.return_value = mock.MagicMock() host = inspector.Host('mymachine', 'root', 'pass') host.exec_command('ls') - inspector.remote_command.assert_called_once_with('mymachine', 'root', 'pass', 'ls') + inspector.remote_command.assert_called_once_with(mclient, 'ls') @patch('hwinfo.tools.inspector.Host.exec_command') def test_get_pci_devices(self, exec_command): @@ -35,6 +37,15 @@ def test_get_info(self, mock_exec_command, mock_dmidecode_parser_cls): rec = host.get_info() self.assertEqual(rec, {'key':'value'}) + def test_is_not_remote(self): + host = inspector.Host() + self.assertEqual(host.is_remote(), False) + + @patch('hwinfo.tools.inspector.get_ssh_client') + def test_is_remote(self, get_ssh_client): + get_ssh_client.return_value = mock.MagicMock() + host = inspector.Host('test', 'user', 'pass') + self.assertEqual(host.is_remote(), True) class RemoteCommandTests(unittest.TestCase): @@ -47,7 +58,7 @@ def setUp(self): def test_ssh_connect(self, ssh_client_cls): client = ssh_client_cls.return_value = mock.Mock() client.exec_command.return_value = self.stdout, self.stdin, self.stderr - inspector.remote_command('test', 'user', 'pass', 'ls') + inspector.get_ssh_client('test', 'user', 'pass') client.connect.assert_called_with('test', password='pass', username='user', timeout=10) @patch('paramiko.SSHClient') @@ -55,7 +66,7 @@ def test_ssh_connect_error(self, ssh_client_cls): client = ssh_client_cls.return_value = mock.Mock() client.exec_command.return_value = self.stdout, self.stdin, StringIO("Error") with self.assertRaises(Exception) as context: - inspector.remote_command('test', 'user', 'pass', 'ls') + inspector.remote_command(client, 'ls') self.assertEqual(context.exception.message, "stderr: ['Error']") class LocalCommandTests(unittest.TestCase):