diff --git a/tgpy/api.py b/tgpy/api.py index d71bec4..30d3fe3 100644 --- a/tgpy/api.py +++ b/tgpy/api.py @@ -1,4 +1,3 @@ -import logging from typing import Any, Callable @@ -15,18 +14,5 @@ def __init__(self): def add_code_transformer(self, name: str, transformer: Callable[[str], str]): self.code_transformers.append((name, transformer)) - def _apply_code_transformers(self, code: str) -> str: - for _, transformer in self.code_transformers: - try: - code = transformer(code) - except Exception: - logger = logging.getLogger(__name__) - logger.exception( - f'Error while applying code transformer {transformer}', - exc_info=True, - ) - raise - return code - api = API() diff --git a/tgpy/run_code/autoawait.py b/tgpy/run_code/autoawait.py new file mode 100644 index 0000000..6ee6416 --- /dev/null +++ b/tgpy/run_code/autoawait.py @@ -0,0 +1,43 @@ +import ast +import tokenize +from io import BytesIO + +AWAIT_REPLACEMENT_ATTRIBUTE = '__tgpy_await__' + + +def tokenize_string(s: str) -> list[tokenize.TokenInfo]: + return list(tokenize.tokenize(BytesIO(s.encode('utf-8')).readline)) + + +def untokenize_to_string(tokens: list[tokenize.TokenInfo]) -> str: + return tokenize.untokenize(tokens).decode('utf-8') + + +def pre_transform(code: str) -> str: + tokens = tokenize_string(code) + for i, tok in enumerate(tokens): + if i == 0: + continue + prev_tok = tokens[i - 1] + if ( + tok.type == tokenize.NAME + and tok.string == 'await' + and prev_tok.type == tokenize.OP + and prev_tok.string == '.' + ): + tokens[i] = tok._replace(string=AWAIT_REPLACEMENT_ATTRIBUTE) + return untokenize_to_string(tokens) + + +class AwaitTransformer(ast.NodeTransformer): + def visit_Attribute(self, node: ast.Attribute): + node = self.generic_visit(node) + if node.attr == AWAIT_REPLACEMENT_ATTRIBUTE: + return ast.Await(value=node.value) + else: + return node + + +transformer = AwaitTransformer() + +__all__ = ['pre_transform', 'transformer'] diff --git a/tgpy/run_code/meval.py b/tgpy/run_code/meval.py index 18c8d1f..81413c5 100644 --- a/tgpy/run_code/meval.py +++ b/tgpy/run_code/meval.py @@ -9,6 +9,8 @@ from typing import Any, Iterator from tgpy import app +from tgpy.run_code import autoawait +from tgpy.run_code.utils import apply_code_transformers def shallow_walk(node) -> Iterator: @@ -65,9 +67,9 @@ async def meval( # Copy data to args we are sending kwargs[global_args][glob] = globs[glob] - # noinspection PyProtectedMember - str_code = app.api._apply_code_transformers(str_code) + str_code = apply_code_transformers(app, str_code) root = ast.parse(str_code, '', 'exec') + autoawait.transformer.visit(root) ret_name = '_ret' ok = False diff --git a/tgpy/run_code/parse_code.py b/tgpy/run_code/parse_code.py index 07e0b66..9d1f117 100644 --- a/tgpy/run_code/parse_code.py +++ b/tgpy/run_code/parse_code.py @@ -1,6 +1,7 @@ import ast from tgpy import app +from tgpy.run_code.utils import apply_code_transformers class _Result: @@ -78,8 +79,7 @@ def parse_code(text: str, locs: dict) -> _Result: """Parse given text and decide should it be evaluated as Python code""" result = _Result() - # noinspection PyProtectedMember - text = app.api._apply_code_transformers(text) + text = apply_code_transformers(app, text) try: root = ast.parse(text, '', 'exec') diff --git a/tgpy/run_code/utils.py b/tgpy/run_code/utils.py index 7e2b7f8..378fca0 100644 --- a/tgpy/run_code/utils.py +++ b/tgpy/run_code/utils.py @@ -1,8 +1,13 @@ +import logging import sys +import tokenize import traceback from telethon.tl import TLObject +from tgpy import App +from tgpy.run_code import autoawait + def convert_result(result): if isinstance(result, TLObject): @@ -15,3 +20,21 @@ def format_traceback(): exc_type, exc_value, exc_traceback = sys.exc_info() exc_traceback = exc_traceback.tb_next.tb_next return traceback.format_exception(exc_type, exc_value, exc_traceback) + + +def apply_code_transformers(app: App, code: str) -> str: + try: + code = autoawait.pre_transform(code) + except tokenize.TokenError: + pass + for _, transformer in app.api.code_transformers: + try: + code = transformer(code) + except Exception: + logger = logging.getLogger(__name__) + logger.exception( + f'Error while applying code transformer {transformer}', + exc_info=True, + ) + raise + return code