From a1c5ac1c9fe94acfa65d6d383c764622997460d1 Mon Sep 17 00:00:00 2001 From: Chris Hambridge Date: Mon, 16 Oct 2017 10:26:08 -0400 Subject: [PATCH] improve keyfile path validation. Closes #350. --- rho/authaddcommand.py | 12 +++++++++ rho/autheditcommand.py | 12 +++++++++ test/test_clicommand.py | 58 ++++++++++++++++++++++++++++++++++------- 3 files changed, 72 insertions(+), 10 deletions(-) diff --git a/rho/authaddcommand.py b/rho/authaddcommand.py index 8a2a3cb..860468f 100644 --- a/rho/authaddcommand.py +++ b/rho/authaddcommand.py @@ -141,6 +141,18 @@ def _validate_options(self): self.parser.print_help() sys.exit(1) + if self.options.filename: + keyfile_path = os.path.abspath(os.path.normpath( + os.path.expanduser(os.path.expandvars(self.options.filename)))) + if os.path.isfile(keyfile_path) is False: + print(_('You must provide a valid file path for' + ' "--sshkeyfile", "%s" could not be found.' + % keyfile_path)) + self.parser.print_help() + sys.exit(1) + else: + self.options.filename = keyfile_path + def _do_command(self): vault = get_vault(self.options.vaultfile) auth_name = self.options.name diff --git a/rho/autheditcommand.py b/rho/autheditcommand.py index ebb6d3f..d4dcf05 100644 --- a/rho/autheditcommand.py +++ b/rho/autheditcommand.py @@ -108,6 +108,18 @@ def _validate_options(self): self.parser.print_help() sys.exit(1) + if self.options.filename: + keyfile_path = os.path.abspath(os.path.normpath( + os.path.expanduser(os.path.expandvars(self.options.filename)))) + if os.path.isfile(keyfile_path) is False: + print(_('You must provide a valid file path for' + ' "--sshkeyfile", "%s" could not be found.' + % keyfile_path)) + self.parser.print_help() + sys.exit(1) + else: + self.options.filename = keyfile_path + def _do_command(self): vault = get_vault(self.options.vaultfile) auth_found = False diff --git a/test/test_clicommand.py b/test/test_clicommand.py index 15b88aa..fc0f5b9 100644 --- a/test/test_clicommand.py +++ b/test/test_clicommand.py @@ -39,7 +39,7 @@ from rho.scancommand import ScanCommand TEST_VAULT_PASSWORD = 'password' - +TMP_KEY = "/tmp/privatekey" TMP_VAULT_PASS = "/tmp/vault_pass" TMP_FACTS = "/tmp/facts.txt" TMP_HOSTS = "/tmp/hosts.txt" @@ -151,6 +151,11 @@ def setUp(self): with open(TMP_VAULT_PASS, 'w') as vault_pass_file: vault_pass_file.write(TEST_VAULT_PASSWORD) + if os.path.isfile(TMP_KEY): + os.remove(TMP_KEY) + with open(TMP_KEY, 'w') as privatekey_file: + privatekey_file.write(TEST_VAULT_PASSWORD) + if os.path.isfile(TMP_FACTS): os.remove(TMP_FACTS) with open(TMP_FACTS, 'w') as facts_file: @@ -227,6 +232,9 @@ def tearDown(self): if os.path.isfile(TMP_VAULT_PASS): os.remove(TMP_VAULT_PASS) + if os.path.isfile(TMP_KEY): + os.remove(TMP_KEY) + if os.path.isfile(TMP_FACTS): os.remove(TMP_FACTS) @@ -243,7 +251,7 @@ def test_auth_add(self, uuid4): sys.argv = ['/bin/rho', "auth", "add", "--name", "auth_1", "--username", "user", "--sshkeyfile", - "./privatekey", "--vault", + TMP_KEY, "--vault", TMP_VAULT_PASS] creds = list() @@ -255,7 +263,22 @@ def test_auth_add(self, uuid4): u'username': u'user', u'password': None, u'sudo_password': None, - u'ssh_key_file': u'./privatekey'}]) + u'ssh_key_file': u'/tmp/privatekey'}]) + + # pylint: disable=unused-argument + @mock.patch('uuid.uuid4', return_value=1) + def test_auth_add_bad_key(self, uuid4): + """Testing the auth add command execution""" + sys.argv = ['/bin/rho', "auth", "add", "--name", "auth_1", + "--username", "user", "--sshkeyfile", + "/not/a/valid/path", "--vault", + TMP_VAULT_PASS] + + auth_add_out = six.StringIO() + with self.assertRaises(SystemExit): + with redirect_stdout(auth_add_out): + AuthAddCommand().main() + self.assertIn("/not/a/valid/path", auth_add_out) # pylint: disable=unused-argument @mock.patch('uuid.uuid4', return_value=2) @@ -264,12 +287,12 @@ def test_auth_add_again(self, uuid4): sys.argv = ['/bin/rho', "auth", "add", "--name", "auth_2", "--username", "user", "--sshkeyfile", - "./privatekey", "--vault", + TMP_KEY, "--vault", TMP_VAULT_PASS] creds = [{u'id': u'1', u'name': u'auth_1', u'username': u'user', u'password': u'', u'sudo_password': None, - u'ssh_key_file': u'./privatekey'}] + u'ssh_key_file': u'/tmp/privatekey'}] with redirect_credentials(creds): AuthAddCommand().main() @@ -280,13 +303,13 @@ def test_auth_add_again(self, uuid4): u'username': u'user', u'password': u'', u'sudo_password': None, - u'ssh_key_file': u'./privatekey'}, + u'ssh_key_file': u'/tmp/privatekey'}, {u'id': u'2', u'name': u'auth_2', u'username': u'user', u'password': None, u'sudo_password': None, - u'ssh_key_file': u'./privatekey'}]) + u'ssh_key_file': u'/tmp/privatekey'}]) def test_auth_list(self): """Testing the auth list command execution""" @@ -312,10 +335,10 @@ def test_auth_edit(self): sys.argv = ['/bin/rho', "auth", "edit", "--name", "auth_1", "--username", "user_2", - "--sshkeyfile", "file_2", + "--sshkeyfile", TMP_KEY, "--vault", TMP_VAULT_PASS] creds = [{'id': '1', 'name': 'auth_1', 'username': 'user_1', - 'password': 'password', 'ssh_key_file': 'file_1'}] + 'password': 'password', 'ssh_key_file': TMP_KEY}] with redirect_credentials(creds): AuthEditCommand().main() @@ -324,7 +347,22 @@ def test_auth_edit(self): 'name': 'auth_1', 'username': 'user_2', 'password': 'password', - 'ssh_key_file': 'file_2'}]) + 'ssh_key_file': TMP_KEY}]) + + # pylint: disable=unused-argument + @mock.patch('uuid.uuid4', return_value=1) + def test_auth_edit_bad_key(self, uuid4): + """Testing the auth add command execution""" + sys.argv = ['/bin/rho', "auth", "edit", "--name", "auth_1", + "--username", "user", "--sshkeyfile", + "/not/a/valid/path", "--vault", + TMP_VAULT_PASS] + + auth_edit_out = six.StringIO() + with self.assertRaises(SystemExit): + with redirect_stdout(auth_edit_out): + AuthEditCommand().main() + self.assertIn("/not/a/valid/path", auth_edit_out) def test_auth_show(self): """Testing the auth show command execution"""