Skip to content

Commit

Permalink
Begin using Self as return type for __aenter__ methods
Browse files Browse the repository at this point in the history
This commit bumps the minimum version of typing_extensions to 4.0.0 to
allow the use of "Self" as a type annotation, and changes all the places
where __aenter__ is defined to use this annotation. Thanks go to
Matthew Bradbury for suggesting this and providing a PR for this
change in SSHProcess. This commit is a superset of that PR.
  • Loading branch information
ronf committed May 18, 2024
1 parent 5468dca commit 38820c7
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 28 deletions.
4 changes: 2 additions & 2 deletions asyncssh/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import sys
from types import TracebackType
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Protocol
from typing_extensions import Protocol, Self

from .listener import SSHForwardListener
from .misc import async_context_manager, maybe_wait_closed
Expand Down Expand Up @@ -198,7 +198,7 @@ def __init__(self, agent_path: _AgentPath):
self._writer: Optional[AgentWriter] = None
self._lock = asyncio.Lock()

async def __aenter__(self) -> 'SSHAgentClient':
async def __aenter__(self) -> Self:
"""Allow SSHAgentClient to be used as an async context manager"""

return self
Expand Down
24 changes: 11 additions & 13 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from typing import TYPE_CHECKING, Any, AnyStr, Awaitable, Callable, Dict
from typing import List, Mapping, Optional, Sequence, Set, Tuple, Type
from typing import TypeVar, Union, cast
from typing_extensions import Protocol
from typing_extensions import Protocol, Self

from .agent import SSHAgentClient, SSHAgentListener

Expand Down Expand Up @@ -182,9 +182,8 @@
_ServerFactory = Callable[[], SSHServer]
_ProtocolFactory = Union[_ClientFactory, _ServerFactory]

_Conn = TypeVar('_Conn', 'SSHClientConnection', 'SSHServerConnection')
_ConnSelf = TypeVar('_ConnSelf', bound='SSHConnection')
_OptionsSelf = TypeVar('_OptionsSelf', bound='SSHConnectionOptions')
_Conn = TypeVar('_Conn', bound='SSHConnection')
_Options = TypeVar('_Options', bound='SSHConnectionOptions')

class _TunnelProtocol(Protocol):
"""Base protocol for connections to tunnel SSH over"""
Expand Down Expand Up @@ -382,7 +381,7 @@ async def _open_tunnel(tunnels: object, passphrase: Optional[BytesOrStr],
return None


async def _connect(options: '_OptionsSelf', config: DefTuple[ConfigPaths],
async def _connect(options: '_Options', config: DefTuple[ConfigPaths],
loop: asyncio.AbstractEventLoop, flags: int,
sock: Optional[socket.socket],
conn_factory: Callable[[], _Conn], msg: str) -> _Conn:
Expand Down Expand Up @@ -456,7 +455,7 @@ async def _connect(options: '_OptionsSelf', config: DefTuple[ConfigPaths],
await conn.wait_closed()


async def _listen(options: '_OptionsSelf', config: DefTuple[ConfigPaths],
async def _listen(options: '_Options', config: DefTuple[ConfigPaths],
loop: asyncio.AbstractEventLoop, flags: int,
backlog: int, sock: Optional[socket.socket],
reuse_address: bool, reuse_port: bool,
Expand Down Expand Up @@ -677,7 +676,7 @@ def __init__(self, server: asyncio.AbstractServer,
self._server = server
self._options = options

async def __aenter__(self) -> 'SSHAcceptor':
async def __aenter__(self) -> Self:
return self

async def __aexit__(self, _exc_type: Optional[Type[BaseException]],
Expand Down Expand Up @@ -942,7 +941,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop,

self._disable_trivial_auth = False

async def __aenter__(self: _ConnSelf) -> _ConnSelf:
async def __aenter__(self) -> Self:
"""Allow SSHConnection to be used as an async context manager"""

return self
Expand Down Expand Up @@ -6964,19 +6963,18 @@ class SSHConnectionOptions(Options):
keepalive_internal: float
keepalive_count_max: int

def __init__(self, options: Optional['_OptionsSelf'] = None,
**kwargs: object):
def __init__(self, options: Optional['_Options'] = None, **kwargs: object):
last_config = options.config if options else None
super().__init__(options=options, last_config=last_config, **kwargs)

@classmethod
async def construct(cls, options: Optional['_OptionsSelf'] = None,
**kwargs: object) -> _OptionsSelf:
async def construct(cls, options: Optional['_Options'] = None,
**kwargs: object) -> _Options:
"""Construct a new options object from within an async task"""

loop = asyncio.get_event_loop()

return cast(_OptionsSelf, await loop.run_in_executor(
return cast(_Options, await loop.run_in_executor(
None, functools.partial(cls, options, loop=loop, **kwargs)))

# pylint: disable=arguments-differ
Expand Down
3 changes: 2 additions & 1 deletion asyncssh/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
from typing import Type, cast
from typing_extensions import Self

from .misc import ChannelOpenError, SockAddr

Expand Down Expand Up @@ -55,7 +56,7 @@ def __init__(self, peer: Optional['SSHForwarder'] = None,

self._extra = extra

async def __aenter__(self) -> 'SSHForwarder':
async def __aenter__(self) -> Self:
return self

async def __aexit__(self, _exc_type: Optional[Type[BaseException]],
Expand Down
3 changes: 2 additions & 1 deletion asyncssh/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from types import TracebackType
from typing import TYPE_CHECKING, AnyStr, Callable, Generic, List, Optional
from typing import Sequence, Tuple, Type, Union
from typing_extensions import Self

from .forward import SSHForwarderCoro
from .forward import SSHLocalPortForwarder, SSHLocalPathForwarder
Expand Down Expand Up @@ -54,7 +55,7 @@ class SSHListener:
def __init__(self) -> None:
self._tunnel: Optional['SSHConnection'] = None

async def __aenter__(self) -> 'SSHListener':
async def __aenter__(self) -> Self:
return self

async def __aexit__(self, _exc_type: Optional[Type[BaseException]],
Expand Down
4 changes: 2 additions & 2 deletions asyncssh/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from typing import Any, AnyStr, Awaitable, Callable, Dict, Generic, IO
from typing import Iterable, List, Mapping, Optional, Set, TextIO
from typing import Tuple, Type, TypeVar, Union, cast
from typing_extensions import Protocol
from typing_extensions import Protocol, Self

from .channel import SSHChannel, SSHClientChannel, SSHServerChannel

Expand Down Expand Up @@ -747,7 +747,7 @@ def __init__(self, *args) -> None:

self._paused_write_streams: Set[Optional[int]] = set()

async def __aenter__(self) -> 'SSHProcess[AnyStr]':
async def __aenter__(self) -> Self:
"""Allow SSHProcess to be used as an async context manager"""

return self
Expand Down
4 changes: 2 additions & 2 deletions asyncssh/scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from types import TracebackType
from typing import TYPE_CHECKING, AsyncIterator, List, NoReturn, Optional
from typing import Sequence, Tuple, Type, Union, cast
from typing_extensions import Protocol
from typing_extensions import Protocol, Self

from .constants import DEFAULT_LANG
from .constants import FILEXFER_TYPE_REGULAR, FILEXFER_TYPE_DIRECTORY
Expand Down Expand Up @@ -240,7 +240,7 @@ def __init__(self, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]',

self._logger = reader.logger.get_child('sftp')

async def __aenter__(self) -> '_SCPHandler': # pragma: no cover
async def __aenter__(self) -> Self: # pragma: no cover
"""Allow _SCPHandler to be used as an async context manager"""

return self
Expand Down
12 changes: 6 additions & 6 deletions asyncssh/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from typing import TYPE_CHECKING, AnyStr, AsyncIterator, Awaitable, Callable
from typing import Dict, Generic, IO, Iterable, List, Mapping, Optional
from typing import Sequence, Set, Tuple, Type, TypeVar, Union, cast, overload
from typing_extensions import Literal, Protocol
from typing_extensions import Literal, Protocol, Self

from . import constants
from .constants import DEFAULT_LANG
Expand Down Expand Up @@ -208,7 +208,7 @@ def scandir(self, path: bytes) -> AsyncIterator['SFTPName']:
class SFTPFileProtocol(Protocol):
"""Protocol for accessing a file via an SFTP server"""

async def __aenter__(self) -> 'SFTPFileProtocol':
async def __aenter__(self) -> Self:
"""Allow SFTPFileProtocol to be used as an async context manager"""

async def __aexit__(self, _exc_type: Optional[Type[BaseException]],
Expand Down Expand Up @@ -2999,7 +2999,7 @@ def __init__(self, handler: SFTPClientHandler, handle: bytes,
self._max_requests = max_requests
self._offset = None if appending else 0

async def __aenter__(self) -> 'SFTPClientFile':
async def __aenter__(self) -> Self:
"""Allow SFTPClientFile to be used as an async context manager"""

return self
Expand Down Expand Up @@ -3463,7 +3463,7 @@ def __init__(self, handler: SFTPClientHandler,
self._path_errors = path_errors
self._cwd: Optional[bytes] = None

async def __aenter__(self) -> 'SFTPClient':
async def __aenter__(self) -> Self:
"""Allow SFTPClient to be used as an async context manager"""

return self
Expand Down Expand Up @@ -7277,7 +7277,7 @@ class LocalFile:
def __init__(self, file: _SFTPFileObj):
self._file = file

async def __aenter__(self) -> 'LocalFile': # pragma: no cover
async def __aenter__(self) -> Self: # pragma: no cover
"""Allow LocalFile to be used as an async context manager"""

return self
Expand Down Expand Up @@ -7407,7 +7407,7 @@ def __init__(self, server: SFTPServer, file_obj: object):
self._server = server
self._file_obj = file_obj

async def __aenter__(self) -> 'SFTPServerFile': # pragma: no cover
async def __aenter__(self) -> Self: # pragma: no cover
"""Allow SFTPServerFile to be used as an async context manager"""

return self
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@
long_description = long_description,
platforms = 'Any',
python_requires = '>= 3.6',
install_requires = ['cryptography >= 39.0', 'typing_extensions >= 3.6'],
install_requires = [
'cryptography >= 39.0',
'typing_extensions >= 4.0.0'],
extras_require = {
'bcrypt': ['bcrypt >= 3.1.3'],
'fido2': ['fido2 >= 0.9.2'],
Expand Down

0 comments on commit 38820c7

Please sign in to comment.