Skip to content

Commit

Permalink
extra micro optimizations on type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
penguinolog committed Apr 30, 2020
1 parent 4bf4d6b commit ee5fb17
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 63 deletions.
45 changes: 22 additions & 23 deletions exec_helpers/_ssh_base.py
Expand Up @@ -86,16 +86,17 @@
from exec_helpers.api import OptionalTimeoutT
from exec_helpers.proc_enums import ExitCodeT

_OptionalSSHAuthMapT = typing.Optional[typing.Union[typing.Dict[str, SSHAuth], SSHAuthMapping]]
_OptionalSSHConfigArgT = typing.Union[
str,
ParamikoSSHConfig,
typing.Dict[str, typing.Dict[str, typing.Union[str, int, bool, typing.List[str]]]],
HostsSSHConfigs,
None,
]
_SSHConnChainT = typing.List[typing.Tuple[SSHConfig, SSHAuth]]
_OptSSHAuthT = typing.Optional[SSHAuth]
_OptionalSSHAuthMapT = typing.Optional[typing.Union[typing.Dict[str, SSHAuth], SSHAuthMapping]]
_OptionalSSHConfigArgT = typing.Union[
str,
ParamikoSSHConfig,
typing.Dict[str, typing.Dict[str, typing.Union[str, int, bool, typing.List[str]]]],
HostsSSHConfigs,
None,
]
_SSHConnChainT = typing.List[typing.Tuple[SSHConfig, SSHAuth]]
_OptSSHAuthT = typing.Optional[SSHAuth]

_RType = typing.TypeVar("_RType")


Expand Down Expand Up @@ -224,9 +225,7 @@ def normalize_path(tgt: typing.Callable[..., _RType]) -> typing.Callable[..., _R
"""

@wraps(tgt)
def wrapper(
self: typing.Any, path: typing.Union[str, PurePath], *args: typing.Any, **kwargs: typing.Any
) -> _RType:
def wrapper(self: typing.Any, path: typing.Union[str, PurePath], *args: typing.Any, **kwargs: typing.Any) -> _RType:
"""Normalize path type before use in corresponding method.
:param self: owner instance
Expand Down Expand Up @@ -322,10 +321,10 @@ def __init__(
username: typing.Optional[str] = None,
password: typing.Optional[str] = None,
*,
auth: _OptSSHAuthT = None,
auth: "_OptSSHAuthT" = None,
verbose: bool = True,
ssh_config: _OptionalSSHConfigArgT = None,
ssh_auth_map: _OptionalSSHAuthMapT = None,
ssh_config: "_OptionalSSHConfigArgT" = None,
ssh_auth_map: "_OptionalSSHAuthMapT" = None,
sock: "typing.Optional[typing.Union[ProxyCommand, Channel, socket]]" = None,
keepalive: typing.Union[int, bool] = 1,
) -> None:
Expand Down Expand Up @@ -387,7 +386,7 @@ def __init__(

# Build connection chain once and use it for connection later
if sock is None:
self.__conn_chain: _SSHConnChainT = self.__build_connection_chain()
self.__conn_chain: "_SSHConnChainT" = self.__build_connection_chain()
else:
self.__conn_chain = []

Expand All @@ -401,13 +400,13 @@ def __rebuild_ssh_config(self) -> None:
)
)

def __build_connection_chain(self) -> _SSHConnChainT:
def __build_connection_chain(self) -> "_SSHConnChainT":
"""Build ssh connection chain to reach destination host.
:return: list of SSHConfig - SSHAuth pairs in order of connection
:rtype: typing.List[typing.Tuple[SSHConfig, SSHAuth]]
"""
conn_chain: _SSHConnChainT = []
conn_chain: "_SSHConnChainT" = []

config = self.ssh_config[self.hostname]
default_auth = SSHAuth(username=config.user, key_filename=config.identityfile)
Expand Down Expand Up @@ -1198,10 +1197,10 @@ def proxy_to(
username: typing.Optional[str] = None,
password: typing.Optional[str] = None,
*,
auth: _OptSSHAuthT = None,
auth: "_OptSSHAuthT" = None,
verbose: bool = True,
ssh_config: _OptionalSSHConfigArgT = None,
ssh_auth_map: _OptionalSSHAuthMapT = None,
ssh_config: "_OptionalSSHConfigArgT" = None,
ssh_auth_map: "_OptionalSSHAuthMapT" = None,
keepalive: typing.Union[int, bool] = 1,
) -> "SSHClientBase":
"""Start new SSH connection using current as proxy.
Expand Down Expand Up @@ -1265,7 +1264,7 @@ def execute_through_host(
hostname: str,
command: "CommandT",
*,
auth: _OptSSHAuthT = None,
auth: "_OptSSHAuthT" = None,
port: typing.Optional[int] = None,
verbose: bool = False,
timeout: "OptionalTimeoutT" = DEFAULT_TIMEOUT,
Expand Down
10 changes: 5 additions & 5 deletions exec_helpers/exec_result.py
Expand Up @@ -63,9 +63,9 @@
from exec_helpers.proc_enums import ExitCodeT

_OptLoggerT = typing.Optional[Logger]
_OptBytesIterableT = typing.Optional[typing.Iterable[bytes]]

LOGGER: "Logger" = getLogger(__name__)
_OptBytesIterableT = typing.Optional[typing.Iterable[bytes]]


def _get_str_from_bin(src: bytearray) -> str:
Expand Down Expand Up @@ -178,8 +178,8 @@ def __init__(
self,
cmd: str,
stdin: typing.Union[bytes, str, bytearray, None] = None,
stdout: _OptBytesIterableT = None,
stderr: _OptBytesIterableT = None,
stdout: "_OptBytesIterableT" = None,
stderr: "_OptBytesIterableT" = None,
exit_code: "ExitCodeT" = INVALID,
*,
started: typing.Optional[datetime] = None,
Expand Down Expand Up @@ -344,7 +344,7 @@ def _poll_stream(
)
return dst

def read_stdout(self, src: _OptBytesIterableT = None, log: "_OptLoggerT" = None, verbose: bool = False,) -> None:
def read_stdout(self, src: "_OptBytesIterableT" = None, log: "_OptLoggerT" = None, verbose: bool = False,) -> None:
"""Read stdout file-like object to stdout.
:param src: source
Expand All @@ -366,7 +366,7 @@ def read_stdout(self, src: _OptBytesIterableT = None, log: "_OptLoggerT" = None,
self._stdout_str = self._stdout_brief = None
self._stdout += tuple(self._poll_stream(src, log, verbose))

def read_stderr(self, src: _OptBytesIterableT = None, log: "_OptLoggerT" = None, verbose: bool = False,) -> None:
def read_stderr(self, src: "_OptBytesIterableT" = None, log: "_OptLoggerT" = None, verbose: bool = False,) -> None:
"""Read stderr file-like object to stdout.
:param src: source
Expand Down
74 changes: 39 additions & 35 deletions exec_helpers/subprocess.py
Expand Up @@ -38,16 +38,19 @@
from threaded import threadpooled

# Package Implementation
from exec_helpers import constants
from exec_helpers import exceptions
from exec_helpers import exec_result
from exec_helpers import proc_enums
from exec_helpers.api import ExecHelper
from exec_helpers.api import ExecuteAsyncResult
from exec_helpers.constants import DEFAULT_TIMEOUT
from exec_helpers.exceptions import CalledProcessError
from exec_helpers.exceptions import ExecHelperNoKillError
from exec_helpers.exceptions import ExecHelperTimeoutError
from exec_helpers.exec_result import ExecResult
from exec_helpers.proc_enums import EXPECTED

# Local Implementation
from . import _log_templates
from . import _subprocess_helpers
from ._log_templates import CMD_WAIT_ERROR
from ._subprocess_helpers import kill_proc_tree
from ._subprocess_helpers import subprocess_kw

if typing.TYPE_CHECKING:
# pylint: disable=ungrouped-imports
Expand All @@ -58,11 +61,12 @@
from exec_helpers.api import OptionalTimeoutT
from exec_helpers.proc_enums import ExitCodeT

_OptionalIOBytes = typing.Optional[typing.IO[bytes]]

EnvT = typing.Optional[
typing.Union[typing.Mapping[bytes, typing.Union[bytes, str]], typing.Mapping[str, typing.Union[bytes, str]]]
]
CwdT = typing.Optional[typing.Union[str, bytes, Path]]
_OptionalIOBytes = typing.Optional[typing.IO[bytes]]


# noinspection PyTypeHints
Expand All @@ -81,7 +85,7 @@ def interface(self) -> "Popen[bytes]":
return super().interface # type: ignore

@property
def stdin(self) -> _OptionalIOBytes: # type: ignore
def stdin(self) -> "_OptionalIOBytes": # type: ignore
"""Override original NamedTuple with proper typing.
:return: STDIN interface
Expand All @@ -90,7 +94,7 @@ def stdin(self) -> _OptionalIOBytes: # type: ignore
return super().stdin

@property
def stderr(self) -> _OptionalIOBytes: # type: ignore
def stderr(self) -> "_OptionalIOBytes": # type: ignore
"""Override original NamedTuple with proper typing.
:return: STDERR interface
Expand All @@ -99,7 +103,7 @@ def stderr(self) -> _OptionalIOBytes: # type: ignore
return super().stderr

@property
def stdout(self) -> _OptionalIOBytes: # type: ignore
def stdout(self) -> "_OptionalIOBytes": # type: ignore
"""Override original NamedTuple with proper typing.
:return: STDOUT interface
Expand Down Expand Up @@ -150,7 +154,7 @@ def _exec_command( # type: ignore
log_mask_re: typing.Optional[str] = None,
stdin: "OptionalStdinT" = None,
**kwargs: typing.Any,
) -> exec_result.ExecResult:
) -> ExecResult:
"""Get exit status from channel with timeout.
:param command: Command for execution
Expand Down Expand Up @@ -197,7 +201,7 @@ def close_streams() -> None:
# Store command with hidden data
cmd_for_log: str = self._mask_command(cmd=command, log_mask_re=log_mask_re)

result = exec_result.ExecResult(cmd=cmd_for_log, stdin=stdin, started=async_result.started)
result = ExecResult(cmd=cmd_for_log, stdin=stdin, started=async_result.started)

# noinspection PyNoneFunctionAssignment,PyTypeChecker
stdout_future: "Future[None]" = poll_stdout()
Expand All @@ -211,10 +215,10 @@ def close_streams() -> None:
return result
except TimeoutExpired:
# kill -9 for all subprocesses
_subprocess_helpers.kill_proc_tree(async_result.interface.pid)
kill_proc_tree(async_result.interface.pid)
exit_signal: typing.Optional[int] = async_result.interface.poll()
if exit_signal is None:
raise exceptions.ExecHelperNoKillError(result=result, timeout=timeout) # type: ignore
raise ExecHelperNoKillError(result=result, timeout=timeout) # type: ignore
result.exit_code = exit_signal
finally:
stdout_future.cancel()
Expand All @@ -228,9 +232,9 @@ def close_streams() -> None:
result.set_timestamp()
close_streams()

wait_err_msg: str = _log_templates.CMD_WAIT_ERROR.format(result=result, timeout=timeout)
wait_err_msg: str = CMD_WAIT_ERROR.format(result=result, timeout=timeout)
self.logger.debug(wait_err_msg)
raise exceptions.ExecHelperTimeoutError(result=result, timeout=timeout) # type: ignore
raise ExecHelperTimeoutError(result=result, timeout=timeout) # type: ignore

# noinspection PyMethodOverriding
def _execute_async( # pylint: disable=arguments-differ
Expand Down Expand Up @@ -300,11 +304,11 @@ def _execute_async( # pylint: disable=arguments-differ
cwd=cwd,
env=env,
universal_newlines=False,
**_subprocess_helpers.subprocess_kw,
**subprocess_kw,
)

if stdin is None:
process_stdin: _OptionalIOBytes = process.stdin
process_stdin: "_OptionalIOBytes" = process.stdin
elif process.stdin is None:
self.logger.warning("STDIN pipe is not set, but STDIN data is available to send.")
process_stdin = None
Expand All @@ -321,7 +325,7 @@ def _execute_async( # pylint: disable=arguments-differ
elif exc.errno in (EPIPE, ESHUTDOWN):
self.logger.warning("STDIN Send failed: broken PIPE")
else:
_subprocess_helpers.kill_proc_tree(process.pid)
kill_proc_tree(process.pid)
process.kill()
raise
try:
Expand All @@ -342,7 +346,7 @@ def execute( # pylint: disable=arguments-differ
self,
command: "CommandT",
verbose: bool = False,
timeout: "OptionalTimeoutT" = constants.DEFAULT_TIMEOUT,
timeout: "OptionalTimeoutT" = DEFAULT_TIMEOUT,
*,
log_mask_re: typing.Optional[str] = None,
stdin: "OptionalStdinT" = None,
Expand All @@ -352,7 +356,7 @@ def execute( # pylint: disable=arguments-differ
env: EnvT = None,
env_patch: EnvT = None,
**kwargs: typing.Any,
) -> exec_result.ExecResult:
) -> ExecResult:
"""Execute command and wait for return code.
:param command: Command for execution
Expand Down Expand Up @@ -404,7 +408,7 @@ def __call__(
self,
command: "CommandT",
verbose: bool = False,
timeout: "OptionalTimeoutT" = constants.DEFAULT_TIMEOUT,
timeout: "OptionalTimeoutT" = DEFAULT_TIMEOUT,
*,
log_mask_re: typing.Optional[str] = None,
stdin: "OptionalStdinT" = None,
Expand All @@ -414,7 +418,7 @@ def __call__(
env: EnvT = None,
env_patch: EnvT = None,
**kwargs: typing.Any,
) -> exec_result.ExecResult:
) -> ExecResult:
"""Execute command and wait for return code.
:param command: Command for execution
Expand Down Expand Up @@ -465,9 +469,9 @@ def check_call( # pylint: disable=arguments-differ
self,
command: "CommandT",
verbose: bool = False,
timeout: "OptionalTimeoutT" = constants.DEFAULT_TIMEOUT,
timeout: "OptionalTimeoutT" = DEFAULT_TIMEOUT,
error_info: typing.Optional[str] = None,
expected: "typing.Iterable[ExitCodeT]" = (proc_enums.EXPECTED,),
expected: "typing.Iterable[ExitCodeT]" = (EXPECTED,),
raise_on_err: bool = True,
*,
log_mask_re: typing.Optional[str] = None,
Expand All @@ -477,9 +481,9 @@ def check_call( # pylint: disable=arguments-differ
cwd: CwdT = None,
env: EnvT = None,
env_patch: EnvT = None,
exception_class: "CalledProcessErrorSubClassT" = exceptions.CalledProcessError,
exception_class: "CalledProcessErrorSubClassT" = CalledProcessError,
**kwargs: typing.Any,
) -> exec_result.ExecResult:
) -> ExecResult:
"""Execute command and check for return code.
:param command: Command for execution
Expand All @@ -491,7 +495,7 @@ def check_call( # pylint: disable=arguments-differ
:param error_info: Text for error details, if fail happens
:type error_info: typing.Optional[str]
:param expected: expected return codes (0 by default)
:type expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]]
:type expected: typing.Iterable[typing.Union[int, ExitCodes]]
:param raise_on_err: Raise exception on unexpected return code
:type raise_on_err: bool
:param log_mask_re: regex lookup rule to mask command for logger.
Expand All @@ -510,7 +514,7 @@ def check_call( # pylint: disable=arguments-differ
:param env_patch: Defines the environment variables to ADD for the new process.
:type env_patch: typing.Optional[typing.Mapping[typing.Union[str, bytes], typing.Union[str, bytes]]]
:param exception_class: Exception class for errors. Subclass of CalledProcessError is mandatory.
:type exception_class: typing.Type[exceptions.CalledProcessError]
:type exception_class: typing.Type[CalledProcessError]
:param kwargs: additional parameters for call.
:type kwargs: typing.Any
:return: Execution result
Expand Down Expand Up @@ -544,21 +548,21 @@ def check_stderr( # pylint: disable=arguments-differ
self,
command: "CommandT",
verbose: bool = False,
timeout: "OptionalTimeoutT" = constants.DEFAULT_TIMEOUT,
timeout: "OptionalTimeoutT" = DEFAULT_TIMEOUT,
error_info: typing.Optional[str] = None,
raise_on_err: bool = True,
*,
expected: "typing.Iterable[ExitCodeT]" = (proc_enums.EXPECTED,),
expected: "typing.Iterable[ExitCodeT]" = (EXPECTED,),
log_mask_re: typing.Optional[str] = None,
stdin: "OptionalStdinT" = None,
open_stdout: bool = True,
open_stderr: bool = True,
cwd: CwdT = None,
env: EnvT = None,
env_patch: EnvT = None,
exception_class: "CalledProcessErrorSubClassT" = exceptions.CalledProcessError,
exception_class: "CalledProcessErrorSubClassT" = CalledProcessError,
**kwargs: typing.Any,
) -> exec_result.ExecResult:
) -> ExecResult:
"""Execute command expecting return code 0 and empty STDERR.
:param command: Command for execution
Expand All @@ -572,7 +576,7 @@ def check_stderr( # pylint: disable=arguments-differ
:param raise_on_err: Raise exception on unexpected return code
:type raise_on_err: bool
:param expected: expected return codes (0 by default)
:type expected: typing.Iterable[typing.Union[int, proc_enums.ExitCodes]]
:type expected: typing.Iterable[typing.Union[int, ExitCodes]]
:param log_mask_re: regex lookup rule to mask command for logger.
all MATCHED groups will be replaced by '<*masked*>'
:type log_mask_re: typing.Optional[str]
Expand All @@ -589,7 +593,7 @@ def check_stderr( # pylint: disable=arguments-differ
:param env_patch: Defines the environment variables to ADD for the new process.
:type env_patch: typing.Optional[typing.Mapping[typing.Union[str, bytes], typing.Union[str, bytes]]]
:param exception_class: Exception class for errors. Subclass of CalledProcessError is mandatory.
:type exception_class: typing.Type[exceptions.CalledProcessError]
:type exception_class: typing.Type[CalledProcessError]
:param kwargs: additional parameters for call.
:type kwargs: typing.Any
:return: Execution result
Expand Down

0 comments on commit ee5fb17

Please sign in to comment.