Skip to content

Commit

Permalink
switch on hugh level API only for ssh connect
Browse files Browse the repository at this point in the history
* hostname is mandatory
* support ssh compression
  • Loading branch information
penguinolog committed Nov 22, 2019
1 parent 69fa286 commit 0a8ba0d
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 44 deletions.
11 changes: 8 additions & 3 deletions doc/source/SSHClient.rst
Expand Up @@ -503,19 +503,24 @@ API: SSHClient and SSHAuth.
:param tgt: Target
:type tgt: file

.. py:method:: connect(client, hostname=None, port=22, log=True, )
.. py:method:: connect(client, hostname, port=22, log=True, *, sock=None, compress=False)
Connect SSH client object using credentials.

:param client: SSH Client (low level)
:type client: ``typing.Union[paramiko.client.SSHClient, paramiko.transport.Transport]``
:type client: ``paramiko.SSHClient``
:param hostname: remote hostname
:type hostname: ``str``
:param port: remote ssh port
:type port: ``int``
:param log: Log on generic connection failure
:type log: ``bool``
:raises paramiko.AuthenticationException: Authentication failed.
:param sock: socket for connection. Useful for ssh proxies support
:type sock: ``typing.Optional[typing.Union[paramiko.ProxyCommand, paramiko.Channel, socket.socket]]``
:param compress: use SSH compression
:type compress: ``bool``
:raises PasswordRequiredException: No password has been set, but required.
:raises AuthenticationException: Authentication failed.


.. py:class::SSHAuthMapping(typing.Dict[str, SSHAuth])
Expand Down
16 changes: 13 additions & 3 deletions exec_helpers/_ssh_client_base.py
Expand Up @@ -392,7 +392,12 @@ def __connect(self) -> None:
self.__ssh = paramiko.SSHClient()
self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.auth.connect(
client=self.__ssh, hostname=self.hostname, port=self.port, log=self.__verbose, sock=sock
client=self.__ssh,
hostname=self.hostname,
port=self.port,
log=self.__verbose,
sock=sock,
compress=bool(self.ssh_config[self.hostname].compression),
)
else:
self.__ssh = self.__get_client()
Expand All @@ -416,9 +421,12 @@ def __get_client(self) -> paramiko.SSHClient:
hostname=config.hostname,
port=config.port or 22,
sock=paramiko.ProxyCommand(config.proxycommand),
compress=bool(config.compression),
)
else:
auth.connect(last_ssh_client, hostname=config.hostname, port=config.port or 22)
auth.connect(
last_ssh_client, hostname=config.hostname, port=config.port or 22, compress=bool(config.compression)
)

for config, auth in self.__conn_chain[1:]: # start has another logic, so do it out of cycle
ssh = paramiko.SSHClient()
Expand All @@ -428,7 +436,9 @@ def __get_client(self) -> paramiko.SSHClient:
sock = last_ssh_client.get_transport().open_channel(
kind="direct-tcpip", dest_addr=(config.hostname, config.port or 22), src_addr=(config.proxyjump, 0),
)
auth.connect(ssh, hostname=config.hostname, port=config.port or 22, sock=sock)
auth.connect(
ssh, hostname=config.hostname, port=config.port or 22, sock=sock, compress=bool(config.compression)
)
last_ssh_client = ssh
continue

Expand Down
8 changes: 4 additions & 4 deletions exec_helpers/api.py
Expand Up @@ -199,7 +199,7 @@ def mask(text: str, rules: str) -> str:
:return: source with all MATCHED groups replaced by '<*masked*>'
"""
indexes: typing.List[int] = [0] # Start of the line
masked: str = ""
masked: typing.List[str] = []

# places to exclude
for match in re.finditer(rules, text):
Expand All @@ -211,11 +211,11 @@ def mask(text: str, rules: str) -> str:
for idx in range(0, len(indexes) - 2, 2):
start: int = indexes[idx]
end: int = indexes[idx + 1]
masked += text[start:end] + "<*masked*>"
masked.append(text[start:end] + "<*masked*>")

# noinspection PyPep8
masked += text[indexes[-2] : indexes[-1]] # final part
return masked
masked.append(text[indexes[-2] : indexes[-1]]) # final part
return "".join(masked)

result: str = cmd.rstrip()

Expand Down
26 changes: 16 additions & 10 deletions exec_helpers/ssh_auth.py
Expand Up @@ -137,17 +137,18 @@ def enter_password(self, tgt: typing.BinaryIO) -> None:

def connect(
self,
client: typing.Union[paramiko.SSHClient, paramiko.Transport],
hostname: typing.Optional[str] = None,
client: paramiko.SSHClient,
hostname: str,
port: int = 22,
log: bool = True,
*,
sock: typing.Optional[typing.Union[paramiko.ProxyCommand, paramiko.Channel, socket.socket]] = None,
compress: bool = False,
) -> None:
"""Connect SSH client object using credentials.
:param client: SSH Client (low level)
:type client: typing.Union[paramiko.SSHClient, paramiko.Transport]
:type client: paramiko.SSHClient
:param hostname: remote hostname
:type hostname: str
:param port: remote ssh port
Expand All @@ -156,16 +157,13 @@ def connect(
:type log: bool
:param sock: socket for connection. Useful for ssh proxies support
:type sock: typing.Optional[typing.Union[paramiko.ProxyCommand, paramiko.Channel, socket.socket]]
:param compress: use SSH compression
:type compress: bool
:raises PasswordRequiredException: No password has been set, but required.
:raises AuthenticationException: Authentication failed.
"""
kwargs: typing.Dict[str, typing.Any] = {"username": self.username, "password": self.__password}
if hostname is not None:
kwargs["hostname"] = hostname
kwargs["port"] = port
kwargs: typing.Dict[str, typing.Any] = {}

if self.key_filename is not None:
kwargs["key_filename"] = self.key_filename
if self.__passphrase is not None:
kwargs["passphrase"] = self.__passphrase
if sock is not None:
Expand All @@ -177,7 +175,15 @@ def connect(
for key in keys:
kwargs["pkey"] = key
try:
client.connect(**kwargs)
client.connect(
hostname=hostname,
port=port,
username=self.username,
password=self.__password,
key_filename=self.key_filename,
compress=compress,
**kwargs,
)
if self.__key != key:
self.__key = key
LOGGER.debug(f"Main key has been updated, public key is: \n{self.public_key}")
Expand Down
26 changes: 16 additions & 10 deletions exec_helpers/ssh_auth.pyx
Expand Up @@ -125,17 +125,18 @@ cdef class SSHAuth:

def connect(
self,
client: typing.Union[paramiko.SSHClient, paramiko.Transport],
hostname: typing.Optional[str] = None,
client: paramiko.SSHClient,
str hostname,
unsigned int port = 22,
bint log = True,
*,
sock: typing.Optional[typing.Union[paramiko.ProxyCommand, paramiko.Channel, socket.socket]] = None,
bint compress = False,
) -> None:
"""Connect SSH client object using credentials.

:param client: SSH Client (low level)
:type client: typing.Union[paramiko.SSHClient, paramiko.Transport]
:type client: paramiko.SSHClient
:param hostname: remote hostname
:type hostname: str
:param port: remote ssh port
Expand All @@ -144,16 +145,13 @@ cdef class SSHAuth:
:type log: bool
:param sock: socket for connection. Useful for ssh proxies support
:type sock: typing.Optional[typing.Union[paramiko.ProxyCommand, paramiko.Channel, socket.socket]]
:param compress: use SSH compression
:type compress: bool
:raises PasswordRequiredException: No password has been set, but required.
:raises AuthenticationException: Authentication failed.
"""
kwargs = {"username": self.username, "password": self.password} # type: typing.Dict[str, typing.Any]
if hostname is not None:
kwargs["hostname"] = hostname
kwargs["port"] = port
kwargs = {} # type: typing.Dict[str, typing.Any]

if self.key_filename is not None:
kwargs["key_filename"] = self.key_filename
if self.passphrase is not None:
kwargs["passphrase"] = self.passphrase
if sock is not None:
Expand All @@ -165,7 +163,15 @@ cdef class SSHAuth:
for key in keys:
kwargs["pkey"] = key
try:
client.connect(**kwargs)
client.connect(
hostname=hostname,
port=port,
username=self.username,
password=self.password,
key_filename=self.key_filename,
compress=compress,
**kwargs,
)
if self.key != key:
self.key = key
LOGGER.debug(f"Main key has been updated, public key is: \n{self.public_key}")
Expand Down
7 changes: 0 additions & 7 deletions test/test_ssh_client_execute.py
Expand Up @@ -273,13 +273,6 @@ def test_001_execute_async(ssh, paramiko_ssh_client, ssh_transport_channel, chan
else:
assert res.stdout is None

assert paramiko_ssh_client.mock_calls == [
mock.call(),
mock.call().set_missing_host_key_policy("AutoAddPolicy"),
mock.call().connect(hostname="127.0.0.1", password="pass", pkey=None, port=22, username="user"),
mock.call().get_transport(),
]

transport_calls = []
if get_pty:
transport_calls.append(
Expand Down
24 changes: 22 additions & 2 deletions test/test_ssh_client_execute_throw_host.py
Expand Up @@ -138,14 +138,24 @@ def test_01_execute_through_host_no_creds(
]
connect.assert_has_calls(
[
mock.call(hostname=host, password=password, pkey=None, port=port, username=username),
mock.call(
hostname=host,
password=password,
pkey=None,
port=port,
username=username,
compress=False,
key_filename=None,
),
mock.call(
hostname=target,
port=port,
username=username,
password=password,
pkey=None,
sock=ssh_intermediate_channel(),
compress=False,
key_filename=None,
),
]
)
Expand All @@ -164,14 +174,24 @@ def test_02_execute_through_host_with_creds(
]
connect.assert_has_calls(
[
mock.call(hostname=host, password=password, pkey=None, port=port, username=username),
mock.call(
hostname=host,
password=password,
pkey=None,
port=port,
username=username,
compress=False,
key_filename=None,
),
mock.call(
hostname=target,
port=port,
username=username_2,
password=password_2,
pkey=None,
sock=ssh_intermediate_channel(),
compress=False,
key_filename=None,
),
]
)
Expand Down
12 changes: 10 additions & 2 deletions test/test_ssh_client_init_basic.py
Expand Up @@ -124,7 +124,15 @@ def test_init_base(paramiko_ssh_client, auto_add_policy, run_parameters, ssh_aut
if auth is None:
expected_calls = [
_ssh.set_missing_host_key_policy("AutoAddPolicy"),
_ssh.connect(hostname=host, password=password, pkey=None, port=port, username=username),
_ssh.connect(
hostname=host,
password=password,
pkey=None,
port=port,
username=username,
compress=False,
key_filename=None,
),
]

assert expected_calls == paramiko_ssh_client().mock_calls
Expand All @@ -144,7 +152,7 @@ def test_init_base(paramiko_ssh_client, auto_add_policy, run_parameters, ssh_aut
# ssh config for main connection is synchronised with connection parameters
expected_config_dict = {host: {"hostname": host, "port": ssh.port}}
if ssh.auth.username:
expected_config_dict[host]['user'] = ssh.auth.username
expected_config_dict[host]["user"] = ssh.auth.username

assert ssh.ssh_config == expected_config_dict
assert ssh.ssh_config[host].hostname == host
12 changes: 9 additions & 3 deletions test/test_ssh_client_init_special.py
Expand Up @@ -78,7 +78,9 @@ def test_001_require_key(paramiko_ssh_client, auto_add_policy, ssh_auth_logger):

pkey = private_keys[0]

kwargs = dict(hostname=host, pkey=None, port=port, username=username, password=None)
kwargs = dict(
hostname=host, pkey=None, port=port, username=username, password=None, compress=False, key_filename=None,
)
kwargs1 = {key: kwargs[key] for key in kwargs}
kwargs1["pkey"] = pkey

Expand Down Expand Up @@ -111,7 +113,9 @@ def test_002_use_next_key(paramiko_ssh_client, auto_add_policy, ssh_auth_logger)

ssh_auth_logger.debug.assert_called_once_with(f"Main key has been updated, public key is: \n{ssh.auth.public_key}")

kwargs = dict(hostname=host, pkey=None, port=port, username=username, password=None)
kwargs = dict(
hostname=host, pkey=None, port=port, username=username, password=None, compress=False, key_filename=None,
)
kwargs0 = {key: kwargs[key] for key in kwargs}
kwargs0["pkey"] = private_keys[0]
kwargs1 = {key: kwargs[key] for key in kwargs}
Expand Down Expand Up @@ -314,7 +318,9 @@ def test_012_re_connect(paramiko_ssh_client, auto_add_policy, ssh_auth_logger):
_ssh.close(),
_ssh,
_ssh.set_missing_host_key_policy("AutoAddPolicy"),
_ssh.connect(hostname="127.0.0.1", password=None, pkey=None, port=22, username=None),
_ssh.connect(
hostname="127.0.0.1", password=None, pkey=None, port=22, username=None, compress=False, key_filename=None,
),
]

assert paramiko_ssh_client.mock_calls == expected_calls
Expand Down

0 comments on commit 0a8ba0d

Please sign in to comment.