Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replicate most of the ansible style host matching logic #753

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions test/test_ansible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest

from testinfra.backend import parse_hostspec
from testinfra.utils.ansible_runner import expand_pattern, get_hosts, Inventory


@pytest.fixture
def inventory() -> Inventory:
"""Hosts are always under a group, the default is "ungrouped" if using the
ini file format. The "all" meta-group always contains all hosts when
expanded."""
return {
"_meta": {
"hostvars": {
"a": None,
"b": None,
"c": None,
}
},
"all": {
"children": ["nested"],
},
"left": {
"hosts": ["a", "b"],
},
"right": {
"hosts": ["b", "c"],
},
"combined": {
"children": ["left", "right"],
},
"nested": {
"children": ["combined"],
}
}


def test_expand_pattern_simple(inventory: Inventory):
"""Simple names are matched, recurring into groups if needed."""
# direct hostname
assert expand_pattern("a", inventory) == {"a"}
# group
assert expand_pattern("left", inventory) == {"a", "b"}
# meta-group
assert expand_pattern("combined", inventory) == {"a", "b", "c"}
# meta-meta-group
assert expand_pattern("nested", inventory) == {"a", "b", "c"}


def test_expand_pattern_fnmatch(inventory: Inventory):
"""Simple names are matched, recurring into groups if needed."""
# l->left
assert expand_pattern("l*", inventory) == {"a", "b"}
# any single letter name
assert expand_pattern("?", inventory) == {"a", "b", "c"}


def test_expand_pattern_regex(inventory: Inventory):
"""Simple names are matched, recurring into groups if needed."""
# simple character matching - "l" matches "left" but not "all"
assert expand_pattern("~l", inventory) == {"a", "b"}
# "b" matches an exact host, not any group
assert expand_pattern("~b", inventory) == {"b"}
# "a" will match all
assert expand_pattern("~a", inventory) == {"a", "b", "c"}


def test_get_hosts(inventory: Inventory):
"""Multiple names/patterns can be combined."""
assert get_hosts("a", inventory) == ["a"]
# the two pattern separators are handled
assert get_hosts("a:b", inventory) == ["a", "b"]
assert get_hosts("a,b", inventory) == ["a", "b"]
# difference works
assert get_hosts("left:!right", inventory) == ["a"]
# intersection works
assert get_hosts("left:&right", inventory) == ["b"]
# intersection is taken with the intersection of the intersection groups
assert get_hosts("all:&left:&right", inventory) == ["b"]
# when the intersections ends up empty, so does the result
assert get_hosts("all:&a:&c", inventory) == []
# negation is taken with the union of negation groups
assert get_hosts("all:!a:!c", inventory) == ["b"]


@pytest.mark.parametrize("left", ["h1", "!h1", "&h1", "~h1", "*h1"])
@pytest.mark.parametrize("sep", [":", ","])
@pytest.mark.parametrize("right", ["h2", "!h2", "&h2", "~h2", "*h2", ""])
def test_parse_hostspec(left: str, sep: str, right: str):
"""Ansible's host patterns are parsed without issue."""
if right:
pattern = f"{left}{sep}{right}"
else:
pattern = left
assert parse_hostspec(pattern) == (pattern, {})
120 changes: 90 additions & 30 deletions testinfra/utils/ansible_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import ipaddress
import json
import os
import re
import tempfile
from typing import Any, Callable, Iterator, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Union

import testinfra
import testinfra.host
Expand All @@ -26,6 +27,86 @@

local = testinfra.get_host("local://")

Inventory = Dict[str, Any]


def expand_group(name: str, inventory: Inventory) -> Iterator[str]:
"""Return all the underlying hostnames for the given group name/pattern."""
group = inventory.get(name)
if group is None:
return

# this is a meta-group so recurse
children = group.get("children")
if children is not None:
for child in children:
yield from expand_group(child, inventory)

# this is a regular group
hosts = group.get("hosts")
if hosts is not None:
yield from iter(hosts)


def expand_pattern(pattern: str, inventory: Inventory) -> Set[str]:
"""Return all underlying hostnames for the given name/pattern."""
if pattern.startswith("~"):
# this is a regex, so cut off the indicating character
pattern = re.compile(pattern[1:])
# match is used, not search or fullmatch
filter_ = lambda l: [i for i in l if pattern.match(i)]
else:
filter_ = lambda l: fnmatch.filter(l, pattern)

# hosts in the inventory directly matched by the pattern
matching_hosts = set(filter_(expand_group('all', inventory)))

# look for matches in the groups
for group in filter_(inventory.keys()):
if group == "_meta":
continue
matching_hosts.update(expand_group(group, inventory))

return matching_hosts


def get_hosts(pattern: str, inventory: Inventory) -> List[str]:
"""Return hostnames with a name/group that matches the given name/pattern.

Reference:
https://docs.ansible.com/ansible/latest/inventory_guide/intro_patterns.html

This is but a shadow of Ansible's full InventoryManager. The source of the
`inventory_hostnames` module would be a good starting point for a more
faithful reproduction if this turns out to be insufficient.
"""
from ansible.inventory.manager import split_host_pattern

patterns = split_host_pattern(pattern)

positive = set()
intersect = None
negative = set()

for requirement in patterns:
if requirement.startswith('&'):
expanded = expand_pattern(requirement[1:], inventory)
if intersect is None:
intersect = expanded
else:
intersect &= expanded
elif requirement.startswith('!'):
negative.update(expand_pattern(requirement[1:], inventory))
else:
positive.update(expand_pattern(requirement, inventory))

result = positive
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment about something that I don't see as worth changing right now as there hasn't been any response on this yet, and a significant improvement is much better than nothing.

I suspect a perfect copy of the logic from ansible would use something like this here:

Suggested change
result = positive
if positive:
result = positive
else:
result = expand_pattern('all')

Then you imitate behaviour like !right which is equivalent to all:!right, but this would also be best with additions to manual testing and adding of tests to this codebase for the confirmed functionality.

Given that adding all: to the start of a host pattern/query is equivalent, it's a fine enough workaround. If this change gets merged/released, I'll follow up with this change.

if intersect is not None:
result &= intersect
if negative:
result -= negative
return sorted(result)


def get_ansible_config() -> configparser.ConfigParser:
fname = os.environ.get("ANSIBLE_CONFIG")
Expand All @@ -45,9 +126,6 @@ def get_ansible_config() -> configparser.ConfigParser:
return config


Inventory = dict[str, Any]


def get_ansible_inventory(
config: configparser.ConfigParser, inventory_file: Optional[str]
) -> Inventory:
Expand Down Expand Up @@ -216,16 +294,8 @@ def get_config(
return testinfra.get_host(spec, **kwargs)


def itergroup(inventory: Inventory, group: str) -> Iterator[str]:
for host in inventory.get(group, {}).get("hosts", []):
yield host
for g in inventory.get(group, {}).get("children", []):
for host in itergroup(inventory, g):
yield host


def is_empty_inventory(inventory: Inventory) -> bool:
return not any(True for _ in itergroup(inventory, "all"))
return next(expand_group("all", inventory), None) is None


class AnsibleRunner:
Expand Down Expand Up @@ -275,25 +345,15 @@ def __init__(self, inventory_file: Optional[str] = None):

def get_hosts(self, pattern: str = "all") -> list[str]:
inventory = self.inventory
result = set()
if is_empty_inventory(inventory):
# empty inventory should not return any hosts except for localhost
if pattern == "localhost":
result.add("localhost")
else:
raise RuntimeError(
"No inventory was parsed (missing file ?), "
"only implicit localhost is available"
)
else:
for group in inventory:
groupmatch = fnmatch.fnmatch(group, pattern)
if groupmatch:
result |= set(itergroup(inventory, group))
for host in inventory[group].get("hosts", []):
if fnmatch.fnmatch(host, pattern):
result.add(host)
return sorted(result)
return ["localhost"]
raise RuntimeError(
"No inventory was parsed (missing file ?), "
"only implicit localhost is available"
)
return get_hosts(pattern, inventory)

@functools.cached_property
def inventory(self) -> Inventory:
Expand All @@ -315,7 +375,7 @@ def get_variables(self, host: str) -> dict[str, Any]:
for group in sorted(inventory):
if group == "_meta":
continue
groups[group] = sorted(itergroup(inventory, group))
groups[group] = sorted(expand_group(group, inventory))
if host in groups[group]:
group_names.append(group)

Expand Down