Skip to content

Commit

Permalink
Merge pull request #96 from life4/code-generation
Browse files Browse the repository at this point in the history
Code generation (`python3 -m deal decorate` CLI command)
  • Loading branch information
orsinium committed Nov 18, 2021
2 parents 6084e90 + c30eda6 commit 366ca42
Show file tree
Hide file tree
Showing 22 changed files with 1,027 additions and 72 deletions.
4 changes: 2 additions & 2 deletions deal/_cli/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import TextIO

Expand All @@ -18,5 +18,5 @@ def print(self, *args) -> None:
def init_parser(parser: ArgumentParser) -> None:
raise NotImplementedError

def __call__(self, args) -> int:
def __call__(self, args: Namespace) -> int:
raise NotImplementedError
63 changes: 63 additions & 0 deletions deal/_cli/_decorate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from argparse import ArgumentParser
from pathlib import Path

from .._colors import get_colors
from ..linter import TransformationType, Transformer
from ._base import Command
from ._common import get_paths


class DecorateCommand(Command):
"""Add decorators to your code.
```bash
python3 -m deal decorate project/
```
Options:
+ `--types`: types of decorators to apply. All are enabled by default.
+ `--double-quotes`: use double quotes. Single quotes are used by default.
+ `--nocolor`: do not use colors in the console output.
The exit code is always 0. If you want to test the code for missed decorators,
use the `lint` command instead.
"""

@staticmethod
def init_parser(parser: ArgumentParser) -> None:
parser.add_argument(
'--types',
nargs='*',
choices=[tt.value for tt in TransformationType],
default=['has', 'raises', 'safe', 'import'],
help='types of decorators to apply',
)
parser.add_argument(
'--double-quotes',
action='store_true',
help='use double quotes',
)
parser.add_argument('--nocolor', action='store_true', help='colorless output')
parser.add_argument('paths', nargs='*', default='.')

def __call__(self, args) -> int:
types = {TransformationType(t) for t in args.types}
colors = get_colors(args)
for arg in args.paths:
for path in get_paths(Path(arg)):
self.print('{magenta}{path}{end}'.format(path=path, **colors))
original_code = path.read_text(encoding='utf8')
tr = Transformer(
content=original_code,
path=path,
types=types,
)
if args.double_quotes:
tr = tr._replace(quote='"')
modified_code = tr.transform()
if original_code == modified_code:
self.print(' {blue}no changes{end}'.format(**colors))
else:
path.write_text(modified_code)
self.print(' {green}decorated{end}'.format(**colors))
return 0
35 changes: 21 additions & 14 deletions deal/_cli/_main.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,40 @@
import sys
from argparse import ArgumentParser
from pathlib import Path
from types import MappingProxyType
from typing import Mapping, Sequence, TextIO, Type

from ._base import Command
from ._lint import LintCommand
from ._memtest import MemtestCommand
from ._prove import ProveCommand
from ._stub import StubCommand
from ._test import TestCommand


CommandsType = Mapping[str, Type[Command]]
COMMANDS: CommandsType = MappingProxyType(dict(
lint=LintCommand,
memtest=MemtestCommand,
prove=ProveCommand,
stub=StubCommand,
test=TestCommand,
))


def get_commands() -> CommandsType:
from ._decorate import DecorateCommand
from ._lint import LintCommand
from ._memtest import MemtestCommand
from ._prove import ProveCommand
from ._stub import StubCommand
from ._test import TestCommand

return dict(
decorate=DecorateCommand,
lint=LintCommand,
memtest=MemtestCommand,
prove=ProveCommand,
stub=StubCommand,
test=TestCommand,
)


def main(
argv: Sequence[str], *,
commands: CommandsType = COMMANDS,
commands: CommandsType = None,
root: Path = None,
stream: TextIO = sys.stdout,
) -> int:
if commands is None:
commands = get_commands()
if root is None: # pragma: no cover
root = Path()
parser = ArgumentParser(prog='python3 -m deal')
Expand Down
7 changes: 5 additions & 2 deletions deal/_cli/_prove.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
class DealTheorem(Theorem):
@staticmethod
def get_contracts(func: astroid.FunctionDef) -> Iterator[Contract]:
for name, args in get_contracts(func):
yield Contract(name=name, args=args)
for contract in get_contracts(func):
yield Contract(
name=contract.name,
args=contract.args, # type: ignore[arg-type]
)


def run_solver(
Expand Down
9 changes: 8 additions & 1 deletion deal/linter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from ._checker import Checker
from ._stub import StubsManager, generate_stub
from ._transformer import TransformationType, Transformer


__all__ = ['Checker', 'StubsManager', 'generate_stub']
__all__ = [
'Checker',
'generate_stub',
'StubsManager',
'TransformationType',
'Transformer',
]
12 changes: 10 additions & 2 deletions deal/linter/_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import enum
from copy import copy
from pathlib import Path
from typing import Dict, FrozenSet, Iterable, List
from typing import Dict, FrozenSet, Iterable, List, Type, Union

import astroid

Expand All @@ -24,25 +24,33 @@ class Category(enum.Enum):
PURE = 'pure'
RAISES = 'raises'
SAFE = 'safe'
INHERIT = 'inherit'

@property
def brackets_optional(self) -> bool:
return self in {Category.SAFE, Category.PURE}


class Contract:
args: tuple
category: Category
func_args: ast.arguments
context: Dict[str, ast.stmt]
line: int

def __init__(
self,
args: Iterable,
category: Category,
func_args: ast.arguments,
context: Dict[str, ast.stmt] = None,
line: int = 0,
):
self.args = tuple(args)
self.category = category
self.func_args = func_args
self.context = context or dict()
self.line = line

@cached_property
def body(self) -> ast.AST:
Expand Down Expand Up @@ -110,7 +118,7 @@ def _resolve_name(contract):
return contract # pragma: no cover

@cached_property
def exceptions(self) -> list:
def exceptions(self) -> List[Union[str, Type[Exception]]]:
from ._extractors import get_name

excs = []
Expand Down
3 changes: 1 addition & 2 deletions deal/linter/_extractors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ class Token(NamedTuple):
line: int
col: int
value: Optional[object] = None
# marker name or error message:
marker: Optional[str] = None
marker: Optional[str] = None # marker name or error message


def traverse(body: List) -> Iterator:
Expand Down
26 changes: 20 additions & 6 deletions deal/linter/_extractors/contracts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ast
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, NamedTuple, Optional, Union

import astroid

Expand All @@ -20,7 +20,13 @@
Attr = Union[ast.Attribute, astroid.Attribute]


def get_contracts(func) -> Iterator[Tuple[str, list]]:
class ContractInfo(NamedTuple):
name: str
args: List[Union[ast.expr, astroid.Expr]]
line: int


def get_contracts(func) -> Iterator[ContractInfo]:
if isinstance(func, ast.FunctionDef):
yield from _get_contracts(func.decorator_list)
return
Expand All @@ -32,13 +38,17 @@ def get_contracts(func) -> Iterator[Tuple[str, list]]:
yield from _get_contracts(func.decorators.nodes)


def _get_contracts(decorators: list) -> Iterator[Tuple[str, list]]:
def _get_contracts(decorators: list) -> Iterator[ContractInfo]:
for contract in decorators:
if isinstance(contract, TOKENS.ATTR):
name = get_name(contract)
if name not in SUPPORTED_MARKERS:
continue
yield name.split('.')[-1], []
yield ContractInfo(
name=name.split('.')[-1],
args=[],
line=contract.lineno,
)
if name == 'deal.inherit':
yield from _resolve_inherit(contract)

Expand All @@ -50,7 +60,11 @@ def _get_contracts(decorators: list) -> Iterator[Tuple[str, list]]:
yield from _get_contracts(contract.args)
if name not in SUPPORTED_CONTRACTS:
continue
yield name.split('.')[-1], contract.args
yield ContractInfo(
name=name.split('.')[-1],
args=contract.args,
line=contract.lineno,
)

# infer assigned value
if isinstance(contract, astroid.Name):
Expand All @@ -68,7 +82,7 @@ def _get_contracts(decorators: list) -> Iterator[Tuple[str, list]]:
yield from _get_contracts([expr.value])


def _resolve_inherit(contract: Attr) -> Iterator[Tuple[str, List[astroid.Expr]]]:
def _resolve_inherit(contract: Attr) -> Iterator[ContractInfo]:
if not isinstance(contract, astroid.Attribute):
return
cls = _get_parent_class(contract)
Expand Down
6 changes: 3 additions & 3 deletions deal/linter/_extractors/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def _exceptions_from_func(expr) -> Iterator[Token]:
yield Token(value=error.value, line=expr.lineno, col=expr.col_offset)

# get explicitly specified exceptions from `@deal.raises`
for category, args in get_contracts(value):
if category != 'raises':
for contract in get_contracts(value):
if contract.name != 'raises':
continue
for arg in args:
for arg in contract.args:
name = get_name(arg)
if name is None:
continue
Expand Down
6 changes: 3 additions & 3 deletions deal/linter/_extractors/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ def _markers_from_func(expr: astroid.NodeNG, inferred: tuple) -> Iterator[Token]
)

# get explicitly specified markers from `@deal.has`
for category, args in get_contracts(value):
if category != 'has':
for contract in get_contracts(value):
if contract.name != 'has':
continue
for arg in args:
for arg in contract.args:
value = get_value(arg)
if type(value) is not str:
continue
Expand Down
7 changes: 3 additions & 4 deletions deal/linter/_extractors/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@ def handle_call(expr: astroid.Call, context: Dict[str, ast.stmt] = None) -> Iter
continue
code = f'def f({func.args.as_string()}):0'
func_args = ast.parse(code).body[0].args # type: ignore
for category, contract_args in get_contracts(func):
if category != 'pre':
for cinfo in get_contracts(func):
if cinfo.name != 'pre':
continue

contract = Contract(
args=contract_args,
args=cinfo.args,
category=Category.PRE,
func_args=func_args,
context=context,
Expand Down
20 changes: 14 additions & 6 deletions deal/linter/_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ def from_ast(cls, tree: ast.Module) -> List['Func']:
if not isinstance(expr, ast.FunctionDef):
continue
contracts = []
for category, args in get_contracts(expr):
for cinfo in get_contracts(expr):
contract = Contract(
args=args,
args=cinfo.args,
func_args=expr.args,
category=Category(category),
category=Category(cinfo.name),
context=definitions,
line=cinfo.line,
)
contracts.append(contract)
funcs.append(cls(
Expand All @@ -68,12 +69,13 @@ def from_astroid(cls, tree: astroid.Module) -> List['Func']:

# collect contracts
contracts = []
for category, args in get_contracts(expr):
for cinfo in get_contracts(expr):
contract = Contract(
args=args,
args=cinfo.args,
func_args=func_args,
category=Category(category),
category=Category(cinfo.name),
context=definitions,
line=cinfo.line,
)
contracts.append(contract)
funcs.append(cls(
Expand All @@ -86,6 +88,12 @@ def from_astroid(cls, tree: astroid.Module) -> List['Func']:
))
return funcs

def has_contract(self, *categories: Category) -> bool:
for contract in self.contracts:
if contract.category in categories:
return True
return False

def __repr__(self) -> str:
cats = ', '.join(contract.category.value for contract in self.contracts)
return f'{type(self).__name__}({cats})'

0 comments on commit 366ca42

Please sign in to comment.