Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/_pytest/assertion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from _pytest.assertion import rewrite
from _pytest.assertion import truncate
from _pytest.assertion import util
from _pytest.assertion.rewrite import assertstate_key
from _pytest.compat import TYPE_CHECKING
from _pytest.config import hookimpl

Expand Down Expand Up @@ -82,13 +83,13 @@ def __init__(self, config, mode):

def install_importhook(config):
"""Try to install the rewrite hook, raise SystemError if it fails."""
config._assertstate = AssertionState(config, "rewrite")
config._assertstate.hook = hook = rewrite.AssertionRewritingHook(config)
config._store[assertstate_key] = AssertionState(config, "rewrite")
config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
config._assertstate.trace("installed rewrite import hook")
config._store[assertstate_key].trace("installed rewrite import hook")

def undo():
hook = config._assertstate.hook
hook = config._store[assertstate_key].hook
if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook)

Expand All @@ -100,7 +101,7 @@ def pytest_collection(session: "Session") -> None:
# this hook is only called when test modules are collected
# so for example not in the master process of pytest-xdist
# (which does not collect test modules)
assertstate = getattr(session.config, "_assertstate", None)
assertstate = session.config._store.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(session)
Expand Down Expand Up @@ -163,7 +164,7 @@ def call_assertion_pass_hook(lineno, orig, expl):


def pytest_sessionfinish(session):
assertstate = getattr(session.config, "_assertstate", None)
assertstate = session.config._store.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(None)
Expand Down
13 changes: 11 additions & 2 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,18 @@
format_explanation as _format_explanation,
)
from _pytest.compat import fspath
from _pytest.compat import TYPE_CHECKING
from _pytest.pathlib import fnmatch_ex
from _pytest.pathlib import Path
from _pytest.pathlib import PurePath
from _pytest.store import StoreKey

if TYPE_CHECKING:
from _pytest.assertion import AssertionState # noqa: F401


assertstate_key = StoreKey["AssertionState"]()


# pytest caches rewritten pycs in pycache dirs
PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version)
Expand Down Expand Up @@ -65,7 +74,7 @@ def set_session(self, session):
def find_spec(self, name, path=None, target=None):
if self._writing_pyc:
return None
state = self.config._assertstate
state = self.config._store[assertstate_key]
if self._early_rewrite_bailout(name, state):
return None
state.trace("find_module called for: %s" % name)
Expand Down Expand Up @@ -104,7 +113,7 @@ def create_module(self, spec):

def exec_module(self, module):
fn = Path(module.__spec__.origin)
state = self.config._assertstate
state = self.config._store[assertstate_key]

self._rewritten_names.add(module.__name__)

Expand Down
17 changes: 12 additions & 5 deletions src/_pytest/mark/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,28 @@
import platform
import sys
import traceback
from typing import Any
from typing import Dict

from ..outcomes import fail
from ..outcomes import TEST_OUTCOME
from _pytest.config import Config
from _pytest.store import StoreKey


def cached_eval(config, expr, d):
if not hasattr(config, "_evalcache"):
config._evalcache = {}
evalcache_key = StoreKey[Dict[str, Any]]()


def cached_eval(config: Config, expr: str, d: Dict[str, object]) -> Any:
default = {} # type: Dict[str, object]
evalcache = config._store.setdefault(evalcache_key, default)
try:
return config._evalcache[expr]
return evalcache[expr]
except KeyError:
import _pytest._code

exprcode = _pytest._code.compile(expr, mode="eval")
config._evalcache[expr] = x = eval(exprcode, d)
evalcache[expr] = x = eval(exprcode, d)
return x


Expand Down
9 changes: 9 additions & 0 deletions src/_pytest/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ def get(self, key: StoreKey[T], default: D) -> Union[T, D]:
except KeyError:
return default

def setdefault(self, key: StoreKey[T], default: T) -> T:
"""Return the value of key if already set, otherwise set the value
of key to default and return default."""
try:
return self[key]
except KeyError:
self[key] = default
return default

def __delitem__(self, key: StoreKey[T]) -> None:
"""Delete the value for key.

Expand Down
8 changes: 8 additions & 0 deletions testing/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def test_store() -> None:
with pytest.raises(KeyError):
store[key1]

# setdefault
store[key1] = "existing"
assert store.setdefault(key1, "default") == "existing"
assert store[key1] == "existing"
key_setdefault = StoreKey[bytes]()
assert store.setdefault(key_setdefault, b"default") == b"default"
assert store[key_setdefault] == b"default"

# Can't accidentally add attributes to store object itself.
with pytest.raises(AttributeError):
store.foo = "nope" # type: ignore[attr-defined] # noqa: F821
Expand Down