diff --git a/dvc/dependency/http.py b/dvc/dependency/http.py index 5b216c0fe4..646ca98175 100644 --- a/dvc/dependency/http.py +++ b/dvc/dependency/http.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +from dvc.path.http import HTTPPathInfo from dvc.utils.compat import urlparse, urljoin from dvc.output.base import OutputBase from dvc.remote.http import RemoteHTTP @@ -16,5 +17,4 @@ def __init__(self, stage, path, info=None, remote=None): if path.startswith("remote"): path = urljoin(self.remote.cache_dir, urlparse(path).path) - self.path_info["scheme"] = urlparse(path).scheme - self.path_info["path"] = path + self.path_info = HTTPPathInfo(url=self.url, path=path) diff --git a/dvc/output/__init__.py b/dvc/output/__init__.py index 7b7176a3ac..54e76c52d9 100644 --- a/dvc/output/__init__.py +++ b/dvc/output/__init__.py @@ -2,6 +2,7 @@ import schema +from dvc.path import Schemes from dvc.utils.compat import urlparse, str from dvc.output.base import OutputBase @@ -25,11 +26,11 @@ ] OUTS_MAP = { - "hdfs": OutputHDFS, - "s3": OutputS3, - "gs": OutputGS, - "ssh": OutputSSH, - "local": OutputLOCAL, + Schemes.HDFS: OutputHDFS, + Schemes.S3: OutputS3, + Schemes.GS: OutputGS, + Schemes.SSH: OutputSSH, + Schemes.LOCAL: OutputLOCAL, } # NOTE: currently there are only 3 possible checksum names: diff --git a/dvc/output/base.py b/dvc/output/base.py index 23704b6e39..9e9be29c68 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -2,6 +2,8 @@ import re import logging +from copy import copy + from schema import Or, Optional from dvc.exceptions import DvcException @@ -84,8 +86,6 @@ def __init__( ) ) - self.path_info = {"scheme": self.scheme, "url": self.url} - def __repr__(self): return "{class_name}: '{url}'".format( class_name=type(self).__name__, url=(self.url or "No url") @@ -131,7 +131,7 @@ def scheme(self): @property def path(self): - return self.path_info["path"] + return self.path_info.path @property def cache_path(self): @@ -241,7 +241,7 @@ def commit(self): self.cache.save(self.path_info, self.info) def dumpd(self): - ret = self.info.copy() + ret = copy(self.info) ret[self.PARAM_PATH] = self.url if self.IS_DEPENDENCY: @@ -302,7 +302,7 @@ def move(self, out): self.remote.move(self.path_info, out.path_info) self.url = out.url - self.path_info = out.path_info.copy() + self.path_info = copy(out.path_info) self.save() self.commit() diff --git a/dvc/output/hdfs.py b/dvc/output/hdfs.py index 2a657745e4..dd01d74889 100644 --- a/dvc/output/hdfs.py +++ b/dvc/output/hdfs.py @@ -2,6 +2,7 @@ import posixpath +from dvc.path.hdfs import HDFSPathInfo from dvc.utils.compat import urlparse from dvc.output.base import OutputBase from dvc.remote.hdfs import RemoteHDFS @@ -34,5 +35,4 @@ def __init__( if remote: path = posixpath.join(remote.url, urlparse(path).path.lstrip("/")) user = remote.user if remote else self.group("user") - self.path_info["user"] = user - self.path_info["path"] = path + self.path_info = HDFSPathInfo(user=user, url=self.url, path=path) diff --git a/dvc/output/local.py b/dvc/output/local.py index 3d9a0cd7f8..848e21fd24 100644 --- a/dvc/output/local.py +++ b/dvc/output/local.py @@ -3,6 +3,7 @@ import os import logging +from dvc.path.local import LocalPathInfo from dvc.utils.compat import urlparse from dvc.istextfile import istextfile from dvc.exceptions import DvcException @@ -49,7 +50,7 @@ def __init__( p = os.path.join(stage.wdir, p) p = os.path.abspath(os.path.normpath(p)) - self.path_info["path"] = p + self.path_info = LocalPathInfo(url=self.url, path=p) def __str__(self): return self.rel_path @@ -64,7 +65,7 @@ def assign_to_stage_file(self, stage): from dvc.repo import Repo fullpath = os.path.abspath(stage.wdir) - self.path_info["path"] = os.path.join(fullpath, self.stage_path) + self.path_info.path = os.path.join(fullpath, self.stage_path) self.repo = Repo(self.path) diff --git a/dvc/output/s3.py b/dvc/output/s3.py index 1b0305c945..caa0399705 100644 --- a/dvc/output/s3.py +++ b/dvc/output/s3.py @@ -2,6 +2,7 @@ import posixpath +from dvc.path.utils import PathInfo from dvc.remote.s3 import RemoteS3 from dvc.utils.compat import urlparse from dvc.output.base import OutputBase @@ -36,5 +37,6 @@ def __init__( if remote: path = posixpath.join(remote.prefix, path) - self.path_info["bucket"] = bucket - self.path_info["path"] = path + self.path_info = PathInfo( + self.scheme, bucket=bucket, path=path, url=self.url + ) diff --git a/dvc/output/ssh.py b/dvc/output/ssh.py index 2c90b70949..184b41218f 100644 --- a/dvc/output/ssh.py +++ b/dvc/output/ssh.py @@ -3,6 +3,7 @@ import getpass import posixpath +from dvc.path.ssh import SSHPathInfo from dvc.utils.compat import urlparse from dvc.output.base import OutputBase from dvc.remote.ssh import RemoteSSH @@ -48,7 +49,6 @@ def __init__( else: path = parsed.path - self.path_info["host"] = host - self.path_info["port"] = port - self.path_info["user"] = user - self.path_info["path"] = path + self.path_info = SSHPathInfo( + host=host, user=user, port=port, url=self.url, path=path + ) diff --git a/dvc/path/__init__.py b/dvc/path/__init__.py new file mode 100644 index 0000000000..df7c6d3ef5 --- /dev/null +++ b/dvc/path/__init__.py @@ -0,0 +1,34 @@ +from dvc.utils.compat import urlunsplit + + +class Schemes: + SSH = "ssh" + HDFS = "hdfs" + S3 = "s3" + AZURE = "azure" + HTTP = "http" + GS = "gs" + LOCAL = "local" + OSS = "oss" + + +class BasePathInfo(object): + scheme = None + + def __init__(self, url=None, path=None): + self.url = url + self.path = path + + def __str__(self): + return self.url + + +class DefaultCloudPathInfo(BasePathInfo): + def __init__(self, bucket, url=None, path=None): + super(DefaultCloudPathInfo, self).__init__(url, path) + self.bucket = bucket + + def __str__(self): + if not self.url: + return urlunsplit((self.scheme, self.bucket, self.path, "", "")) + return self.url diff --git a/dvc/path/azure.py b/dvc/path/azure.py new file mode 100644 index 0000000000..956d91581b --- /dev/null +++ b/dvc/path/azure.py @@ -0,0 +1,5 @@ +from dvc.path import Schemes, DefaultCloudPathInfo + + +class AzurePathInfo(DefaultCloudPathInfo): + scheme = Schemes.AZURE diff --git a/dvc/path/gs.py b/dvc/path/gs.py new file mode 100644 index 0000000000..087c584bed --- /dev/null +++ b/dvc/path/gs.py @@ -0,0 +1,5 @@ +from dvc.path import Schemes, DefaultCloudPathInfo + + +class GSPathInfo(DefaultCloudPathInfo): + scheme = Schemes.GS diff --git a/dvc/path/hdfs.py b/dvc/path/hdfs.py new file mode 100644 index 0000000000..d5ce133149 --- /dev/null +++ b/dvc/path/hdfs.py @@ -0,0 +1,15 @@ +from dvc.utils.compat import urlunsplit +from dvc.path import BasePathInfo, Schemes + + +class HDFSPathInfo(BasePathInfo): + scheme = Schemes.HDFS + + def __init__(self, user, url=None, path=None): + super(HDFSPathInfo, self).__init__(url, path) + self.user = user + + def __str__(self): + if not self.url: + return urlunsplit((self.scheme, self.user, self.path, "", "")) + return self.url diff --git a/dvc/path/http.py b/dvc/path/http.py new file mode 100644 index 0000000000..c9b9bc4a5c --- /dev/null +++ b/dvc/path/http.py @@ -0,0 +1,16 @@ +from dvc.path import BasePathInfo, Schemes +from dvc.utils.compat import urlparse, urlunsplit + + +class HTTPPathInfo(BasePathInfo): + @property + def scheme(self): + if self.path: + return urlparse(self.path).scheme + else: + return Schemes.HTTP + + def __str__(self): + if not self.url: + return urlunsplit((self.scheme, self.path, "", "", "")) + return self.url diff --git a/dvc/path/local.py b/dvc/path/local.py new file mode 100644 index 0000000000..cbde5c2a48 --- /dev/null +++ b/dvc/path/local.py @@ -0,0 +1,10 @@ +import os + +from dvc.path import BasePathInfo, Schemes + + +class LocalPathInfo(BasePathInfo): + scheme = Schemes.LOCAL + + def __str__(self): + return os.path.relpath(self.path) diff --git a/dvc/path/oss.py b/dvc/path/oss.py new file mode 100644 index 0000000000..606007b262 --- /dev/null +++ b/dvc/path/oss.py @@ -0,0 +1,5 @@ +from dvc.path import Schemes, DefaultCloudPathInfo + + +class OSSPathInfo(DefaultCloudPathInfo): + scheme = Schemes.OSS diff --git a/dvc/path/s3.py b/dvc/path/s3.py new file mode 100644 index 0000000000..a1c04fa755 --- /dev/null +++ b/dvc/path/s3.py @@ -0,0 +1,5 @@ +from dvc.path import Schemes, DefaultCloudPathInfo + + +class S3PathInfo(DefaultCloudPathInfo): + scheme = Schemes.S3 diff --git a/dvc/path/ssh.py b/dvc/path/ssh.py new file mode 100644 index 0000000000..3c9c2e185c --- /dev/null +++ b/dvc/path/ssh.py @@ -0,0 +1,25 @@ +from dvc.utils.compat import urlunsplit +from dvc.path import BasePathInfo, Schemes + + +class SSHPathInfo(BasePathInfo): + scheme = Schemes.SSH + + def __init__(self, host, user, port, url=None, path=None): + super(SSHPathInfo, self).__init__(url, path) + self.host = host + self.user = user + self.port = port + + def __str__(self): + if not self.url: + return urlunsplit( + ( + self.scheme, + "{}@{}:{}".format(self.user, self.host, self.port), + self.path, + "", + "", + ) + ) + return self.url diff --git a/dvc/path/utils.py b/dvc/path/utils.py new file mode 100644 index 0000000000..cf421ede7e --- /dev/null +++ b/dvc/path/utils.py @@ -0,0 +1,25 @@ +from dvc.path import Schemes +from dvc.path.azure import AzurePathInfo +from dvc.path.gs import GSPathInfo +from dvc.path.hdfs import HDFSPathInfo +from dvc.path.http import HTTPPathInfo +from dvc.path.local import LocalPathInfo +from dvc.path.oss import OSSPathInfo +from dvc.path.s3 import S3PathInfo +from dvc.path.ssh import SSHPathInfo + +PATH_MAP = { + Schemes.SSH: SSHPathInfo, + Schemes.HDFS: HDFSPathInfo, + Schemes.S3: S3PathInfo, + Schemes.AZURE: AzurePathInfo, + Schemes.HTTP: HTTPPathInfo, + Schemes.GS: GSPathInfo, + Schemes.LOCAL: LocalPathInfo, + Schemes.OSS: OSSPathInfo, +} + + +def PathInfo(scheme, *args, **kwargs): + cls = PATH_MAP[scheme] + return cls(*args, **kwargs) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 9e2a19b6ed..263c2f724c 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -5,6 +5,9 @@ import re import logging +from dvc.path import Schemes +from dvc.path.azure import AzurePathInfo + try: from azure.storage.blob import BlockBlobService from azure.common import AzureMissingResourceHttpError @@ -30,7 +33,7 @@ def __call__(self, current, total): class RemoteAzure(RemoteBase): - scheme = "azure" + scheme = Schemes.AZURE REGEX = ( r"azure://((?P[^=;]*)?|(" # backward compatibility @@ -69,8 +72,7 @@ def __init__(self, repo, config): raise ValueError("azure storage connection string missing") self.__blob_service = None - - self.path_info = {"scheme": self.scheme, "bucket": self.bucket} + self.path_info = AzurePathInfo(bucket=self.bucket) @property def blob_service(self): @@ -90,16 +92,14 @@ def blob_service(self): return self.__blob_service def remove(self, path_info): - if path_info["scheme"] != self.scheme: + if path_info.scheme != self.scheme: raise NotImplementedError logger.debug( - "Removing azure://{}/{}".format( - path_info["bucket"], path_info["path"] - ) + "Removing azure://{}/{}".format(path_info.bucket, path_info.path) ) - self.blob_service.delete_blob(path_info["bucket"], path_info["path"]) + self.blob_service.delete_blob(path_info.bucket, path_info.path) def _list_paths(self, bucket, prefix): blob_service = self.blob_service @@ -124,32 +124,32 @@ def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): names = self._verify_path_args(to_infos, from_infos, names) for from_info, to_info, name in zip(from_infos, to_infos, names): - if to_info["scheme"] != self.scheme: + if to_info.scheme != self.scheme: raise NotImplementedError - if from_info["scheme"] != "local": + if from_info.scheme != "local": raise NotImplementedError - bucket = to_info["bucket"] - path = to_info["path"] + bucket = to_info.bucket + path = to_info.path logger.debug( "Uploading '{}' to '{}/{}'".format( - from_info["path"], bucket, path + from_info.path, bucket, path ) ) if not name: - name = os.path.basename(from_info["path"]) + name = os.path.basename(from_info.path) cb = None if no_progress_bar else Callback(name) try: self.blob_service.create_blob_from_path( - bucket, path, from_info["path"], progress_callback=cb + bucket, path, from_info.path, progress_callback=cb ) except Exception: - msg = "failed to upload '{}'".format(from_info["path"]) + msg = "failed to upload '{}'".format(from_info.path) logger.warning(msg) else: progress.finish_target(name) @@ -165,28 +165,28 @@ def download( names = self._verify_path_args(from_infos, to_infos, names) for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info["scheme"] != self.scheme: + if from_info.scheme != self.scheme: raise NotImplementedError - if to_info["scheme"] != "local": + if to_info.scheme != "local": raise NotImplementedError - bucket = from_info["bucket"] - path = from_info["path"] + bucket = from_info.bucket + path = from_info.path logger.debug( "Downloading '{}/{}' to '{}'".format( - bucket, path, to_info["path"] + bucket, path, to_info.path ) ) - tmp_file = tmp_fname(to_info["path"]) + tmp_file = tmp_fname(to_info.path) if not name: - name = os.path.basename(to_info["path"]) + name = os.path.basename(to_info.path) cb = None if no_progress_bar else Callback(name) - makedirs(os.path.dirname(to_info["path"]), exist_ok=True) + makedirs(os.path.dirname(to_info.path), exist_ok=True) try: self.blob_service.get_blob_to_path( @@ -196,7 +196,7 @@ def download( msg = "failed to download '{}/{}'".format(bucket, path) logger.warning(msg) else: - move(tmp_file, to_info["path"]) + move(tmp_file, to_info.path) if not no_progress_bar: progress.finish_target(name) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 8bd3cac835..7669336861 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -1,5 +1,8 @@ from __future__ import unicode_literals +from copy import copy + +from dvc.path.local import LocalPathInfo from dvc.utils.compat import str import os @@ -159,8 +162,8 @@ def get_file_checksum(self, path_info): def _collect_dir(self, path_info): dir_info = [] - p_info = path_info.copy() - dpath = p_info["path"] + p_info = copy(path_info) + dpath = p_info.path for root, dirs, files in self.walk(path_info): if len(files) > LARGE_DIR_SIZE: msg = ( @@ -173,7 +176,7 @@ def _collect_dir(self, path_info): for fname in files: path = self.ospath.join(root, fname) - p_info["path"] = path + p_info.path = path relpath = self.to_posixpath(self.ospath.relpath(path, dpath)) checksum = self.get_file_checksum(p_info) @@ -201,21 +204,19 @@ def get_dir_checksum(self, path_info): return checksum def _get_dir_info_checksum(self, dir_info, path_info): - to_info = path_info.copy() - to_info["path"] = self.cache.ospath.join( - self.cache.prefix, tmp_fname("") - ) + to_info = copy(path_info) + to_info.path = self.cache.ospath.join(self.cache.prefix, tmp_fname("")) tmp = tempfile.NamedTemporaryFile(delete=False).name with open(tmp, "w+") as fobj: json.dump(dir_info, fobj, sort_keys=True) - from_info = {"scheme": "local", "path": tmp} + from_info = LocalPathInfo(path=tmp) self.cache.upload([from_info], [to_info], no_progress_bar=True) checksum = self.get_file_checksum(to_info) + self.CHECKSUM_DIR_SUFFIX - from_info = to_info.copy() - to_info["path"] = self.cache.checksum_to_path(checksum) + from_info = copy(to_info) + to_info.path = self.cache.checksum_to_path(checksum) return checksum, from_info, to_info def get_dir_cache(self, checksum): @@ -234,7 +235,7 @@ def load_dir_cache(self, checksum): fobj = tempfile.NamedTemporaryFile(delete=False) path = fobj.name - to_info = {"scheme": "local", "path": path} + to_info = LocalPathInfo(path=path) self.cache.download([path_info], [to_info], no_progress_bar=True) try: @@ -279,7 +280,7 @@ def get_checksum(self, path_info): return checksum def save_info(self, path_info): - assert path_info["scheme"] == self.scheme + assert path_info.scheme == self.scheme return {self.PARAM_CHECKSUM: self.get_checksum(path_info)} def changed(self, path_info, checksum_info): @@ -359,11 +360,11 @@ def _save_dir(self, path_info, checksum): cache_info = self.checksum_to_path_info(checksum) dir_info = self.get_dir_cache(checksum) - entry_info = path_info.copy() + entry_info = copy(path_info) for entry in dir_info: entry_checksum = entry[self.PARAM_CHECKSUM] - entry_info["path"] = self.ospath.join( - path_info["path"], entry[self.PARAM_RELPATH] + entry_info.path = self.ospath.join( + path_info.path, entry[self.PARAM_RELPATH] ) self._save_file(entry_info, entry_checksum, save_link=False) @@ -388,9 +389,9 @@ def protect(self, path_info): pass def save(self, path_info, checksum_info): - if path_info["scheme"] != self.scheme: + if path_info.scheme != self.scheme: raise RemoteActionNotImplemented( - "save {} -> {}".format(path_info["scheme"], self.scheme), + "save {} -> {}".format(path_info.scheme, self.scheme), self.scheme, ) @@ -462,8 +463,8 @@ def path_to_checksum(self, path): return self.ospath.dirname(relpath) + self.ospath.basename(relpath) def checksum_to_path_info(self, checksum): - path_info = self.path_info.copy() - path_info["path"] = self.checksum_to_path(checksum) + path_info = copy(self.path_info) + path_info.path = self.checksum_to_path(checksum) return path_info def md5s_to_path_infos(self, md5s): @@ -592,7 +593,7 @@ def _checkout_file( self.link(cache_info, path_info) self.state.save_link(path_info) if progress_callback: - progress_callback.update(path_info["url"]) + progress_callback.update(path_info.url) def makedirs(self, path_info): raise NotImplementedError @@ -609,13 +610,13 @@ def _checkout_dir( logger.debug("Linking directory '{}'.".format(path_info)) - entry_info = path_info.copy() + entry_info = copy(path_info) for entry in dir_info: relpath = entry[self.PARAM_RELPATH] checksum = entry[self.PARAM_CHECKSUM] entry_cache_info = self.checksum_to_path_info(checksum) - entry_info["url"] = self.ospath.join(path_info["url"], relpath) - entry_info["path"] = self.ospath.join(path_info["path"], relpath) + entry_info.url = self.ospath.join(path_info.url, relpath) + entry_info.path = self.ospath.join(path_info.path, relpath) entry_checksum_info = {self.PARAM_CHECKSUM: checksum} if self.changed(entry_info, entry_checksum_info): @@ -623,7 +624,7 @@ def _checkout_dir( self.safe_remove(entry_info, force=force) self.link(entry_cache_info, entry_info) if progress_callback: - progress_callback.update(entry_info["url"]) + progress_callback.update(entry_info.url) self._remove_redundant_files(path_info, dir_info, force) @@ -637,22 +638,24 @@ def _remove_redundant_files(self, path_info, dir_info, force): ) needed_files = set( - self.ospath.join(path_info["path"], entry[self.PARAM_RELPATH]) + self.ospath.join(path_info.path, entry[self.PARAM_RELPATH]) for entry in dir_info ) delta = existing_files - needed_files - d_info = path_info.copy() + d_info = copy(path_info) for path in delta: - d_info["path"] = path + d_info.path = path self.safe_remove(d_info, force) def checkout( self, path_info, checksum_info, force=False, progress_callback=None ): - scheme = path_info["scheme"] - if scheme not in ["", "local"] and scheme != self.scheme: + if ( + path_info.scheme not in ["local"] + and path_info.scheme != self.scheme + ): raise NotImplementedError checksum = checksum_info.get(self.PARAM_CHECKSUM) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index bd3919d361..b41aff3e64 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -3,6 +3,9 @@ import os import logging +from dvc.path import Schemes +from dvc.path.gs import GSPathInfo + try: from google.cloud import storage except ImportError: @@ -20,7 +23,7 @@ class RemoteGS(RemoteBase): - scheme = "gs" + scheme = Schemes.GS REGEX = r"^gs://(?P.*)$" REQUIRES = {"google.cloud.storage": storage} PARAM_CHECKSUM = "md5" @@ -38,7 +41,7 @@ def __init__(self, repo, config): self.bucket = parsed.netloc self.prefix = parsed.path.lstrip("/") - self.path_info = {"scheme": "gs", "bucket": self.bucket} + self.path_info = GSPathInfo(bucket=self.bucket) @staticmethod def compat_config(config): @@ -59,8 +62,8 @@ def get_file_checksum(self, path_info): import base64 import codecs - bucket = path_info["bucket"] - path = path_info["path"] + bucket = path_info.bucket + path = path_info.path blob = self.gs.bucket(bucket).get_blob(path) if not blob: return None @@ -72,27 +75,25 @@ def get_file_checksum(self, path_info): def copy(self, from_info, to_info, gs=None): gs = gs if gs else self.gs - blob = gs.bucket(from_info["bucket"]).get_blob(from_info["path"]) + blob = gs.bucket(from_info.bucket).get_blob(from_info.path) if not blob: - msg = "'{}' doesn't exist in the cloud".format(from_info["path"]) + msg = "'{}' doesn't exist in the cloud".format(from_info.path) raise DvcException(msg) - bucket = self.gs.bucket(to_info["bucket"]) + bucket = self.gs.bucket(to_info.bucket) bucket.copy_blob( - blob, self.gs.bucket(to_info["bucket"]), new_name=to_info["path"] + blob, self.gs.bucket(to_info.bucket), new_name=to_info.path ) def remove(self, path_info): - if path_info["scheme"] != "gs": + if path_info.scheme != "gs": raise NotImplementedError logger.debug( - "Removing gs://{}/{}".format( - path_info["bucket"], path_info["path"] - ) + "Removing gs://{}/{}".format(path_info.bucket, path_info.path) ) - blob = self.gs.bucket(path_info["bucket"]).get_blob(path_info["path"]) + blob = self.gs.bucket(path_info.bucket).get_blob(path_info.path) if not blob: return @@ -107,10 +108,10 @@ def list_cache_paths(self): def exists(self, path_info): assert not isinstance(path_info, list) - assert path_info["scheme"] == "gs" + assert path_info.scheme == "gs" - paths = self._list_paths(path_info["bucket"], path_info["path"]) - return any(path_info["path"] == path for path in paths) + paths = self._list_paths(path_info.bucket, path_info.path) + return any(path_info.path == path for path in paths) def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): names = self._verify_path_args(to_infos, from_infos, names) @@ -118,34 +119,32 @@ def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): gs = self.gs for from_info, to_info, name in zip(from_infos, to_infos, names): - if to_info["scheme"] != "gs": + if to_info.scheme != "gs": raise NotImplementedError - if from_info["scheme"] != "local": + if from_info.scheme != "local": raise NotImplementedError logger.debug( "Uploading '{}' to '{}/{}'".format( - from_info["path"], to_info["bucket"], to_info["path"] + from_info.path, to_info.bucket, to_info.path ) ) if not name: - name = os.path.basename(from_info["path"]) + name = os.path.basename(from_info.path) if not no_progress_bar: progress.update_target(name, 0, None) try: - bucket = gs.bucket(to_info["bucket"]) - blob = bucket.blob(to_info["path"]) - blob.upload_from_filename(from_info["path"]) + bucket = gs.bucket(to_info.bucket) + blob = bucket.blob(to_info.path) + blob.upload_from_filename(from_info.path) except Exception: msg = "failed to upload '{}' to '{}/{}'" logger.exception( - msg.format( - from_info["path"], to_info["bucket"], to_info["path"] - ) + msg.format(from_info.path, to_info.bucket, to_info.path) ) continue @@ -164,46 +163,44 @@ def download( gs = self.gs for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info["scheme"] != "gs": + if from_info.scheme != "gs": raise NotImplementedError - if to_info["scheme"] == "gs": + if to_info.scheme == "gs": self.copy(from_info, to_info, gs=gs) continue - if to_info["scheme"] != "local": + if to_info.scheme != "local": raise NotImplementedError msg = "Downloading '{}/{}' to '{}'".format( - from_info["bucket"], from_info["path"], to_info["path"] + from_info.bucket, from_info.path, to_info.path ) logger.debug(msg) - tmp_file = tmp_fname(to_info["path"]) + tmp_file = tmp_fname(to_info.path) if not name: - name = os.path.basename(to_info["path"]) + name = os.path.basename(to_info.path) if not no_progress_bar: # percent_cb is not available for download_to_filename, so # lets at least update progress at pathpoints(start, finish) progress.update_target(name, 0, None) - makedirs(os.path.dirname(to_info["path"]), exist_ok=True) + makedirs(os.path.dirname(to_info.path), exist_ok=True) try: - bucket = gs.bucket(from_info["bucket"]) - blob = bucket.get_blob(from_info["path"]) + bucket = gs.bucket(from_info.bucket) + blob = bucket.get_blob(from_info.path) blob.download_to_filename(tmp_file) except Exception: msg = "failed to download '{}/{}' to '{}'" logger.exception( - msg.format( - from_info["bucket"], from_info["path"], to_info["path"] - ) + msg.format(from_info.bucket, from_info.path, to_info.path) ) continue - move(tmp_file, to_info["path"]) + move(tmp_file, to_info.path) if not no_progress_bar: progress.finish_target(name) diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index abbbecad46..95b35743fc 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -8,6 +8,8 @@ from subprocess import Popen, PIPE from dvc.config import Config +from dvc.path import Schemes +from dvc.path.hdfs import HDFSPathInfo from dvc.remote.base import RemoteBase, RemoteCmdError from dvc.utils import fix_env, tmp_fname @@ -16,7 +18,7 @@ class RemoteHDFS(RemoteBase): - scheme = "hdfs" + scheme = Schemes.HDFS REGEX = r"^hdfs://((?P.*)@)?.*$" PARAM_CHECKSUM = "checksum" @@ -30,7 +32,7 @@ def __init__(self, repo, config): Config.SECTION_REMOTE_USER, getpass.getuser() ) - self.path_info = {"scheme": "hdfs", "user": self.user} + self.path_info = HDFSPathInfo(user=self.user) def hadoop_fs(self, cmd, user=None): cmd = "hadoop fs -" + cmd @@ -66,39 +68,37 @@ def _group(regex, s, gname): def get_file_checksum(self, path_info): regex = r".*\t.*\t(?P.*)" stdout = self.hadoop_fs( - "checksum {}".format(path_info["path"]), user=path_info["user"] + "checksum {}".format(path_info.path), user=path_info.user ) return self._group(regex, stdout, "checksum") def copy(self, from_info, to_info): - dname = posixpath.dirname(to_info["path"]) - self.hadoop_fs("mkdir -p {}".format(dname), user=to_info["user"]) + dname = posixpath.dirname(to_info.path) + self.hadoop_fs("mkdir -p {}".format(dname), user=to_info.user) self.hadoop_fs( - "cp -f {} {}".format(from_info["path"], to_info["path"]), - user=to_info["user"], + "cp -f {} {}".format(from_info.path, to_info.path), + user=to_info.user, ) def rm(self, path_info): - self.hadoop_fs( - "rm -f {}".format(path_info["path"]), user=path_info["user"] - ) + self.hadoop_fs("rm -f {}".format(path_info.path), user=path_info.user) def remove(self, path_info): - if path_info["scheme"] != "hdfs": + if path_info.scheme != "hdfs": raise NotImplementedError assert path_info.get("path") - logger.debug("Removing {}".format(path_info["path"])) + logger.debug("Removing {}".format(path_info.path)) self.rm(path_info) def exists(self, path_info): assert not isinstance(path_info, list) - assert path_info["scheme"] == "hdfs" + assert path_info.scheme == "hdfs" try: - self.hadoop_fs("test -e {}".format(path_info["path"])) + self.hadoop_fs("test -e {}".format(path_info.path)) return True except RemoteCmdError: return False @@ -107,27 +107,26 @@ def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): names = self._verify_path_args(to_infos, from_infos, names) for from_info, to_info, name in zip(from_infos, to_infos, names): - if to_info["scheme"] != "hdfs": + if to_info.scheme != "hdfs": raise NotImplementedError - if from_info["scheme"] != "local": + if from_info.scheme != "local": raise NotImplementedError self.hadoop_fs( - "mkdir -p {}".format(posixpath.dirname(to_info["path"])), - user=to_info["user"], + "mkdir -p {}".format(posixpath.dirname(to_info.path)), + user=to_info.user, ) - tmp_file = tmp_fname(to_info["path"]) + tmp_file = tmp_fname(to_info.path) self.hadoop_fs( - "copyFromLocal {} {}".format(from_info["path"], tmp_file), - user=to_info["user"], + "copyFromLocal {} {}".format(from_info.path, tmp_file), + user=to_info.user, ) self.hadoop_fs( - "mv {} {}".format(tmp_file, to_info["path"]), - user=to_info["user"], + "mv {} {}".format(tmp_file, to_info.path), user=to_info.user ) def download( @@ -141,28 +140,28 @@ def download( names = self._verify_path_args(from_infos, to_infos, names) for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info["scheme"] != "hdfs": + if from_info.scheme != "hdfs": raise NotImplementedError - if to_info["scheme"] == "hdfs": + if to_info.scheme == "hdfs": self.copy(from_info, to_info) continue - if to_info["scheme"] != "local": + if to_info.scheme != "local": raise NotImplementedError - dname = os.path.dirname(to_info["path"]) + dname = os.path.dirname(to_info.path) if not os.path.exists(dname): os.makedirs(dname) - tmp_file = tmp_fname(to_info["path"]) + tmp_file = tmp_fname(to_info.path) self.hadoop_fs( - "copyToLocal {} {}".format(from_info["path"], tmp_file), - user=from_info["user"], + "copyToLocal {} {}".format(from_info.path, tmp_file), + user=from_info.user, ) - os.rename(tmp_file, to_info["path"]) + os.rename(tmp_file, to_info.path) def list_cache_paths(self): try: diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 6834601413..64d0670c33 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -1,4 +1,7 @@ from __future__ import unicode_literals + +from dvc.path import Schemes +from dvc.path.http import HTTPPathInfo from dvc.utils.compat import open, makedirs import os @@ -30,7 +33,7 @@ def __call__(self, byts): class RemoteHTTP(RemoteBase): - scheme = "http" + scheme = Schemes.HTTP REGEX = r"^https?://.*$" REQUEST_TIMEOUT = 10 CHUNK_SIZE = 2 ** 16 @@ -40,7 +43,8 @@ def __init__(self, repo, config): super(RemoteHTTP, self).__init__(repo, config) self.cache_dir = config.get(Config.SECTION_REMOTE_URL) self.url = self.cache_dir - self.path_info = {"scheme": "http"} + + self.path_info = HTTPPathInfo() @property def prefix(self): @@ -57,23 +61,23 @@ def download( names = self._verify_path_args(to_infos, from_infos, names) for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info["scheme"] not in ["http", "https"]: + if from_info.scheme not in ["http", "https"]: raise NotImplementedError - if to_info["scheme"] != "local": + if to_info.scheme != "local": raise NotImplementedError msg = "Downloading '{}' to '{}'".format( - from_info["path"], to_info["path"] + from_info.path, to_info.path ) logger.debug(msg) if not name: - name = os.path.basename(to_info["path"]) + name = os.path.basename(to_info.path) - makedirs(os.path.dirname(to_info["path"]), exist_ok=True) + makedirs(os.path.dirname(to_info.path), exist_ok=True) - total = self._content_length(from_info["path"]) + total = self._content_length(from_info.path) if no_progress_bar or not total: cb = None @@ -82,14 +86,11 @@ def download( try: self._download_to( - from_info["path"], - to_info["path"], - callback=cb, - resume=resume, + from_info.path, to_info.path, callback=cb, resume=resume ) except Exception: - msg = "failed to download '{}'".format(from_info["path"]) + msg = "failed to download '{}'".format(from_info.path) logger.exception(msg) continue @@ -98,8 +99,8 @@ def download( def exists(self, path_info): assert not isinstance(path_info, list) - assert path_info["scheme"] in ["http", "https"] - return bool(self._request("HEAD", path_info.get("path"))) + assert path_info.scheme in ["http", "https"] + return bool(self._request("HEAD", path_info.path)) def cache_exists(self, md5s): assert isinstance(md5s, list) @@ -113,7 +114,7 @@ def _content_length(self, url): return self._request("HEAD", url).headers.get("Content-Length") def get_file_checksum(self, path_info): - url = path_info["path"] + url = path_info.path etag = self._request("HEAD", url).headers.get("ETag") or self._request( "HEAD", url ).headers.get("Content-MD5") diff --git a/dvc/remote/local/__init__.py b/dvc/remote/local/__init__.py index 45b1e4bbe6..7d853b7022 100644 --- a/dvc/remote/local/__init__.py +++ b/dvc/remote/local/__init__.py @@ -2,6 +2,8 @@ from copy import copy +from dvc.path import BasePathInfo, Schemes +from dvc.path.local import LocalPathInfo from dvc.remote.local.slow_link_detection import slow_link_guard from dvc.utils.compat import str, makedirs @@ -42,7 +44,7 @@ class RemoteLOCAL(RemoteBase): - scheme = "local" + scheme = Schemes.LOCAL REGEX = r"^(?P.*)$" PARAM_CHECKSUM = "md5" PARAM_PATH = "path" @@ -77,9 +79,8 @@ def __init__(self, repo, config): if self.cache_dir is not None and not os.path.exists(self.cache_dir): os.mkdir(self.cache_dir) - self.path_info = {"scheme": "local"} - self._dir_info = {} + self.path_info = LocalPathInfo() @staticmethod def compat_config(config): @@ -115,18 +116,18 @@ def get(self, md5): return self.checksum_to_path(md5) def exists(self, path_info): - assert not isinstance(path_info, list) - assert path_info["scheme"] == "local" - return os.path.lexists(path_info["path"]) + assert isinstance(path_info, BasePathInfo) + assert path_info.scheme == "local" + return os.path.lexists(path_info.path) def makedirs(self, path_info): if not self.exists(path_info): - os.makedirs(path_info["path"]) + os.makedirs(path_info.path) @slow_link_guard def link(self, cache_info, path_info): - cache = cache_info["path"] - path = path_info["path"] + cache = cache_info.path + path = path_info.path assert os.path.isfile(cache) @@ -175,7 +176,7 @@ def ospath(self): return posixpath def already_cached(self, path_info): - assert path_info["scheme"] in ["", "local"] + assert path_info.scheme in ["", "local"] current_md5 = self.get_checksum(path_info) @@ -185,7 +186,7 @@ def already_cached(self, path_info): return not self.changed_cache(current_md5) def is_empty(self, path_info): - path = path_info["path"] + path = path_info.path if self.isfile(path_info) and os.path.getsize(path) == 0: return True @@ -196,29 +197,29 @@ def is_empty(self, path_info): return False def isfile(self, path_info): - return os.path.isfile(path_info["path"]) + return os.path.isfile(path_info.path) def isdir(self, path_info): - return os.path.isdir(path_info["path"]) + return os.path.isdir(path_info.path) def walk(self, path_info): - return os.walk(path_info["path"]) + return os.walk(path_info.path) def get_file_checksum(self, path_info): - return file_md5(path_info["path"])[0] + return file_md5(path_info.path)[0] def remove(self, path_info): - if path_info["scheme"] != "local": + if path_info.scheme != "local": raise NotImplementedError - remove(path_info["path"]) + remove(path_info.path) def move(self, from_info, to_info): - if from_info["scheme"] != "local" or to_info["scheme"] != "local": + if from_info.scheme != "local" or to_info.scheme != "local": raise NotImplementedError - inp = from_info["path"] - outp = to_info["path"] + inp = from_info.path + outp = to_info.path # moving in two stages to make the whole operation atomic in # case inp and outp are in different filesystems and actual @@ -235,36 +236,34 @@ def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): names = self._verify_path_args(to_infos, from_infos, names) for from_info, to_info, name in zip(from_infos, to_infos, names): - if to_info["scheme"] != "local": + if to_info.scheme != "local": raise NotImplementedError - if from_info["scheme"] != "local": + if from_info.scheme != "local": raise NotImplementedError logger.debug( - "Uploading '{}' to '{}'".format( - from_info["path"], to_info["path"] - ) + "Uploading '{}' to '{}'".format(from_info.path, to_info.path) ) if not name: - name = os.path.basename(from_info["path"]) + name = os.path.basename(from_info.path) - makedirs(os.path.dirname(to_info["path"]), exist_ok=True) - tmp_file = tmp_fname(to_info["path"]) + makedirs(os.path.dirname(to_info.path), exist_ok=True) + tmp_file = tmp_fname(to_info.path) try: copyfile( - from_info["path"], + from_info.path, tmp_file, name=name, no_progress_bar=no_progress_bar, ) - os.rename(tmp_file, to_info["path"]) + os.rename(tmp_file, to_info.path) except Exception: logger.exception( "failed to upload '{}' to '{}'".format( - from_info["path"], to_info["path"] + from_info.path, to_info.path ) ) @@ -279,36 +278,34 @@ def download( names = self._verify_path_args(from_infos, to_infos, names) for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info["scheme"] != "local": + if from_info.scheme != "local": raise NotImplementedError - if to_info["scheme"] != "local": + if to_info.scheme != "local": raise NotImplementedError logger.debug( - "Downloading '{}' to '{}'".format( - from_info["path"], to_info["path"] - ) + "Downloading '{}' to '{}'".format(from_info.path, to_info.path) ) if not name: - name = os.path.basename(to_info["path"]) + name = os.path.basename(to_info.path) - makedirs(os.path.dirname(to_info["path"]), exist_ok=True) - tmp_file = tmp_fname(to_info["path"]) + makedirs(os.path.dirname(to_info.path), exist_ok=True) + tmp_file = tmp_fname(to_info.path) try: copyfile( - from_info["path"], + from_info.path, tmp_file, no_progress_bar=no_progress_bar, name=name, ) - move(tmp_file, to_info["path"]) + move(tmp_file, to_info.path) except Exception: logger.exception( "failed to download '{}' to '{}'".format( - from_info["path"], to_info["path"] + from_info.path, to_info.path ) ) @@ -563,7 +560,7 @@ def _unprotect_dir(path): @staticmethod def unprotect(path_info): - path = path_info["path"] + path = path_info.path if not os.path.exists(path): raise DvcException( "can't unprotect non-existing data '{}'".format(path) @@ -576,4 +573,4 @@ def unprotect(path_info): @staticmethod def protect(path_info): - os.chmod(path_info["path"], stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH) + os.chmod(path_info.path, stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH) diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 7615ee5352..b07d87999d 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -4,6 +4,9 @@ import os import logging +from dvc.path import Schemes +from dvc.path.oss import OSSPathInfo + try: import oss2 except ImportError: @@ -39,7 +42,7 @@ class RemoteOSS(RemoteBase): $ export OSS_ENDPOINT="endpoint" """ - scheme = "oss" + scheme = Schemes.OSS REGEX = r"^oss://(?P.*)?$" REQUIRES = {"oss2": oss2} PARAM_CHECKSUM = "etag" @@ -70,7 +73,7 @@ def __init__(self, repo, config): ) self._bucket = None - self.path_info = {"scheme": self.scheme, "bucket": self.bucket} + self.path_info = OSSPathInfo(bucket=self.bucket) @property def oss_service(self): @@ -93,16 +96,14 @@ def oss_service(self): return self._bucket def remove(self, path_info): - if path_info["scheme"] != self.scheme: + if path_info.scheme != self.scheme: raise NotImplementedError logger.debug( - "Removing oss://{}/{}".format( - path_info["bucket"], path_info["path"] - ) + "Removing oss://{}/{}".format(path_info.bucket, path_info.path) ) - self.oss_service.delete_object(path_info["path"]) + self.oss_service.delete_object(path_info.path) def _list_paths(self, prefix): for blob in oss2.ObjectIterator(self.oss_service, prefix=prefix): @@ -115,32 +116,32 @@ def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): names = self._verify_path_args(to_infos, from_infos, names) for from_info, to_info, name in zip(from_infos, to_infos, names): - if to_info["scheme"] != self.scheme: + if to_info.scheme != self.scheme: raise NotImplementedError - if from_info["scheme"] != "local": + if from_info.scheme != "local": raise NotImplementedError - bucket = to_info["bucket"] - path = to_info["path"] + bucket = to_info.bucket + path = to_info.path logger.debug( "Uploading '{}' to 'oss://{}/{}'".format( - from_info["path"], bucket, path + from_info.path, bucket, path ) ) if not name: - name = os.path.basename(from_info["path"]) + name = os.path.basename(from_info.path) cb = None if no_progress_bar else Callback(name) try: self.oss_service.put_object_from_file( - path, from_info["path"], progress_callback=cb + path, from_info.path, progress_callback=cb ) except Exception: - msg = "failed to upload '{}'".format(from_info["path"]) + msg = "failed to upload '{}'".format(from_info.path) logger.warning(msg) else: progress.finish_target(name) @@ -155,27 +156,27 @@ def download( ): names = self._verify_path_args(from_infos, to_infos, names) for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info["scheme"] != self.scheme: + if from_info.scheme != self.scheme: raise NotImplementedError - if to_info["scheme"] != "local": + if to_info.scheme != "local": raise NotImplementedError - bucket = from_info["bucket"] - path = from_info["path"] + bucket = from_info.bucket + path = from_info.path logger.debug( "Downloading 'oss://{}/{}' to '{}'".format( - bucket, path, to_info["path"] + bucket, path, to_info.path ) ) - tmp_file = tmp_fname(to_info["path"]) + tmp_file = tmp_fname(to_info.path) if not name: - name = os.path.basename(to_info["path"]) + name = os.path.basename(to_info.path) cb = None if no_progress_bar else Callback(name) - makedirs(os.path.dirname(to_info["path"]), exist_ok=True) + makedirs(os.path.dirname(to_info.path), exist_ok=True) try: self.oss_service.get_object_to_file( @@ -185,7 +186,7 @@ def download( msg = "failed to download 'oss://{}/{}'".format(bucket, path) logger.warning(msg) else: - move(tmp_file, to_info["path"]) + move(tmp_file, to_info.path) if not no_progress_bar: progress.finish_target(name) diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 0dfe372d09..0964090a1b 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -4,6 +4,9 @@ import threading import logging +from dvc.path import Schemes +from dvc.path.s3 import S3PathInfo + try: import boto3 except ImportError: @@ -33,7 +36,7 @@ def __call__(self, byts): class RemoteS3(RemoteBase): - scheme = "s3" + scheme = Schemes.S3 REGEX = r"^s3://(?P.*)$" REQUIRES = {"boto3": boto3} PARAM_CHECKSUM = "etag" @@ -69,7 +72,7 @@ def __init__(self, repo, config): self.bucket = parsed.netloc self.prefix = parsed.path.lstrip("/") - self.path_info = {"scheme": self.scheme, "bucket": self.bucket} + self.path_info = S3PathInfo(bucket=self.bucket) @staticmethod def compat_config(config): @@ -95,7 +98,7 @@ def get_etag(cls, s3, bucket, path): return obj["ETag"].strip('"') def get_file_checksum(self, path_info): - return self.get_etag(self.s3, path_info["bucket"], path_info["path"]) + return self.get_etag(self.s3, path_info.bucket, path_info.path) @staticmethod def get_head_object(s3, bucket, path, *args, **kwargs): @@ -111,7 +114,7 @@ def get_head_object(s3, bucket, path, *args, **kwargs): @classmethod def _copy_multipart(cls, s3, from_info, to_info, size, n_parts): mpu = s3.create_multipart_upload( - Bucket=to_info["bucket"], Key=to_info["path"] + Bucket=to_info.bucket, Key=to_info.path ) mpu_id = mpu["UploadId"] @@ -119,7 +122,7 @@ def _copy_multipart(cls, s3, from_info, to_info, size, n_parts): byte_position = 0 for i in range(1, n_parts + 1): obj = cls.get_head_object( - s3, from_info["bucket"], from_info["path"], PartNumber=i + s3, from_info.bucket, from_info.path, PartNumber=i ) part_size = obj["ContentLength"] lastbyte = byte_position + part_size - 1 @@ -129,15 +132,12 @@ def _copy_multipart(cls, s3, from_info, to_info, size, n_parts): srange = "bytes={}-{}".format(byte_position, lastbyte) part = s3.upload_part_copy( - Bucket=to_info["bucket"], - Key=to_info["path"], + Bucket=to_info.bucket, + Key=to_info.path, PartNumber=i, UploadId=mpu_id, CopySourceRange=srange, - CopySource={ - "Bucket": from_info["bucket"], - "Key": from_info["path"], - }, + CopySource={"Bucket": from_info.bucket, "Key": from_info.path}, ) parts.append( {"PartNumber": i, "ETag": part["CopyPartResult"]["ETag"]} @@ -147,8 +147,8 @@ def _copy_multipart(cls, s3, from_info, to_info, size, n_parts): assert n_parts == len(parts) s3.complete_multipart_upload( - Bucket=to_info["bucket"], - Key=to_info["path"], + Bucket=to_info.bucket, + Key=to_info.path, UploadId=mpu_id, MultipartUpload={"Parts": parts}, ) @@ -175,7 +175,7 @@ def _copy(cls, s3, from_info, to_info): # preserve etag, we need to transfer each part separately, so the # object is transfered in the same chunks as it was originally. - obj = cls.get_head_object(s3, from_info["bucket"], from_info["path"]) + obj = cls.get_head_object(s3, from_info.bucket, from_info.path) etag = obj["ETag"].strip('"') size = obj["ContentLength"] @@ -184,10 +184,10 @@ def _copy(cls, s3, from_info, to_info): n_parts = int(parts_suffix) cls._copy_multipart(s3, from_info, to_info, size, n_parts) else: - source = {"Bucket": from_info["bucket"], "Key": from_info["path"]} - s3.copy(source, to_info["bucket"], to_info["path"]) + source = {"Bucket": from_info.bucket, "Key": from_info.path} + s3.copy(source, to_info.bucket, to_info.path) - cached_etag = cls.get_etag(s3, to_info["bucket"], to_info["path"]) + cached_etag = cls.get_etag(s3, to_info.bucket, to_info.path) if etag != cached_etag: raise ETagMismatchError(etag, cached_etag) @@ -196,18 +196,14 @@ def copy(self, from_info, to_info, s3=None): self._copy(s3, from_info, to_info) def remove(self, path_info): - if path_info["scheme"] != "s3": + if path_info.scheme != "s3": raise NotImplementedError logger.debug( - "Removing s3://{}/{}".format( - path_info["bucket"], path_info["path"] - ) + "Removing s3://{}/{}".format(path_info.bucket, path_info.path) ) - self.s3.delete_object( - Bucket=path_info["bucket"], Key=path_info["path"] - ) + self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path) def _list_paths(self, bucket, prefix): """ Read config for list object api, paginate through list objects.""" @@ -230,10 +226,10 @@ def list_cache_paths(self): def exists(self, path_info): assert not isinstance(path_info, list) - assert path_info["scheme"] == "s3" + assert path_info.scheme == "s3" - paths = self._list_paths(path_info["bucket"], path_info["path"]) - return any(path_info["path"] == path for path in paths) + paths = self._list_paths(path_info.bucket, path_info.path) + return any(path_info.path == path for path in paths) def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): names = self._verify_path_args(to_infos, from_infos, names) @@ -241,33 +237,30 @@ def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): s3 = self.s3 for from_info, to_info, name in zip(from_infos, to_infos, names): - if to_info["scheme"] != "s3": + if to_info.scheme != "s3": raise NotImplementedError - if from_info["scheme"] != "local": + if from_info.scheme != "local": raise NotImplementedError logger.debug( "Uploading '{}' to '{}/{}'".format( - from_info["path"], to_info["bucket"], to_info["path"] + from_info.path, to_info.bucket, to_info.path ) ) if not name: - name = os.path.basename(from_info["path"]) + name = os.path.basename(from_info.path) - total = os.path.getsize(from_info["path"]) + total = os.path.getsize(from_info.path) cb = None if no_progress_bar else Callback(name, total) try: s3.upload_file( - from_info["path"], - to_info["bucket"], - to_info["path"], - Callback=cb, + from_info.path, to_info.bucket, to_info.path, Callback=cb ) except Exception: - msg = "failed to upload '{}'".format(from_info["path"]) + msg = "failed to upload '{}'".format(from_info.path) logger.exception(msg) continue @@ -286,50 +279,47 @@ def download( s3 = self.s3 for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info["scheme"] != "s3": + if from_info.scheme != "s3": raise NotImplementedError - if to_info["scheme"] == "s3": + if to_info.scheme == "s3": self.copy(from_info, to_info, s3=s3) continue - if to_info["scheme"] != "local": + if to_info.scheme != "local": raise NotImplementedError msg = "Downloading '{}/{}' to '{}'".format( - from_info["bucket"], from_info["path"], to_info["path"] + from_info.bucket, from_info.path, to_info.path ) logger.debug(msg) - tmp_file = tmp_fname(to_info["path"]) + tmp_file = tmp_fname(to_info.path) if not name: - name = os.path.basename(to_info["path"]) + name = os.path.basename(to_info.path) - makedirs(os.path.dirname(to_info["path"]), exist_ok=True) + makedirs(os.path.dirname(to_info.path), exist_ok=True) try: if no_progress_bar: cb = None else: total = s3.head_object( - Bucket=from_info["bucket"], Key=from_info["path"] + Bucket=from_info.bucket, Key=from_info.path )["ContentLength"] cb = Callback(name, total) s3.download_file( - from_info["bucket"], - from_info["path"], - tmp_file, - Callback=cb, + from_info.bucket, from_info.path, tmp_file, Callback=cb ) except Exception: msg = "failed to download '{}/{}'".format( - from_info["bucket"], from_info["path"] + from_info.bucket, from_info.path ) logger.exception(msg) continue - move(tmp_file, to_info["path"]) + move(tmp_file, to_info.path) if not no_progress_bar: progress.finish_target(name) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 0ef060efcf..b9b924570b 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -4,6 +4,9 @@ import getpass import logging +from dvc.path import Schemes +from dvc.path.ssh import SSHPathInfo + try: import paramiko except ImportError: @@ -20,7 +23,7 @@ class RemoteSSH(RemoteBase): - scheme = "ssh" + scheme = Schemes.SSH # NOTE: we support both URL-like (ssh://[user@]host.xz[:port]/path) and # SCP-like (ssh://[user@]host.xz:/absolute/path) urls. @@ -65,12 +68,9 @@ def __init__(self, repo, config): Config.SECTION_REMOTE_ASK_PASSWORD, False ) - self.path_info = { - "scheme": "ssh", - "host": self.host, - "user": self.user, - "port": self.port, - } + self.path_info = SSHPathInfo( + host=self.host, user=self.user, port=self.port + ) @staticmethod def ssh_config_filename(): @@ -101,7 +101,9 @@ def _try_get_ssh_config_keyfile(user_ssh_config): return identity_file[0] return None - def ssh(self, host=None, user=None, port=None, **kwargs): + def ssh(self, path_info): + host, user, port = path_info.host, path_info.user, path_info.port + logger.debug( "Establishing ssh connection with '{host}' " "through port '{port}' as user '{user}'".format( @@ -128,48 +130,42 @@ def ssh(self, host=None, user=None, port=None, **kwargs): def exists(self, path_info): assert not isinstance(path_info, list) - assert path_info["scheme"] == self.scheme + assert path_info.scheme == self.scheme - with self.ssh(**path_info) as ssh: - return ssh.exists(path_info["path"]) + with self.ssh(path_info) as ssh: + return ssh.exists(path_info.path) def get_file_checksum(self, path_info): - if path_info["scheme"] != self.scheme: + if path_info.scheme != self.scheme: raise NotImplementedError - with self.ssh(**path_info) as ssh: - return ssh.md5(path_info["path"]) + with self.ssh(path_info) as ssh: + return ssh.md5(path_info.path) def isdir(self, path_info): - with self.ssh(**path_info) as ssh: - return ssh.isdir(path_info["path"]) + with self.ssh(path_info) as ssh: + return ssh.isdir(path_info.path) def copy(self, from_info, to_info): - if ( - from_info["scheme"] != self.scheme - or to_info["scheme"] != self.scheme - ): + if from_info.scheme != self.scheme or to_info.scheme != self.scheme: raise NotImplementedError - with self.ssh(**from_info) as ssh: - ssh.cp(from_info["path"], to_info["path"]) + with self.ssh(from_info) as ssh: + ssh.cp(from_info.path, to_info.path) def remove(self, path_info): - if path_info["scheme"] != self.scheme: + if path_info.scheme != self.scheme: raise NotImplementedError - with self.ssh(**path_info) as ssh: - ssh.remove(path_info["path"]) + with self.ssh(path_info) as ssh: + ssh.remove(path_info.path) def move(self, from_info, to_info): - if ( - from_info["scheme"] != self.scheme - or to_info["scheme"] != self.scheme - ): + if from_info.scheme != self.scheme or to_info.scheme != self.scheme: raise NotImplementedError - with self.ssh(**from_info) as ssh: - ssh.move(from_info["path"], to_info["path"]) + with self.ssh(from_info) as ssh: + ssh.move(from_info.path, to_info.path) def download( self, @@ -180,40 +176,38 @@ def download( resume=False, ): names = self._verify_path_args(from_infos, to_infos, names) - ssh = self.ssh(**from_infos[0]) + ssh = self.ssh(from_infos[0]) for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info["scheme"] != self.scheme: + if from_info.scheme != self.scheme: raise NotImplementedError - if to_info["scheme"] == self.scheme: - ssh.cp(from_info["path"], to_info["path"]) + if to_info.scheme == self.scheme: + ssh.cp(from_info.path, to_info.path) continue - if to_info["scheme"] != "local": + if to_info.scheme != "local": raise NotImplementedError logger.debug( "Downloading '{host}/{path}' to '{dest}'".format( - host=from_info["host"], - path=from_info["path"], - dest=to_info["path"], + host=from_info.host, path=from_info.path, dest=to_info.path ) ) try: ssh.download( - from_info["path"], - to_info["path"], + from_info.path, + to_info.path, progress_title=name, no_progress_bar=no_progress_bar, ) except Exception: logger.exception( "failed to download '{host}/{path}' to '{dest}'".format( - host=from_info["host"], - path=from_info["path"], - dest=to_info["path"], + host=from_info.host, + path=from_info.path, + dest=to_info.path, ) ) continue @@ -223,40 +217,40 @@ def download( def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): names = self._verify_path_args(to_infos, from_infos, names) - with self.ssh(**to_infos[0]) as ssh: + with self.ssh(to_infos[0]) as ssh: for from_info, to_info, name in zip(from_infos, to_infos, names): - if to_info["scheme"] != self.scheme: + if to_info.scheme != self.scheme: raise NotImplementedError - if from_info["scheme"] != "local": + if from_info.scheme != "local": raise NotImplementedError try: ssh.upload( - from_info["path"], - to_info["path"], + from_info.path, + to_info.path, progress_title=name, no_progress_bar=no_progress_bar, ) except Exception: logger.exception( "failed to upload '{host}/{path}' to '{dest}'".format( - host=from_info["host"], - path=from_info["path"], - dest=to_info["path"], + host=from_info.host, + path=from_info.path, + dest=to_info.path, ) ) pass def list_cache_paths(self): - with self.ssh(**self.path_info) as ssh: + with self.ssh(self.path_info) as ssh: return list(ssh.walk_files(self.prefix)) def walk(self, path_info): - with self.ssh(**path_info) as ssh: - for entry in ssh.walk(path_info["path"]): + with self.ssh(path_info) as ssh: + for entry in ssh.walk(path_info.path): yield entry def makedirs(self, path_info): - with self.ssh(**path_info) as ssh: - ssh.makedirs(path_info["path"]) + with self.ssh(path_info) as ssh: + ssh.makedirs(path_info.path) diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 659541bd1c..4e6a477f41 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -12,6 +12,7 @@ TargetNotDirectoryError, ) from dvc.ignore import DvcIgnoreFileHandler +from dvc.path.local import LocalPathInfo logger = logging.getLogger(__name__) @@ -116,7 +117,7 @@ def init(root_dir=os.curdir, no_scm=False, force=False): return Repo(root_dir) def unprotect(self, target): - path_info = {"schema": "local", "path": target} + path_info = LocalPathInfo(path=target) return self.cache.local.unprotect(path_info) def _ignore(self): @@ -282,7 +283,7 @@ def used_cache( ) for out in stage.outs: - scheme = out.path_info["scheme"] + scheme = out.path_info.scheme cache[scheme].extend( self._collect_used_cache( out, diff --git a/dvc/repo/metrics/modify.py b/dvc/repo/metrics/modify.py index 2ce5fbe79a..71fe51b2ae 100644 --- a/dvc/repo/metrics/modify.py +++ b/dvc/repo/metrics/modify.py @@ -11,7 +11,7 @@ def modify(repo, path, typ=None, xpath=None, delete=False): if out.scheme != "local": msg = "output '{}' scheme '{}' is not supported for metrics" - raise DvcException(msg.format(out.path, out.path_info["scheme"])) + raise DvcException(msg.format(out.path, out.path_info.scheme)) if typ is not None: typ = typ.lower().strip() diff --git a/dvc/repo/tag/add.py b/dvc/repo/tag/add.py index 4dd866da3c..d5d8278ac1 100644 --- a/dvc/repo/tag/add.py +++ b/dvc/repo/tag/add.py @@ -1,5 +1,5 @@ import logging - +from copy import copy logger = logging.getLogger(__name__) @@ -12,7 +12,7 @@ def add(self, tag, target=None, with_deps=False, recursive=False): if not out.info: logger.warning("missing checksum info for '{}'".format(out)) continue - out.tags[tag] = out.info.copy() + out.tags[tag] = copy(out.info) changed = True if changed: stage.dump() diff --git a/dvc/state.py b/dvc/state.py index cf3ef8fdae..293ece7673 100644 --- a/dvc/state.py +++ b/dvc/state.py @@ -371,10 +371,10 @@ def save(self, path_info, checksum): path_info (dict): path_info to save checksum for. checksum (str): checksum to save. """ - assert path_info["scheme"] == "local" + assert path_info.scheme == "local" assert checksum is not None - path = path_info["path"] + path = path_info.path assert os.path.exists(path) actual_mtime, actual_size = get_mtime_and_size(path) @@ -402,8 +402,8 @@ def get(self, path_info): str or None: checksum for the specified path info or None if it doesn't exist in the state database. """ - assert path_info["scheme"] == "local" - path = path_info["path"] + assert path_info.scheme == "local" + path = path_info.path if not os.path.exists(path): return None @@ -429,8 +429,8 @@ def save_link(self, path_info): Args: path_info (dict): path info to add to the list of links. """ - assert path_info["scheme"] == "local" - path = path_info["path"] + assert path_info.scheme == "local" + path = path_info.path if not os.path.exists(path): return diff --git a/dvc/utils/compat.py b/dvc/utils/compat.py index e0ee72078f..2c3ec53e0d 100644 --- a/dvc/utils/compat.py +++ b/dvc/utils/compat.py @@ -91,7 +91,7 @@ def _makedirs(name, mode=0o777, exist_ok=False): if is_py2: - from urlparse import urlparse, urljoin # noqa: F401 + from urlparse import urlparse, urljoin, urlsplit, urlunsplit # noqa: F401 from BaseHTTPServer import HTTPServer # noqa: F401 from SimpleHTTPServer import SimpleHTTPRequestHandler # noqa: F401 import ConfigParser # noqa: F401 @@ -131,7 +131,12 @@ def __exit__(self, *args): elif is_py3: from pathlib import Path # noqa: F401 from os import makedirs # noqa: F401 - from urllib.parse import urlparse, urljoin # noqa: F401 + from urllib.parse import ( # noqa: F401 + urlparse, # noqa: F401 + urljoin, # noqa: F401 + urlsplit, # noqa: F401 + urlunsplit, # noqa: F401 + ) from io import StringIO, BytesIO # noqa: F401 from http.server import ( # noqa: F401 HTTPServer, # noqa: F401 diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index 84651be52e..a0c527a39d 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -1,5 +1,6 @@ import os import configobj +from dvc.path.local import LocalPathInfo from dvc.remote.base import RemoteBase from mock import patch @@ -198,7 +199,7 @@ def test_dir_checksum_should_be_key_order_agnostic(dvc): with open(file2, "w") as fobj: fobj.write("2") - path_info = {"scheme": "local", "path": data_dir} + path_info = LocalPathInfo(path=data_dir) with dvc.state: with patch.object( RemoteBase, diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index fae7b692b7..718a8a41cf 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -1,8 +1,10 @@ import uuid import posixpath +from copy import copy import boto3 import pytest +from dvc.path.s3 import S3PathInfo from dvc.remote.s3 import RemoteS3 from tests.func.test_data_cloud import TEST_AWS_REPO_BUCKET, _should_test_aws @@ -10,13 +12,11 @@ def _get_src_dst(): prefix = str(uuid.uuid4()) - from_info = { - "scheme": "s3", - "bucket": TEST_AWS_REPO_BUCKET, - "path": posixpath.join(prefix, "from"), - } - to_info = from_info.copy() - to_info["path"] = posixpath.join(prefix, "to") + from_info = S3PathInfo( + bucket=TEST_AWS_REPO_BUCKET, path=posixpath.join(prefix, "from") + ) + to_info = copy(from_info) + to_info.path = posixpath.join(prefix, "to") return from_info, to_info @@ -27,9 +27,7 @@ def test_copy_singlepart_preserve_etag(): pytest.skip() s3 = boto3.client("s3") - s3.put_object( - Bucket=from_info["bucket"], Key=from_info["path"], Body="data" - ) + s3.put_object(Bucket=from_info.bucket, Key=from_info.path, Body="data") RemoteS3._copy(s3, from_info, to_info) @@ -71,5 +69,5 @@ def test_copy_multipart_preserve_etag(): pytest.skip() s3 = boto3.client("s3") - _upload_multipart(s3, from_info["bucket"], from_info["path"]) + _upload_multipart(s3, from_info.bucket, from_info.path) RemoteS3._copy(s3, from_info, to_info) diff --git a/tests/func/test_state.py b/tests/func/test_state.py index 3e620f242e..5c0fbc96c6 100644 --- a/tests/func/test_state.py +++ b/tests/func/test_state.py @@ -1,5 +1,6 @@ import os import mock +from dvc.path.local import LocalPathInfo from dvc.utils.compat import str @@ -14,7 +15,7 @@ class TestState(TestDvc): def test_update(self): path = os.path.join(self.dvc.root_dir, self.FOO) - path_info = {"scheme": "local", "path": path} + path_info = LocalPathInfo(path=path) md5 = file_md5(path)[0] state = State(self.dvc, self.dvc.config.config) @@ -77,6 +78,6 @@ def test_transforms_inode(self, get_inode_mock): get_inode_mock.side_effect = self.mock_get_inode(path, inode) with state: - state.save({"scheme": "local", "path": path}, md5) + state.save(LocalPathInfo(path=path), md5) ret = state.get_state_record_for_inode(inode) self.assertIsNotNone(ret) diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 12be6a37cf..a63bfa55bc 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -1,6 +1,7 @@ import mock from unittest import TestCase +from dvc.path import BasePathInfo from dvc.remote.base import RemoteBase, RemoteCmdError, RemoteMissingDepsError @@ -43,10 +44,10 @@ def test(self): remote = RemoteBase(None, config) remote.PARAM_CHECKSUM = "checksum" - remote.path_info = {} + remote.path_info = BasePathInfo(None, None) remote.url = "" remote.prefix = "" - path_info = {"scheme": None, "path": "example"} + path_info = BasePathInfo("example", None) checksum_info = {remote.PARAM_CHECKSUM: "1234567890"} with mock.patch.object(remote, "_checkout") as mock_checkout: