Skip to content

Commit

Permalink
junitxml: convert from py.xml to xml.etree.ElementTree
Browse files Browse the repository at this point in the history
Part of the effort to reduce dependency on the py library.

Besides that, py.xml implements its own XML serialization which is
pretty scary.

I tried to keep the code with minimal changes (though it could use some
cleanups). The differences in behavior I have noticed are:

- Attributes in the output are not sorted.

- Some unneeded escaping is no longer performed, for example escaping
  `"` to `"` in a text node.
  • Loading branch information
bluetech committed Jul 23, 2020
1 parent e799b7d commit d0e3c5e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 80 deletions.
135 changes: 60 additions & 75 deletions src/_pytest/junitxml.py
Expand Up @@ -12,6 +12,7 @@
import os
import platform
import re
import xml.etree.ElementTree as ET
from datetime import datetime
from typing import Callable
from typing import Dict
Expand All @@ -21,14 +22,11 @@
from typing import Tuple
from typing import Union

import py

import pytest
from _pytest import deprecated
from _pytest import nodes
from _pytest import timing
from _pytest._code.code import ExceptionRepr
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import filename_arg
from _pytest.config.argparsing import Parser
Expand All @@ -38,29 +36,19 @@
from _pytest.terminal import TerminalReporter
from _pytest.warnings import _issue_warning_captured

if TYPE_CHECKING:
from typing import Type


xml_key = StoreKey["LogXML"]()


class Junit(py.xml.Namespace):
pass


def bin_xml_escape(arg: object) -> py.xml.raw:
r"""Visually escape an object into valid a XML string.
def bin_xml_escape(arg: object) -> str:
r"""Visually escape invalid XML characters.
For example, transforms
'hello\aworld\b'
into
'hello#x07world#x08'
Note that the #xABs are *not* XML escapes - missing the ampersand &#xAB.
The idea is to escape visually for the user rather than for XML itself.
The result is also entity-escaped and wrapped in py.xml.raw() so it can
be embedded directly.
"""

def repl(matchobj: Match[str]) -> str:
Expand All @@ -76,7 +64,7 @@ def repl(matchobj: Match[str]) -> str:
illegal_xml_re = (
"[^\u0009\u000A\u000D\u0020-\u007E\u0080-\uD7FF\uE000-\uFFFD\u10000-\u10FFFF]"
)
return py.xml.raw(re.sub(illegal_xml_re, repl, py.xml.escape(str(arg))))
return re.sub(illegal_xml_re, repl, str(arg))


def merge_family(left, right) -> None:
Expand Down Expand Up @@ -108,12 +96,12 @@ def __init__(self, nodeid: Union[str, TestReport], xml: "LogXML") -> None:
self.add_stats = self.xml.add_stats
self.family = self.xml.family
self.duration = 0
self.properties = [] # type: List[Tuple[str, py.xml.raw]]
self.nodes = [] # type: List[py.xml.Tag]
self.attrs = {} # type: Dict[str, Union[str, py.xml.raw]]
self.properties = [] # type: List[Tuple[str, str]]
self.nodes = [] # type: List[ET.Element]
self.attrs = {} # type: Dict[str, str]

def append(self, node: py.xml.Tag) -> None:
self.xml.add_stats(type(node).__name__)
def append(self, node: ET.Element) -> None:
self.xml.add_stats(node.tag)
self.nodes.append(node)

def add_property(self, name: str, value: object) -> None:
Expand All @@ -122,17 +110,17 @@ def add_property(self, name: str, value: object) -> None:
def add_attribute(self, name: str, value: object) -> None:
self.attrs[str(name)] = bin_xml_escape(value)

def make_properties_node(self) -> Union[py.xml.Tag, str]:
def make_properties_node(self) -> Optional[ET.Element]:
"""Return a Junit node containing custom properties, if any.
"""
if self.properties:
return Junit.properties(
[
Junit.property(name=name, value=value)
for name, value in self.properties
]
properties = ET.Element("properties")
properties.extend(
ET.Element("property", name=name, value=value)
for name, value in self.properties
)
return ""
return properties
return None

def record_testreport(self, testreport: TestReport) -> None:
names = mangle_test_address(testreport.nodeid)
Expand All @@ -144,7 +132,7 @@ def record_testreport(self, testreport: TestReport) -> None:
"classname": ".".join(classnames),
"name": bin_xml_escape(names[-1]),
"file": testreport.location[0],
} # type: Dict[str, Union[str, py.xml.raw]]
} # type: Dict[str, str]
if testreport.location[1] is not None:
attrs["line"] = str(testreport.location[1])
if hasattr(testreport, "url"):
Expand All @@ -164,16 +152,17 @@ def record_testreport(self, testreport: TestReport) -> None:
temp_attrs[key] = self.attrs[key]
self.attrs = temp_attrs

def to_xml(self) -> py.xml.Tag:
testcase = Junit.testcase(time="%.3f" % self.duration, **self.attrs)
testcase.append(self.make_properties_node())
for node in self.nodes:
testcase.append(node)
def to_xml(self) -> ET.Element:
testcase = ET.Element("testcase", self.attrs, time="%.3f" % self.duration)
properties = self.make_properties_node()
if properties is not None:
testcase.append(properties)
testcase.extend(self.nodes)
return testcase

def _add_simple(self, kind: "Type[py.xml.Tag]", message: str, data=None) -> None:
data = bin_xml_escape(data)
node = kind(data, message=message)
def _add_simple(self, tag: str, message: str, data: Optional[str] = None) -> None:
node = ET.Element(tag, message=message)
node.text = bin_xml_escape(data)
self.append(node)

def write_captured_output(self, report: TestReport) -> None:
Expand Down Expand Up @@ -203,36 +192,33 @@ def _prepare_content(self, content: str, header: str) -> str:
return "\n".join([header.center(80, "-"), content, ""])

def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
tag = getattr(Junit, jheader)
self.append(tag(bin_xml_escape(content)))
tag = ET.Element(jheader)
tag.text = bin_xml_escape(content)
self.append(tag)

def append_pass(self, report: TestReport) -> None:
self.add_stats("passed")

def append_failure(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
if hasattr(report, "wasxfail"):
self._add_simple(Junit.skipped, "xfail-marked test passes unexpectedly")
self._add_simple("skipped", "xfail-marked test passes unexpectedly")
else:
assert report.longrepr is not None
if getattr(report.longrepr, "reprcrash", None) is not None:
message = report.longrepr.reprcrash.message
else:
message = str(report.longrepr)
message = bin_xml_escape(message)
fail = Junit.failure(message=message)
fail.append(bin_xml_escape(report.longrepr))
self.append(fail)
self._add_simple("failure", message, str(report.longrepr))

def append_collect_error(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
assert report.longrepr is not None
self.append(
Junit.error(bin_xml_escape(report.longrepr), message="collection failure")
)
self._add_simple("error", "collection failure", str(report.longrepr))

def append_collect_skipped(self, report: TestReport) -> None:
self._add_simple(Junit.skipped, "collection skipped", report.longrepr)
self._add_simple("skipped", "collection skipped", str(report.longrepr))

def append_error(self, report: TestReport) -> None:
assert report.longrepr is not None
Expand All @@ -245,40 +231,34 @@ def append_error(self, report: TestReport) -> None:
msg = 'failed on teardown with "{}"'.format(reason)
else:
msg = 'failed on setup with "{}"'.format(reason)
self._add_simple(Junit.error, msg, report.longrepr)
self._add_simple("error", msg, str(report.longrepr))

def append_skipped(self, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
xfailreason = report.wasxfail
if xfailreason.startswith("reason: "):
xfailreason = xfailreason[8:]
self.append(
Junit.skipped(
"", type="pytest.xfail", message=bin_xml_escape(xfailreason)
)
)
xfailreason = bin_xml_escape(xfailreason)
skipped = ET.Element("skipped", type="pytest.xfail", message=xfailreason)
self.append(skipped)
else:
assert report.longrepr is not None
filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "):
skipreason = skipreason[9:]
details = "{}:{}: {}".format(filename, lineno, skipreason)

self.append(
Junit.skipped(
bin_xml_escape(details),
type="pytest.skip",
message=bin_xml_escape(skipreason),
)
)
skipped = ET.Element("skipped", type="pytest.skip", message=skipreason)
skipped.text = bin_xml_escape(details)
self.append(skipped)
self.write_captured_output(report)

def finalize(self) -> None:
data = self.to_xml().unicode(indent=0)
data = self.to_xml()
self.__dict__.clear()
# Type ignored becuase mypy doesn't like overriding a method.
# Also the return value doesn't match...
self.to_xml = lambda: py.xml.raw(data) # type: ignore
self.to_xml = lambda: data # type: ignore[assignment]


def _warn_incompatibility_with_xunit2(
Expand Down Expand Up @@ -502,7 +482,7 @@ def __init__(
{}
) # type: Dict[Tuple[Union[str, TestReport], object], _NodeReporter]
self.node_reporters_ordered = [] # type: List[_NodeReporter]
self.global_properties = [] # type: List[Tuple[str, py.xml.raw]]
self.global_properties = [] # type: List[Tuple[str, str]]

# List of reports that failed on call but teardown is pending.
self.open_reports = [] # type: List[TestReport]
Expand Down Expand Up @@ -654,7 +634,7 @@ def pytest_collectreport(self, report: TestReport) -> None:
def pytest_internalerror(self, excrepr: ExceptionRepr) -> None:
reporter = self.node_reporter("internal")
reporter.attrs.update(classname="pytest", name="internal")
reporter._add_simple(Junit.error, "internal error", excrepr)
reporter._add_simple("error", "internal error", str(excrepr))

def pytest_sessionstart(self) -> None:
self.suite_start_time = timing.time()
Expand All @@ -676,9 +656,8 @@ def pytest_sessionfinish(self) -> None:
)
logfile.write('<?xml version="1.0" encoding="utf-8"?>')

suite_node = Junit.testsuite(
self._get_global_properties_node(),
[x.to_xml() for x in self.node_reporters_ordered],
suite_node = ET.Element(
"testsuite",
name=self.suite_name,
errors=str(self.stats["error"]),
failures=str(self.stats["failure"]),
Expand All @@ -688,7 +667,13 @@ def pytest_sessionfinish(self) -> None:
timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(),
hostname=platform.node(),
)
logfile.write(Junit.testsuites([suite_node]).unicode(indent=0))
global_properties = self._get_global_properties_node()
if global_properties is not None:
suite_node.append(global_properties)
suite_node.extend(x.to_xml() for x in self.node_reporters_ordered)
testsuites = ET.Element("testsuites")
testsuites.append(suite_node)
logfile.write(ET.tostring(testsuites, encoding="unicode"))
logfile.close()

def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None:
Expand All @@ -699,14 +684,14 @@ def add_global_property(self, name: str, value: object) -> None:
_check_record_param_type("name", name)
self.global_properties.append((name, bin_xml_escape(value)))

def _get_global_properties_node(self) -> Union[py.xml.Tag, str]:
def _get_global_properties_node(self) -> Optional[ET.Element]:
"""Return a Junit node containing custom properties, if any.
"""
if self.global_properties:
return Junit.properties(
[
Junit.property(name=name, value=value)
for name, value in self.global_properties
]
properties = ET.Element("properties")
properties.extend(
ET.Element("property", name=name, value=value)
for name, value in self.global_properties
)
return ""
return properties
return None
11 changes: 6 additions & 5 deletions testing/test_junitxml.py
Expand Up @@ -323,8 +323,9 @@ def test_function(arg):
node = dom.find_first_by_tag("testsuite")
node.assert_attr(errors=1, failures=1, tests=1)
first, second = dom.find_by_tag("testcase")
if not first or not second or first == second:
assert 0
assert first
assert second
assert first != second
fnode = first.find_first_by_tag("failure")
fnode.assert_attr(message="Exception: Call Exception")
snode = second.find_first_by_tag("error")
Expand Down Expand Up @@ -535,7 +536,7 @@ def test_fail():
node = dom.find_first_by_tag("testsuite")
tnode = node.find_first_by_tag("testcase")
fnode = tnode.find_first_by_tag("failure")
fnode.assert_attr(message="AssertionError: An error assert 0")
fnode.assert_attr(message="AssertionError: An error\nassert 0")

@parametrize_families
def test_failure_escape(self, testdir, run_and_parse, xunit_family):
Expand Down Expand Up @@ -995,14 +996,14 @@ def test_invalid_xml_escape():
# 0xD, 0xD7FF, 0xE000, 0xFFFD, 0x10000, 0x10FFFF)

for i in invalid:
got = bin_xml_escape(chr(i)).uniobj
got = bin_xml_escape(chr(i))
if i <= 0xFF:
expected = "#x%02X" % i
else:
expected = "#x%04X" % i
assert got == expected
for i in valid:
assert chr(i) == bin_xml_escape(chr(i)).uniobj
assert chr(i) == bin_xml_escape(chr(i))


def test_logxml_path_expansion(tmpdir, monkeypatch):
Expand Down

0 comments on commit d0e3c5e

Please sign in to comment.