Skip to content

Commit

Permalink
Fix type annotations in various places.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanslenders committed Dec 12, 2023
1 parent 655b354 commit 6801f94
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 64 deletions.
15 changes: 6 additions & 9 deletions examples/asyncio-python-embed.py
Expand Up @@ -19,7 +19,7 @@
counter = [0]


async def print_counter():
async def print_counter() -> None:
"""
Coroutine that prints counters and saves it in a global variable.
"""
Expand All @@ -29,7 +29,7 @@ async def print_counter():
await asyncio.sleep(3)


async def interactive_shell():
async def interactive_shell() -> None:
"""
Coroutine that starts a Python REPL from which we can access the global
counter variable.
Expand All @@ -44,13 +44,10 @@ async def interactive_shell():
loop.stop()


def main():
asyncio.ensure_future(print_counter())
asyncio.ensure_future(interactive_shell())

loop.run_forever()
loop.close()
async def main() -> None:
asyncio.create_task(print_counter())
await interactive_shell()


if __name__ == "__main__":
main()
asyncio.run(main())
18 changes: 6 additions & 12 deletions examples/asyncio-ssh-python-embed.py
Expand Up @@ -32,31 +32,25 @@ def session_requested(self):
return ReplSSHServerSession(self.get_namespace)


def main(port=8222):
async def main(port: int = 8222) -> None:
"""
Example that starts the REPL through an SSH server.
"""
loop = asyncio.get_event_loop()

# Namespace exposed in the REPL.
environ = {"hello": "world"}

# Start SSH server.
def create_server():
def create_server() -> MySSHServer:
return MySSHServer(lambda: environ)

print("Listening on :%i" % port)
print('To connect, do "ssh localhost -p %i"' % port)

loop.run_until_complete(
asyncssh.create_server(
create_server, "", port, server_host_keys=["/etc/ssh/ssh_host_dsa_key"]
)
await asyncssh.create_server(
create_server, "", port, server_host_keys=["/etc/ssh/ssh_host_dsa_key"]
)

# Run eventloop.
loop.run_forever()
await asyncio.Future() # Wait forever.


if __name__ == "__main__":
main()
asyncio.run(main())
12 changes: 6 additions & 6 deletions examples/python-embed-with-custom-prompt.py
Expand Up @@ -2,26 +2,26 @@
"""
Example of embedding a Python REPL, and setting a custom prompt.
"""
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.formatted_text import HTML, AnyFormattedText

from ptpython.prompt_style import PromptStyle
from ptpython.repl import embed


def configure(repl):
def configure(repl) -> None:
# Probably, the best is to add a new PromptStyle to `all_prompt_styles` and
# activate it. This way, the other styles are still selectable from the
# menu.
class CustomPrompt(PromptStyle):
def in_prompt(self):
def in_prompt(self) -> AnyFormattedText:
return HTML("<ansigreen>Input[%s]</ansigreen>: ") % (
repl.current_statement_index,
)

def in2_prompt(self, width):
def in2_prompt(self, width: int) -> AnyFormattedText:
return "...: ".rjust(width)

def out_prompt(self):
def out_prompt(self) -> AnyFormattedText:
return HTML("<ansired>Result[%s]</ansired>: ") % (
repl.current_statement_index,
)
Expand All @@ -30,7 +30,7 @@ def out_prompt(self):
repl.prompt_style = "custom"


def main():
def main() -> None:
embed(globals(), locals(), configure=configure)


Expand Down
2 changes: 1 addition & 1 deletion examples/python-embed.py
Expand Up @@ -4,7 +4,7 @@
from ptpython.repl import embed


def main():
def main() -> None:
embed(globals(), locals(), vi_mode=False)


Expand Down
11 changes: 7 additions & 4 deletions examples/ssh-and-telnet-embed.py
Expand Up @@ -11,26 +11,29 @@

import asyncssh
from prompt_toolkit import print_formatted_text
from prompt_toolkit.contrib.ssh.server import PromptToolkitSSHServer
from prompt_toolkit.contrib.ssh.server import (
PromptToolkitSSHServer,
PromptToolkitSSHSession,
)
from prompt_toolkit.contrib.telnet.server import TelnetServer

from ptpython.repl import embed


def ensure_key(filename="ssh_host_key"):
def ensure_key(filename: str = "ssh_host_key") -> str:
path = pathlib.Path(filename)
if not path.exists():
rsa_key = asyncssh.generate_private_key("ssh-rsa")
path.write_bytes(rsa_key.export_private_key())
return str(path)


async def interact(connection=None):
async def interact(connection: PromptToolkitSSHSession) -> None:
global_dict = {**globals(), "print": print_formatted_text}
await embed(return_asyncio_coroutine=True, globals=global_dict)


async def main(ssh_port=8022, telnet_port=8023):
async def main(ssh_port: int = 8022, telnet_port: int = 8023) -> None:
ssh_server = PromptToolkitSSHServer(interact=interact)
await asyncssh.create_server(
lambda: ssh_server, "", ssh_port, server_host_keys=[ensure_key()]
Expand Down
26 changes: 15 additions & 11 deletions ptpython/contrib/asyncssh_repl.py
Expand Up @@ -9,20 +9,20 @@
from __future__ import annotations

import asyncio
from typing import Any, TextIO, cast
from typing import Any, AnyStr, TextIO, cast

import asyncssh
from prompt_toolkit.data_structures import Size
from prompt_toolkit.input import create_pipe_input
from prompt_toolkit.output.vt100 import Vt100_Output

from ptpython.python_input import _GetNamespace
from ptpython.python_input import _GetNamespace, _Namespace
from ptpython.repl import PythonRepl

__all__ = ["ReplSSHServerSession"]


class ReplSSHServerSession(asyncssh.SSHServerSession):
class ReplSSHServerSession(asyncssh.SSHServerSession[str]):
"""
SSH server session that runs a Python REPL.
Expand All @@ -35,7 +35,7 @@ def __init__(
) -> None:
self._chan: Any = None

def _globals() -> dict:
def _globals() -> _Namespace:
data = get_globals()
data.setdefault("print", self._print)
return data
Expand Down Expand Up @@ -79,7 +79,7 @@ def _get_size(self) -> Size:
width, height, pixwidth, pixheight = self._chan.get_terminal_size()
return Size(rows=height, columns=width)

def connection_made(self, chan):
def connection_made(self, chan: Any) -> None:
"""
Client connected, run repl in coroutine.
"""
Expand All @@ -89,7 +89,7 @@ def connection_made(self, chan):
f = asyncio.ensure_future(self.repl.run_async())

# Close channel when done.
def done(_) -> None:
def done(_: object) -> None:
chan.close()
self._chan = None

Expand All @@ -98,24 +98,28 @@ def done(_) -> None:
def shell_requested(self) -> bool:
return True

def terminal_size_changed(self, width, height, pixwidth, pixheight):
def terminal_size_changed(
self, width: int, height: int, pixwidth: int, pixheight: int
) -> None:
"""
When the terminal size changes, report back to CLI.
"""
self.repl.app._on_resize()

def data_received(self, data, datatype):
def data_received(self, data: AnyStr, datatype: int | None) -> None:
"""
When data is received, send to inputstream of the CLI and repaint.
"""
self._input_pipe.send(data)

def _print(self, *data, sep=" ", end="\n", file=None) -> None:
def _print(
self, *data: object, sep: str = " ", end: str = "\n", file: Any = None
) -> None:
"""
Alternative 'print' function that prints back into the SSH channel.
"""
# Pop keyword-only arguments. (We cannot use the syntax from the
# signature. Otherwise, Python2 will give a syntax error message when
# installing.)
data = sep.join(map(str, data))
self._chan.write(data + end)
data_as_str = sep.join(map(str, data))
self._chan.write(data_as_str + end)
58 changes: 38 additions & 20 deletions ptpython/python_input.py
Expand Up @@ -6,7 +6,7 @@

from asyncio import get_event_loop
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Mapping, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Mapping, TypeVar, Union

from prompt_toolkit.application import Application, get_app
from prompt_toolkit.auto_suggest import (
Expand All @@ -31,7 +31,7 @@
)
from prompt_toolkit.document import Document
from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
from prompt_toolkit.filters import Condition
from prompt_toolkit.filters import Condition, FilterOrBool
from prompt_toolkit.formatted_text import AnyFormattedText
from prompt_toolkit.history import (
FileHistory,
Expand All @@ -49,8 +49,13 @@
from prompt_toolkit.key_binding.bindings.open_in_editor import (
load_open_in_editor_bindings,
)
from prompt_toolkit.key_binding.key_bindings import Binding, KeyHandlerCallable
from prompt_toolkit.key_binding.key_processor import KeyPressEvent
from prompt_toolkit.key_binding.vi_state import InputMode
from prompt_toolkit.keys import Keys
from prompt_toolkit.layout.containers import AnyContainer
from prompt_toolkit.layout.dimension import AnyDimension
from prompt_toolkit.layout.processors import Processor
from prompt_toolkit.lexers import DynamicLexer, Lexer, SimpleLexer
from prompt_toolkit.output import ColorDepth, Output
from prompt_toolkit.styles import (
Expand Down Expand Up @@ -91,22 +96,23 @@
from typing_extensions import Protocol

class _SupportsLessThan(Protocol):
# Taken from typeshed. _T is used by "sorted", which needs anything
# Taken from typeshed. _T_lt is used by "sorted", which needs anything
# sortable.
def __lt__(self, __other: Any) -> bool:
...


_T = TypeVar("_T", bound="_SupportsLessThan")
_T_lt = TypeVar("_T_lt", bound="_SupportsLessThan")
_T_kh = TypeVar("_T_kh", bound=Union[KeyHandlerCallable, Binding])


class OptionCategory(Generic[_T]):
def __init__(self, title: str, options: list[Option[_T]]) -> None:
class OptionCategory(Generic[_T_lt]):
def __init__(self, title: str, options: list[Option[_T_lt]]) -> None:
self.title = title
self.options = options


class Option(Generic[_T]):
class Option(Generic[_T_lt]):
"""
Ptpython configuration option that can be shown and modified from the
sidebar.
Expand All @@ -122,18 +128,18 @@ def __init__(
self,
title: str,
description: str,
get_current_value: Callable[[], _T],
get_current_value: Callable[[], _T_lt],
# We accept `object` as return type for the select functions, because
# often they return an unused boolean. Maybe this can be improved.
get_values: Callable[[], Mapping[_T, Callable[[], object]]],
get_values: Callable[[], Mapping[_T_lt, Callable[[], object]]],
) -> None:
self.title = title
self.description = description
self.get_current_value = get_current_value
self.get_values = get_values

@property
def values(self) -> Mapping[_T, Callable[[], object]]:
def values(self) -> Mapping[_T_lt, Callable[[], object]]:
return self.get_values()

def activate_next(self, _previous: bool = False) -> None:
Expand Down Expand Up @@ -208,10 +214,10 @@ def __init__(
_completer: Completer | None = None,
_validator: Validator | None = None,
_lexer: Lexer | None = None,
_extra_buffer_processors=None,
_extra_buffer_processors: list[Processor] | None = None,
_extra_layout_body: AnyContainer | None = None,
_extra_toolbars=None,
_input_buffer_height=None,
_extra_toolbars: list[AnyContainer] | None = None,
_input_buffer_height: AnyDimension | None = None,
) -> None:
self.get_globals: _GetNamespace = get_globals or (lambda: {})
self.get_locals: _GetNamespace = get_locals or self.get_globals
Expand Down Expand Up @@ -466,24 +472,36 @@ def get_compiler_flags(self) -> int:

return flags

@property
def add_key_binding(self) -> Callable[[_T], _T]:
def add_key_binding(
self,
*keys: Keys | str,
filter: FilterOrBool = True,
eager: FilterOrBool = False,
is_global: FilterOrBool = False,
save_before: Callable[[KeyPressEvent], bool] = (lambda e: True),
record_in_macro: FilterOrBool = True,
) -> Callable[[_T_kh], _T_kh]:
"""
Shortcut for adding new key bindings.
(Mostly useful for a config.py file, that receives
a PythonInput/Repl instance as input.)
All arguments are identical to prompt_toolkit's `KeyBindings.add`.
::
@python_input.add_key_binding(Keys.ControlX, filter=...)
def handler(event):
...
"""

def add_binding_decorator(*k, **kw):
return self.extra_key_bindings.add(*k, **kw)

return add_binding_decorator
return self.extra_key_bindings.add(
*keys,
filter=filter,
eager=eager,
is_global=is_global,
save_before=save_before,
record_in_macro=record_in_macro,
)

def install_code_colorscheme(self, name: str, style: BaseStyle) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion ptpython/repl.py
Expand Up @@ -158,7 +158,7 @@ def run(self) -> None:
clear_title()
self._remove_from_namespace()

async def run_and_show_expression_async(self, text: str):
async def run_and_show_expression_async(self, text: str) -> object:
loop = asyncio.get_event_loop()

try:
Expand Down

0 comments on commit 6801f94

Please sign in to comment.