Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code/source: some cleanups #7438

Merged
merged 13 commits into from
Jul 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions changelog/7438.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Some changes were made to the internal ``_pytest._code.source``, listed here
for the benefit of plugin authors who may be using it:

- The ``deindent`` argument to ``Source()`` has been removed, now it is always true.
- Support for zero or multiple arguments to ``Source()`` has been removed.
- Support for comparing ``Source`` with an ``str`` has been removed.
- The methods ``Source.isparseable()`` and ``Source.putaround()`` have been removed.
- The method ``Source.compile()`` and function ``_pytest._code.compile()`` have
been removed; use plain ``compile()`` instead.
- The function ``_pytest._code.source.getsource()`` has been removed; use
``Source()`` directly instead.
2 changes: 0 additions & 2 deletions src/_pytest/_code/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .code import getrawcode
from .code import Traceback
from .code import TracebackEntry
from .source import compile_ as compile
from .source import Source

__all__ = [
Expand All @@ -19,6 +18,5 @@
"getrawcode",
"Traceback",
"TracebackEntry",
"compile",
"Source",
]
220 changes: 23 additions & 197 deletions src/_pytest/_code/source.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,43 @@
import ast
import inspect
import linecache
import sys
import textwrap
import tokenize
import warnings
from bisect import bisect_right
from types import CodeType
from types import FrameType
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import py

from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING

if TYPE_CHECKING:
from typing_extensions import Literal


class Source:
""" an immutable object holding a source code fragment,
possibly deindenting it.
"""An immutable object holding a source code fragment.

When using Source(...), the source lines are deindented.
"""

_compilecounter = 0

def __init__(self, *parts, **kwargs) -> None:
self.lines = lines = [] # type: List[str]
de = kwargs.get("deindent", True)
for part in parts:
if not part:
partlines = [] # type: List[str]
elif isinstance(part, Source):
partlines = part.lines
elif isinstance(part, (tuple, list)):
partlines = [x.rstrip("\n") for x in part]
elif isinstance(part, str):
partlines = part.split("\n")
else:
partlines = getsource(part, deindent=de).lines
if de:
partlines = deindent(partlines)
lines.extend(partlines)

def __eq__(self, other):
try:
return self.lines == other.lines
except AttributeError:
if isinstance(other, str):
return str(self) == other
return False
def __init__(self, obj: object = None) -> None:
if not obj:
self.lines = [] # type: List[str]
elif isinstance(obj, Source):
self.lines = obj.lines
elif isinstance(obj, (tuple, list)):
self.lines = deindent(x.rstrip("\n") for x in obj)
elif isinstance(obj, str):
self.lines = deindent(obj.split("\n"))
else:
rawcode = getrawcode(obj)
src = inspect.getsource(rawcode)
self.lines = deindent(src.split("\n"))

def __eq__(self, other: object) -> bool:
if not isinstance(other, Source):
return NotImplemented
return self.lines == other.lines

# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
Expand Down Expand Up @@ -97,19 +79,6 @@ def strip(self) -> "Source":
source.lines[:] = self.lines[start:end]
return source

def putaround(
self, before: str = "", after: str = "", indent: str = " " * 4
) -> "Source":
""" return a copy of the source object with
'before' and 'after' wrapped around it.
"""
beforesource = Source(before)
aftersource = Source(after)
newsource = Source()
lines = [(indent + line) for line in self.lines]
newsource.lines = beforesource.lines + lines + aftersource.lines
return newsource

def indent(self, indent: str = " " * 4) -> "Source":
""" return a copy of the source object with
all lines indented by the given indent-string.
Expand Down Expand Up @@ -140,142 +109,9 @@ def deindent(self) -> "Source":
newsource.lines[:] = deindent(self.lines)
return newsource

def isparseable(self, deindent: bool = True) -> bool:
""" return True if source is parseable, heuristically
deindenting it by default.
"""
if deindent:
source = str(self.deindent())
else:
source = str(self)
try:
ast.parse(source)
except (SyntaxError, ValueError, TypeError):
return False
else:
return True

def __str__(self) -> str:
return "\n".join(self.lines)

@overload
def compile(
self,
filename: Optional[str] = ...,
mode: str = ...,
flag: "Literal[0]" = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> CodeType:
raise NotImplementedError()

@overload # noqa: F811
def compile( # noqa: F811
self,
filename: Optional[str] = ...,
mode: str = ...,
flag: int = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()

def compile( # noqa: F811
self,
filename: Optional[str] = None,
mode: str = "exec",
flag: int = 0,
dont_inherit: int = 0,
_genframe: Optional[FrameType] = None,
) -> Union[CodeType, ast.AST]:
""" return compiled code object. if filename is None
invent an artificial filename which displays
the source/line position of the caller frame.
"""
if not filename or py.path.local(filename).check(file=0):
if _genframe is None:
_genframe = sys._getframe(1) # the caller
fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno
base = "<%d-codegen " % self._compilecounter
self.__class__._compilecounter += 1
if not filename:
filename = base + "%s:%d>" % (fn, lineno)
else:
filename = base + "%r %s:%d>" % (filename, fn, lineno)
source = "\n".join(self.lines) + "\n"
try:
co = compile(source, filename, mode, flag)
except SyntaxError as ex:
# re-represent syntax errors from parsing python strings
msglines = self.lines[: ex.lineno]
if ex.offset:
msglines.append(" " * ex.offset + "^")
msglines.append("(code was compiled probably from here: %s)" % filename)
newex = SyntaxError("\n".join(msglines))
newex.offset = ex.offset
newex.lineno = ex.lineno
newex.text = ex.text
raise newex from ex
else:
if flag & ast.PyCF_ONLY_AST:
assert isinstance(co, ast.AST)
return co
assert isinstance(co, CodeType)
lines = [(x + "\n") for x in self.lines]
# Type ignored because linecache.cache is private.
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
return co


#
# public API shortcut functions
#


@overload
def compile_(
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: "Literal[0]" = ...,
dont_inherit: int = ...,
) -> CodeType:
raise NotImplementedError()


@overload # noqa: F811
def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: int = ...,
dont_inherit: int = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()


def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = None,
mode: str = "exec",
flags: int = 0,
dont_inherit: int = 0,
) -> Union[CodeType, ast.AST]:
""" compile the given source to a raw code object,
and maintain an internal cache which allows later
retrieval of the source code for the code object
and any recursively created code objects.
"""
if isinstance(source, ast.AST):
# XXX should Source support having AST?
assert filename is not None
co = compile(source, filename, mode, flags, dont_inherit)
assert isinstance(co, (CodeType, ast.AST))
return co
_genframe = sys._getframe(1) # the caller
s = Source(source)
return s.compile(filename, mode, flags, _genframe=_genframe)


#
# helper functions
Expand Down Expand Up @@ -307,17 +143,7 @@ def getrawcode(obj, trycall: bool = True):
return obj


def getsource(obj, **kwargs) -> Source:
obj = getrawcode(obj)
try:
strsrc = inspect.getsource(obj)
except IndentationError:
strsrc = '"Buggy python version consider upgrading, cannot get source"'
assert isinstance(strsrc, str)
return Source(strsrc, **kwargs)


def deindent(lines: Sequence[str]) -> List[str]:
def deindent(lines: Iterable[str]) -> List[str]:
return textwrap.dedent("\n".join(lines)).splitlines()


Expand Down
4 changes: 2 additions & 2 deletions src/_pytest/skipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import attr

import _pytest._code
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import hookimpl
Expand Down Expand Up @@ -105,7 +104,8 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
if hasattr(item, "obj"):
globals_.update(item.obj.__globals__) # type: ignore[attr-defined]
try:
condition_code = _pytest._code.compile(condition, mode="eval")
filename = "<{} condition>".format(mark.name)
condition_code = compile(condition, filename, "eval")
result = eval(condition_code, globals_)
except SyntaxError as exc:
msglines = [
Expand Down
3 changes: 2 additions & 1 deletion testing/code/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from _pytest._code import Code
from _pytest._code import ExceptionInfo
from _pytest._code import Frame
from _pytest._code import Source
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ReprFuncArgs

Expand Down Expand Up @@ -67,7 +68,7 @@ def func() -> FrameType:

f = Frame(func())
with mock.patch.object(f.code.__class__, "fullsource", None):
assert f.statement == ""
assert f.statement == Source("")


def test_code_from_func() -> None:
Expand Down
Loading