From 4c27a1e8ce4bd8c9a43b54532177aa18cfa7f9c4 Mon Sep 17 00:00:00 2001 From: Philippe Pepiot Date: Tue, 30 Apr 2019 01:14:44 +0200 Subject: [PATCH] Implement a new Ansible Runner Following specification on https://github.com/philpep/testinfra/issues/431 * Use ansible-inventory cli to parse inventory, group and host variables * Ansible backend re-use existing backend (local, ssh) to run commands. This is a breaking change because we do not support all connection backends from ansible. * The Ansible module run with the ansible cli. * Add more tests * Fix the skipped "encoding" test which pass now. This fix maintainability issue with ansible and the license issue (ansible is GPL). Closes #431 --- test/test_backends.py | 30 ++- testinfra/backend/ansible.py | 15 +- testinfra/utils/ansible_runner.py | 295 +++++++++++++++++------------- 3 files changed, 192 insertions(+), 148 deletions(-) diff --git a/test/test_backends.py b/test/test_backends.py index 535f59b5..9b1661f3 100644 --- a/test/test_backends.py +++ b/test/test_backends.py @@ -43,9 +43,6 @@ def test_command(host): @pytest.mark.testinfra_hosts(*HOSTS) def test_encoding(host): - if host.backend.get_connection_type() == "ansible": - pytest.skip("ansible handle encoding himself") - # stretch image is fr_FR@ISO-8859-15 cmd = host.run("ls -l %s", "/é") if host.backend.get_connection_type() == "docker": @@ -124,6 +121,33 @@ def get_vars(host): } +def test_ansible_get_backend(): + with tempfile.NamedTemporaryFile() as f: + f.write(( + b'localhost ansible_connection=local ansible_become=yes\n' + b'debian ansible_user=u ansible_become=yes\n' + b'centos ansible_connection=ssh ansible_host=127.0.0.1 ' + b'ansible_port=2222\n' + )) + f.flush() + + def get_backend(host): + return AnsibleRunner(f.name).get_backend(host).backend + localhost = get_backend('localhost') + assert localhost.NAME == 'local' + assert localhost.sudo + debian = get_backend('debian') + assert debian.NAME == 'paramiko' + assert debian.sudo + assert debian.host.name == 'debian' + assert debian.host.user == 'u' + centos = get_backend('centos') + assert centos.NAME == 'paramiko' + assert not centos.sudo + assert centos.host.name == '127.0.0.1' + assert centos.host.port == '2222' + + def test_backend_importables(): # just check that all declared backend are importable and NAME is set # correctly diff --git a/testinfra/backend/ansible.py b/testinfra/backend/ansible.py index a5dc77ca..87faca9a 100644 --- a/testinfra/backend/ansible.py +++ b/testinfra/backend/ansible.py @@ -19,7 +19,6 @@ from testinfra.backend import base from testinfra.utils.ansible_runner import AnsibleRunner -from testinfra.utils.ansible_runner import to_bytes logger = logging.getLogger("testinfra") @@ -39,20 +38,10 @@ def ansible_runner(self): def run(self, command, *args, **kwargs): command = self.get_command(command, *args) - out = self.run_ansible("shell", module_args=command) - return self.result( - out['rc'], - command, - stdout_bytes=None, - stderr_bytes=None, - stdout=out["stdout"], stderr=out["stderr"], - ) - - def encode(self, data): - return to_bytes(data) + return self.ansible_runner.run(self.host, command) def run_ansible(self, module_name, module_args=None, **kwargs): - result = self.ansible_runner.run( + result = self.ansible_runner.run_module( self.host, module_name, module_args, **kwargs) logger.info( diff --git a/testinfra/utils/ansible_runner.py b/testinfra/utils/ansible_runner.py index 4cdbb6ed..3fbcb6b0 100644 --- a/testinfra/utils/ansible_runner.py +++ b/testinfra/utils/ansible_runner.py @@ -11,56 +11,183 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=import-error,no-name-in-module,no-member -# pylint: disable=unexpected-keyword-arg,no-value-for-parameter -# pylint: disable=arguments-differ - from __future__ import unicode_literals from __future__ import absolute_import -import pprint - - -try: - import ansible -except ImportError: - raise RuntimeError( - "You must install ansible package to use the ansible backend") +import fnmatch +import json +import os +import tempfile -import ansible.cli.playbook -import ansible.constants -import ansible.executor.task_queue_manager -import ansible.inventory -import ansible.parsing.dataloader -import ansible.playbook.play -import ansible.plugins.callback -import ansible.utils.vars -import ansible.vars +from six.moves import configparser -try: - from ansible.module_utils._text import to_bytes -except ImportError: - from ansible.utils.unicode import to_bytes +import testinfra +from testinfra.utils import cached_property -__all__ = ['AnsibleRunner', 'to_bytes'] +__all__ = ['AnsibleRunner'] - -class AnsibleRunnerBase(object): +EMPTY_INVENTORY = { + "_meta": { + "hostvars": {} + }, + "all": { + "children": [ + "ungrouped" + ] + }, + "ungrouped": {} +} +local = testinfra.get_host('local://') + + +def get_ansible_config(): + fname = os.environ.get('ANSIBLE_CONFIG') + if not fname: + for possible in ( + os.path.join(os.path.expanduser('~'), '.ansible.cfg'), + os.path.join('/', 'etc', 'ansible', 'ansible.cfg'), + ): + if os.path.exists(possible): + fname = possible + break + config = configparser.ConfigParser() + if not fname: + return config + config.read(fname) + return config + + +def get_ansible_inventory(inventory_file): + cmd = 'ansible-inventory --list' + args = [] + if inventory_file: + cmd += ' -i %s' + args += [inventory_file] + return json.loads(local.check_output(cmd, *args)) + + +def get_backend(config, inventory, host): + if inventory == EMPTY_INVENTORY: + return testinfra.get_host('local://') + else: + hostvars = inventory['_meta'].get('hostvars', {}).get(host, {}) + connection = hostvars.get('ansible_connection', 'ssh') + if connection not in ('ssh', 'local', 'docker'): + raise NotImplementedError( + 'unhandled ansible_connection {}'.format(connection)) + if connection == 'ssh': + connection = 'paramiko' + testinfra_host = hostvars.get('ansible_host', host) + user = hostvars.get('ansible_user') + port = hostvars.get('ansible_port') + kwargs = {} + if hostvars.get('ansible_become', False): + kwargs['sudo'] = True + if 'ansible_ssh_private_key_file' in hostvars: + kwargs['ssh_identity_file'] = hostvars[ + 'ansible_ssh_private_key_file'] + try: + host_key_checking = config['defaults']['host_key_checking'] + except KeyError: + pass + else: + if host_key_checking.lower()[:1] in ('n', 'f', '0'): + kwargs['strict_host_key_checking'] = False + spec = '{}://'.format(connection) + if user: + spec += '{}@'.format(user) + spec += testinfra_host + if port: + spec += ':{}'.format(port) + return testinfra.get_host(spec, **kwargs) + + +class AnsibleRunner(object): _runners = {} - def __init__(self, host_list=None): - self.host_list = host_list - super(AnsibleRunnerBase, self).__init__() + def __init__(self, inventory_file=None): + self.inventory_file = inventory_file + self._backend_cache = {} + super(AnsibleRunner, self).__init__() def get_hosts(self, pattern=None): - raise NotImplementedError + inventory = self.inventory + result = set() + if inventory == EMPTY_INVENTORY: + # use localhost as fallback + result.add('localhost') + else: + for group in inventory: + groupmatch = fnmatch.fnmatch(group, pattern) + for host in inventory[group].get('hosts', []): + if (groupmatch or pattern == 'all' + or fnmatch.fnmatch(host, pattern)): + result.add(host) + return sorted(result) + + @cached_property + def inventory(self): + return get_ansible_inventory(self.inventory_file) + + @cached_property + def ansible_config(self): + return get_ansible_config() def get_variables(self, host): - raise NotImplementedError - - def run(self, host, module_name, module_args, **kwargs): - raise NotImplementedError + inventory = self.inventory + hostvars = inventory['_meta'].get( + 'hostvars', {}).get(host, {}) + hostvars.setdefault('inventory_hostname', host) + groups = [] + for group in sorted(inventory): + if group in ('_meta', 'all'): + continue + if host in inventory[group].get('hosts', []): + groups.append(group) + hostvars.setdefault('group_names', groups) + return hostvars + + def get_backend(self, host): + try: + return self._backend_cache[host] + except KeyError: + backend = self._backend_cache[host] = get_backend( + self.ansible_config, self.inventory, host) + return backend + + def run(self, host, command): + return self.get_backend(host).run(command) + + def run_module(self, host, module_name, module_args, become=False, + check=True, **kwargs): + cmd, args = 'ansible --tree %s', [None] + if self.inventory_file: + cmd += ' -i %s' + args += [self.inventory_file] + cmd += ' -m %s' + args += [module_name] + if module_args: + cmd += ' --args %s' + args += [module_args] + if become: + cmd += ' --become' + if check: + cmd += ' --check' + cmd += ' %s' + args += [host] + with tempfile.TemporaryDirectory() as d: + args[0] = d + out = local.run_expect([0, 2], cmd, *args) + files = os.listdir(d) + if not files and 'skipped' in out.stdout.lower(): + return {'failed': True, 'skipped': True, + 'msg': 'Skipped. You might want to try check=False'} + elif not files: + raise RuntimeError('Error while running {}: {}'.format( + ' '.join(cmd), out)) + with open(os.path.join(d, files[0]), 'r') as f: + return json.load(f) @classmethod def get_runner(cls, inventory): @@ -69,99 +196,3 @@ def get_runner(cls, inventory): except KeyError: cls._runners[inventory] = cls(inventory) return cls._runners[inventory] - - -class Callback(ansible.plugins.callback.CallbackBase): - - def __init__(self, *args, **kwargs): - self.result = {} - super(Callback, self).__init__(*args, **kwargs) - - def runner_on_ok(self, host, result): - self.result = result - - def runner_on_failed(self, host, result, ignore_errors=False): - self.result = result - - # pylint: disable=no-self-use - def runner_on_unreachable(self, host, result): - raise RuntimeError( - 'Host {} is unreachable: {}'.format( - host, pprint.pformat(result)), - ) - - def runner_on_skipped(self, host, item=None): - self.result = { - 'failed': True, - 'msg': 'Skipped. You might want to try check=False', - 'item': item, - } - - -class AnsibleRunner(AnsibleRunnerBase): - - def __init__(self, host_list=None): - super(AnsibleRunner, self).__init__(host_list) - self.cli = ansible.cli.playbook.PlaybookCLI(None) - self.cli.options = self.cli.base_parser( - connect_opts=True, - meta_opts=True, - runas_opts=True, - subset_opts=True, - check_opts=True, - inventory_opts=True, - runtask_opts=True, - vault_opts=True, - fork_opts=True, - module_opts=True, - ).parse_args([])[0] - self.cli.normalize_become_options() - self.cli.options.connection = "smart" - self.cli.options.inventory = host_list - # pylint: disable=protected-access - self.loader, self.inventory, self.variable_manager = ( - self.cli._play_prereqs(self.cli.options)) - - def get_hosts(self, pattern=None): - return [ - e.name for e in - self.inventory.get_hosts(pattern=pattern or "all") - ] - - def get_variables(self, host): - host = self.inventory.get_host(host) - return self.variable_manager.get_vars(host=host) - - def run(self, host, module_name, module_args=None, **kwargs): - self.cli.options.check = kwargs.get("check", False) - self.cli.options.become = kwargs.get("become", False) - action = {"module": module_name} - if module_args is not None: - if module_name in ("command", "shell"): - # Workaround https://github.com/ansible/ansible/issues/13862 - module_args = module_args.replace("=", "\\=") - action["args"] = module_args - play = ansible.playbook.play.Play().load({ - "hosts": host, - "gather_facts": "no", - "tasks": [{ - "action": action, - }], - }, variable_manager=self.variable_manager, loader=self.loader) - tqm = None - callback = Callback() - try: - tqm = ansible.executor.task_queue_manager.TaskQueueManager( - inventory=self.inventory, - variable_manager=self.variable_manager, - loader=self.loader, - options=self.cli.options, - passwords=None, - stdout_callback=callback, - ) - tqm.run(play) - finally: - if tqm is not None: - tqm.cleanup() - - return callback.result