Skip to content

Commit

Permalink
Merge pull request #1060 from bluetech/pre-typing-fixes
Browse files Browse the repository at this point in the history
Pre-typing fixes/improvements
  • Loading branch information
bluetech committed Apr 6, 2024
2 parents 3e92604 + cdff86a commit 4720808
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ repos:
additional_dependencies:
- pytest>=7.0.0
- execnet>=2.1.0
- py>=1.10.0
- types-psutil
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ select = [
"W", # pycodestyle
"T10", # flake8-debugger
"PIE", # flake8-pie
"FA", # flake8-future-annotations
"PGH", # pygrep-hooks
"PLE", # pylint error
"PLW", # pylint warning
Expand Down Expand Up @@ -135,6 +136,7 @@ lines-after-imports = 2

[tool.mypy]
mypy_path = ["src"]
files = ["src", "testing"]
# TODO: Enable this & fix errors.
# check_untyped_defs = true
disallow_any_generics = true
Expand Down
10 changes: 1 addition & 9 deletions src/xdist/dsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,6 @@ def worker_collectreport(self, node, rep):
assert not rep.passed
self._failed_worker_collectreport(node, rep)

def worker_warning_captured(self, warning_message, when, item):
"""Emitted when a node calls the pytest_warning_captured hook (deprecated in 6.0)."""
# This hook as been removed in pytest 7.1, and we can remove support once we only
# support pytest >=7.1.
kwargs = dict(warning_message=warning_message, when=when, item=item)
self.config.hook.pytest_warning_captured.call_historic(kwargs=kwargs)

def worker_warning_recorded(self, warning_message, when, nodeid, location):
"""Emitted when a node calls the pytest_warning_recorded hook."""
kwargs = dict(
Expand Down Expand Up @@ -374,10 +367,9 @@ def triggershutdown(self):
def handle_crashitem(self, nodeid, worker):
# XXX get more reporting info by recording pytest_runtest_logstart?
# XXX count no of failures and retry N times
runner = self.config.pluginmanager.getplugin("runner")
fspath = nodeid.split("::")[0]
msg = f"worker {worker.gateway.id!r} crashed while running {nodeid!r}"
rep = runner.TestReport(
rep = pytest.TestReport(
nodeid, (fspath, None, fspath), (), "failed", msg, "???"
)
rep.node = worker
Expand Down
15 changes: 8 additions & 7 deletions src/xdist/looponfail.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
the controlling process which should best never happen.
"""

from __future__ import annotations

import os
from pathlib import Path
import sys
import time
from typing import Dict
from typing import Sequence

from _pytest._io import TerminalWriter
Expand Down Expand Up @@ -45,7 +46,7 @@ def pytest_cmdline_main(config):
return 2 # looponfail only can get stop with ctrl-C anyway


def looponfail_main(config: "pytest.Config") -> None:
def looponfail_main(config: pytest.Config) -> None:
remotecontrol = RemoteControl(config)
config_roots = config.getini("looponfailroots")
if not config_roots:
Expand Down Expand Up @@ -79,9 +80,7 @@ def trace(self, *args):
def initgateway(self):
return execnet.makegateway("popen")

def setup(self, out=None):
if out is None:
out = TerminalWriter()
def setup(self):
if hasattr(self, "gateway"):
raise ValueError("already have gateway %r" % self.gateway)
self.trace("setting up worker session")
Expand All @@ -93,6 +92,8 @@ def setup(self, out=None):
)
remote_outchannel = channel.receive()

out = TerminalWriter()

def write(s):
out._file.write(s)
out._file.flush()
Expand Down Expand Up @@ -238,7 +239,7 @@ def main(self):
class StatRecorder:
def __init__(self, rootdirlist: Sequence[Path]) -> None:
self.rootdirlist = rootdirlist
self.statcache: Dict[Path, os.stat_result] = {}
self.statcache: dict[Path, os.stat_result] = {}
self.check() # snapshot state

def fil(self, p: Path) -> bool:
Expand All @@ -256,7 +257,7 @@ def waitonchange(self, checkinterval=1.0):

def check(self, removepycfiles: bool = True) -> bool:
changed = False
newstat: Dict[Path, os.stat_result] = {}
newstat: dict[Path, os.stat_result] = {}
for rootdir in self.rootdirlist:
for path in visit_path(rootdir, filter=self.fil, recurse=self.rec):
oldstat = self.statcache.pop(path, None)
Expand Down
29 changes: 16 additions & 13 deletions src/xdist/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import contextlib
import enum
import os
import sys
import time
Expand Down Expand Up @@ -57,10 +58,12 @@ def worker_title(title):
pass


class WorkerInteractor:
SHUTDOWN_MARK = object()
QUEUE_REPLACED_MARK = object()
class Marker(enum.Enum):
SHUTDOWN = 0
QUEUE_REPLACED = 1


class WorkerInteractor:
def __init__(self, config, channel):
self.config = config
self.workerid = config.workerinput.get("workerid", "?")
Expand All @@ -79,7 +82,7 @@ def _get_next_item_index(self):
is replaced concurrently in another thread.
"""
result = self.torun.get()
while result is self.QUEUE_REPLACED_MARK:
while result is Marker.QUEUE_REPLACED:
result = self.torun.get()
return result

Expand Down Expand Up @@ -114,8 +117,8 @@ def pytest_collection(self, session):
self.sendevent("collectionstart")

def handle_command(self, command):
if command is self.SHUTDOWN_MARK:
self.torun.put(self.SHUTDOWN_MARK)
if command is Marker.SHUTDOWN:
self.torun.put(Marker.SHUTDOWN)
return

name, kwargs = command
Expand All @@ -128,7 +131,7 @@ def handle_command(self, command):
for i in range(len(self.session.items)):
self.torun.put(i)
elif name == "shutdown":
self.torun.put(self.SHUTDOWN_MARK)
self.torun.put(Marker.SHUTDOWN)
elif name == "steal":
self.steal(kwargs["indices"])

Expand All @@ -149,14 +152,14 @@ def old_queue_get_nowait_noraise():
self.torun.put(i)

self.sendevent("unscheduled", indices=stolen)
old_queue.put(self.QUEUE_REPLACED_MARK)
old_queue.put(Marker.QUEUE_REPLACED)

@pytest.hookimpl
def pytest_runtestloop(self, session):
self.log("entering main loop")
self.channel.setcallback(self.handle_command, endmarker=self.SHUTDOWN_MARK)
self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN)
self.nextitem_index = self._get_next_item_index()
while self.nextitem_index is not self.SHUTDOWN_MARK:
while self.nextitem_index is not Marker.SHUTDOWN:
self.run_one_test()
if session.shouldfail or session.shouldstop:
break
Expand All @@ -168,16 +171,16 @@ def run_one_test(self):

items = self.session.items
item = items[self.item_index]
if self.nextitem_index is self.SHUTDOWN_MARK:
if self.nextitem_index is Marker.SHUTDOWN:
nextitem = None
else:
nextitem = items[self.nextitem_index]

worker_title("[pytest-xdist running] %s" % item.nodeid)

start = time.time()
start = time.perf_counter()
self.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
duration = time.time() - start
duration = time.perf_counter() - start

worker_title("[pytest-xdist idle]")

Expand Down
7 changes: 1 addition & 6 deletions src/xdist/scheduler/each.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,7 @@ def mark_test_complete(self, node, item_index, duration=0):
self.node2pending[node].remove(item_index)

def mark_test_pending(self, item):
self.pending.insert(
0,
self.collection.index(item),
)
for node in self.node2pending:
self.check_schedule(node)
raise NotImplementedError()

def remove_node(self, node):
# KeyError if we didn't get an add_node() yet
Expand Down
17 changes: 7 additions & 10 deletions src/xdist/scheduler/loadscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(self, config, log=None):
self.collection = None

self.workqueue = OrderedDict()
self.assigned_work = OrderedDict()
self.registered_collections = OrderedDict()
self.assigned_work = {}
self.registered_collections = {}

if log is None:
self.log = Producer("loadscopesched")
Expand Down Expand Up @@ -156,7 +156,7 @@ def add_node(self, node):
bootstraps a new node.
"""
assert node not in self.assigned_work
self.assigned_work[node] = OrderedDict()
self.assigned_work[node] = {}

def remove_node(self, node):
"""Remove a node from the scheduler.
Expand Down Expand Up @@ -252,7 +252,7 @@ def _assign_work_unit(self, node):
scope, work_unit = self.workqueue.popitem(last=False)

# Keep track of the assigned work
assigned_to_node = self.assigned_work.setdefault(node, default=OrderedDict())
assigned_to_node = self.assigned_work.setdefault(node, {})
assigned_to_node[scope] = work_unit

# Ask the node to execute the workload
Expand Down Expand Up @@ -349,10 +349,10 @@ def schedule(self):
return

# Determine chunks of work (scopes)
unsorted_workqueue = OrderedDict()
unsorted_workqueue = {}
for nodeid in self.collection:
scope = self._split_scope(nodeid)
work_unit = unsorted_workqueue.setdefault(scope, default=OrderedDict())
work_unit = unsorted_workqueue.setdefault(scope, {})
work_unit[nodeid] = False

# Insert tests scopes into work queue ordered by number of tests.
Expand All @@ -368,7 +368,7 @@ def schedule(self):
self.log(f"Shutting down {extra_nodes} nodes")

for _ in range(extra_nodes):
unused_node, assigned = self.assigned_work.popitem(last=True)
unused_node, assigned = self.assigned_work.popitem()

self.log(f"Shutting down unused node {unused_node}")
unused_node.shutdown()
Expand Down Expand Up @@ -407,9 +407,6 @@ def _check_nodes_have_same_collection(self):
same_collection = False
self.log(msg)

if self.config is None:
continue

rep = pytest.CollectReport(
nodeid=node.gateway.id,
outcome="failed",
Expand Down
37 changes: 14 additions & 23 deletions src/xdist/workermanage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

import enum
import fnmatch
import os
from pathlib import Path
import re
import sys
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
import uuid

Expand Down Expand Up @@ -60,7 +59,7 @@ def __init__(self, config, specs=None, defaultchdir="pyexecnetcache") -> None:
self.specs.append(spec)
self.roots = self._getrsyncdirs()
self.rsyncoptions = self._getrsyncoptions()
self._rsynced_specs: Set[Tuple[Any, Any]] = set()
self._rsynced_specs: set[tuple[Any, Any]] = set()

def rsync_roots(self, gateway):
"""Rsync the set of roots to the node's gateway cwd."""
Expand Down Expand Up @@ -89,7 +88,7 @@ def teardown_nodes(self):
def _getxspecs(self):
return [execnet.XSpec(x) for x in parse_spec_config(self.config)]

def _getrsyncdirs(self) -> List[Path]:
def _getrsyncdirs(self) -> list[Path]:
for spec in self.specs:
if not spec.popen or spec.chdir:
break
Expand Down Expand Up @@ -174,7 +173,7 @@ def __init__(
self,
sourcedir: PathLike,
*,
ignores: Optional[Sequence[PathLike]] = None,
ignores: Sequence[PathLike] | None = None,
verbose: bool = True,
) -> None:
if ignores is None:
Expand All @@ -201,7 +200,7 @@ def _report_send_file(self, gateway, modified_rel_path):
print(f"{gateway.spec}:{remotepath} <= {path}")


def make_reltoroot(roots: Sequence[Path], args: List[str]) -> List[str]:
def make_reltoroot(roots: Sequence[Path], args: list[str]) -> list[str]:
# XXX introduce/use public API for splitting pytest args
splitcode = "::"
result = []
Expand All @@ -216,7 +215,7 @@ def make_reltoroot(roots: Sequence[Path], args: List[str]) -> List[str]:
result.append(arg)
continue
for root in roots:
x: Optional[Path]
x: Path | None
try:
x = fspath.relative_to(root)
except ValueError:
Expand All @@ -230,9 +229,11 @@ def make_reltoroot(roots: Sequence[Path], args: List[str]) -> List[str]:
return result


class WorkerController:
ENDMARK = -1
class Marker(enum.Enum):
END = -1


class WorkerController:
class RemoteHook:
@pytest.hookimpl(trylast=True)
def pytest_xdist_getremotemodule(self):
Expand Down Expand Up @@ -283,7 +284,7 @@ def setup(self):
self.channel.send((self.workerinput, args, option_dict, change_sys_path))

if self.putevent:
self.channel.setcallback(self.process_from_remote, endmarker=self.ENDMARK)
self.channel.setcallback(self.process_from_remote, endmarker=Marker.END)

def ensure_teardown(self):
if hasattr(self, "channel"):
Expand Down Expand Up @@ -331,7 +332,7 @@ def process_from_remote(self, eventcall):
avoid raising exceptions or doing heavy work.
"""
try:
if eventcall == self.ENDMARK:
if eventcall is Marker.END:
err = self.channel._getremoteerror()
if not self._down:
if not err or isinstance(err, EOFError):
Expand Down Expand Up @@ -374,16 +375,6 @@ def process_from_remote(self, eventcall):
nodeid=kwargs["nodeid"],
fslocation=kwargs["nodeid"],
)
elif eventname == "warning_captured":
warning_message = unserialize_warning_message(
kwargs["warning_message_data"]
)
self.notify_inproc(
eventname,
warning_message=warning_message,
when=kwargs["when"],
item=kwargs["item"],
)
elif eventname == "warning_recorded":
warning_message = unserialize_warning_message(
kwargs["warning_message_data"]
Expand Down
Loading

0 comments on commit 4720808

Please sign in to comment.