From dd6307ca4dd05f7570a219d088771aa56c395174 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Wed, 30 Jun 2021 23:38:32 +0300 Subject: [PATCH] ssh: don't use path_info for connections --- dvc/fs/ssh/__init__.py | 42 +++++++++++++++++++++--------------------- dvc/objects/db/ssh.py | 4 ++-- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/dvc/fs/ssh/__init__.py b/dvc/fs/ssh/__init__.py index 192e3ef3fe..496e250438 100644 --- a/dvc/fs/ssh/__init__.py +++ b/dvc/fs/ssh/__init__.py @@ -119,16 +119,16 @@ def ensure_credentials(self): if self.ask_password and self.password is None: self.password = ask_password(self.host, self.user, self.port) - def ssh(self, path_info): + def ssh(self): self.ensure_credentials() from .connection import SSHConnection return get_connection( SSHConnection, - path_info.host, - username=path_info.user, - port=path_info.port, + self.host, + username=self.user, + port=self.port, key_filename=self.keyfile, timeout=self.timeout, password=self.password, @@ -141,7 +141,7 @@ def ssh(self, path_info): def open(self, path_info, mode="r", encoding=None, **kwargs): assert mode in {"r", "rt", "rb", "wb"} - with self.ssh(path_info) as ssh, closing( + with self.ssh() as ssh, closing( ssh.sftp.open(path_info.path, mode) ) as fd: if "b" in mode: @@ -150,19 +150,19 @@ def open(self, path_info, mode="r", encoding=None, **kwargs): yield io.TextIOWrapper(fd, encoding=encoding) def exists(self, path_info) -> bool: - with self.ssh(path_info) as ssh: + with self.ssh() as ssh: return ssh.exists(path_info.path) def isdir(self, path_info): - with self.ssh(path_info) as ssh: + with self.ssh() as ssh: return ssh.isdir(path_info.path) def isfile(self, path_info): - with self.ssh(path_info) as ssh: + with self.ssh() as ssh: return ssh.isfile(path_info.path) def walk_files(self, path_info, **kwargs): - with self.ssh(path_info) as ssh: + with self.ssh() as ssh: for fname in ssh.walk_files(path_info.path): yield path_info.replace(path=fname) @@ -170,32 +170,32 @@ def remove(self, path_info): if path_info.scheme != self.scheme: raise NotImplementedError - with self.ssh(path_info) as ssh: + with self.ssh() as ssh: ssh.remove(path_info.path) def makedirs(self, path_info): - with self.ssh(path_info) as ssh: + with self.ssh() as ssh: ssh.makedirs(path_info.path) def move(self, from_info, to_info): if from_info.scheme != self.scheme or to_info.scheme != self.scheme: raise NotImplementedError - with self.ssh(from_info) as ssh: + with self.ssh() as ssh: ssh.move(from_info.path, to_info.path) def copy(self, from_info, to_info): if not from_info.scheme == to_info.scheme == self.scheme: raise NotImplementedError - with self.ssh(from_info) as ssh: + with self.ssh() as ssh: ssh.atomic_copy(from_info.path, to_info.path) def symlink(self, from_info, to_info): if not from_info.scheme == to_info.scheme == self.scheme: raise NotImplementedError - with self.ssh(from_info) as ssh: + with self.ssh() as ssh: ssh.symlink(from_info.path, to_info.path) def hardlink(self, from_info, to_info): @@ -205,7 +205,7 @@ def hardlink(self, from_info, to_info): # See dvc/remote/local/__init__.py - hardlink() if self.getsize(from_info) == 0: - with self.ssh(to_info) as ssh: + with self.ssh() as ssh: ssh.sftp.open(to_info.path, "w").close() logger.debug( @@ -215,18 +215,18 @@ def hardlink(self, from_info, to_info): ) return - with self.ssh(from_info) as ssh: + with self.ssh() as ssh: ssh.hardlink(from_info.path, to_info.path) def reflink(self, from_info, to_info): if from_info.scheme != self.scheme or to_info.scheme != self.scheme: raise NotImplementedError - with self.ssh(from_info) as ssh: + with self.ssh() as ssh: ssh.reflink(from_info.path, to_info.path) def md5(self, path_info): - with self.ssh(path_info) as ssh: + with self.ssh() as ssh: return HashInfo( "md5", ssh.md5(path_info.path), @@ -234,7 +234,7 @@ def md5(self, path_info): ) def info(self, path_info): - with self.ssh(path_info) as ssh: + with self.ssh() as ssh: return ssh.info(path_info.path) def _upload_fobj(self, fobj, to_info, **kwargs): @@ -243,7 +243,7 @@ def _upload_fobj(self, fobj, to_info, **kwargs): shutil.copyfileobj(fobj, fdest) def _download(self, from_info, to_file, name=None, no_progress_bar=False): - with self.ssh(from_info) as ssh: + with self.ssh() as ssh: ssh.download( from_info.path, to_file, @@ -254,7 +254,7 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs ): - with self.ssh(to_info) as ssh: + with self.ssh() as ssh: ssh.upload( from_file, to_info.path, diff --git a/dvc/objects/db/ssh.py b/dvc/objects/db/ssh.py index ded0d41dbe..1f981f1004 100644 --- a/dvc/objects/db/ssh.py +++ b/dvc/objects/db/ssh.py @@ -28,7 +28,7 @@ def _exists(chunk_and_channel): callback(path) return ret - with self.fs.ssh(path_infos[0]) as ssh: + with self.fs.ssh() as ssh: channels = ssh.open_max_sftp_channels() max_workers = len(channels) @@ -78,7 +78,7 @@ def _list_paths(self, prefix=None, progress_callback=None): root = posixpath.join(self.path_info.path, prefix[:2]) else: root = self.path_info.path - with self.fs.ssh(self.path_info) as ssh: + with self.fs.ssh() as ssh: if prefix and not ssh.exists(root): return # If we simply return an iterator then with above closes instantly