diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 91bd96c232..803deccfbe 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -49,9 +49,15 @@ del SCHEMA[BaseOutput.PARAM_METRIC] SCHEMA.update(RepoDependency.REPO_SCHEMA) SCHEMA.update(ParamsDependency.PARAM_SCHEMA) +SCHEMA.update({BaseOutput.PARAM_FILTER: str}) def _get(stage, p, info): + if isinstance(p, dict): + p = list(p.items()) + assert len(p) == 1 + p, extra_info = p[0] # PARAM_FILTER + info.update(extra_info) parsed = urlparse(p) if p else None if parsed and parsed.scheme == "remote": tree = get_cloud_tree(stage.repo, name=parsed.netloc) diff --git a/dvc/output/base.py b/dvc/output/base.py index f17e3df9e5..d5f303c090 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -52,6 +52,7 @@ class BaseOutput: PARAM_PATH = "path" PARAM_CACHE = "cache" + PARAM_FILTER = "cmd" PARAM_METRIC = "metric" PARAM_METRIC_TYPE = "type" PARAM_METRIC_XPATH = "xpath" @@ -176,6 +177,10 @@ def checksum(self): def checksum(self, checksum): self.info[self.tree.PARAM_CHECKSUM] = checksum + @property + def filter_cmd(self): + return self.info.get(self.PARAM_FILTER) + def get_checksum(self): return self.tree.get_hash(self.path_info) @@ -188,7 +193,7 @@ def exists(self): return self.tree.exists(self.path_info) def save_info(self): - return self.tree.save_info(self.path_info) + return self.tree.save_info(self.path_info, cmd=self.filter_cmd) def changed_checksum(self): return self.checksum != self.get_checksum() @@ -313,6 +318,8 @@ def dumpd(self): if self.persist: ret[self.PARAM_PERSIST] = self.persist + if self.filter_cmd: + ret[self.PARAM_FILTER] = self.filter_cmd return ret def verify_metric(self): diff --git a/dvc/repo/tree.py b/dvc/repo/tree.py index 8011d0318c..1f30e5024a 100644 --- a/dvc/repo/tree.py +++ b/dvc/repo/tree.py @@ -226,7 +226,8 @@ def isdvc(self, path, **kwargs): def isexec(self, path): # pylint: disable=unused-argument return False - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): + assert not cmd, NotImplementedError outs = self._find_outs(path_info, strict=False) if len(outs) != 1: raise OutputNotFoundError @@ -404,7 +405,7 @@ def walk_files(self, top, **kwargs): # pylint: disable=arguments-differ for fname in files: yield PathInfo(root) / fname - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): """Return file checksum for specified path. If path_info is a DVC out, the pre-computed checksum for the file @@ -418,7 +419,7 @@ def get_file_hash(self, path_info): return self.dvctree.get_file_hash(path_info) except OutputNotFoundError: pass - return file_md5(path_info, self)[0] + return file_md5(path_info, self, cmd=cmd)[0] def copytree(self, top, dest): top = PathInfo(top) diff --git a/dvc/schema.py b/dvc/schema.py index 0948f1f613..3005e1c375 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -17,10 +17,12 @@ StageParams.PARAM_ALWAYS_CHANGED: bool, } -DATA_SCHEMA = {**CHECKSUMS_SCHEMA, Required("path"): str} +DATA_SCHEMA = {**CHECKSUMS_SCHEMA, Required(BaseOutput.PARAM_PATH): str} LOCK_FILE_STAGE_SCHEMA = { Required(StageParams.PARAM_CMD): str, - StageParams.PARAM_DEPS: [DATA_SCHEMA], + StageParams.PARAM_DEPS: [ + {**DATA_SCHEMA, Optional(BaseOutput.PARAM_FILTER): str} + ], StageParams.PARAM_PARAMS: {str: {str: object}}, StageParams.PARAM_OUTS: [DATA_SCHEMA], } @@ -51,7 +53,9 @@ str: { StageParams.PARAM_CMD: str, Optional(StageParams.PARAM_WDIR): str, - Optional(StageParams.PARAM_DEPS): [str], + Optional(StageParams.PARAM_DEPS): [ + Any(str, {str: {BaseOutput.PARAM_FILTER: str}}) + ], Optional(StageParams.PARAM_PARAMS): [ Any(str, PARAM_PSTAGE_NON_DEFAULT_SCHEMA) ], diff --git a/dvc/tree/azure.py b/dvc/tree/azure.py index 19f9ecdcb9..1829cff04b 100644 --- a/dvc/tree/azure.py +++ b/dvc/tree/azure.py @@ -139,7 +139,8 @@ def remove(self, path_info): logger.debug(f"Removing {path_info}") self.blob_service.delete_blob(path_info.bucket, path_info.path) - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): + assert not cmd, NotImplementedError return self.get_etag(path_info) def _upload( diff --git a/dvc/tree/base.py b/dvc/tree/base.py index 1860bafab6..828a379411 100644 --- a/dvc/tree/base.py +++ b/dvc/tree/base.py @@ -63,6 +63,7 @@ class BaseTree: CACHE_MODE = None SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} PARAM_CHECKSUM = None + PARAM_FILTER = None state = StateNoop() @@ -236,7 +237,7 @@ def is_dir_hash(cls, hash_): return False return hash_.endswith(cls.CHECKSUM_DIR_SUFFIX) - def get_hash(self, path_info, **kwargs): + def get_hash(self, path_info, cmd=None, **kwargs): assert path_info and ( isinstance(path_info, str) or path_info.scheme == self.scheme ) @@ -265,14 +266,14 @@ def get_hash(self, path_info, **kwargs): if self.isdir(path_info): hash_ = self.get_dir_hash(path_info, **kwargs) else: - hash_ = self.get_file_hash(path_info) + hash_ = self.get_file_hash(path_info, cmd=cmd) if hash_ and self.exists(path_info): self.state.save(path_info, hash_) return hash_ - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): raise NotImplementedError def get_dir_hash(self, path_info, **kwargs): @@ -293,8 +294,11 @@ def path_to_hash(self, path): return "".join(parts) - def save_info(self, path_info, **kwargs): - return {self.PARAM_CHECKSUM: self.get_hash(path_info, **kwargs)} + def save_info(self, path_info, cmd=None, **kwargs): + ret = {self.PARAM_CHECKSUM: self.get_hash(path_info, **kwargs)} + if cmd: + ret[self.PARAM_FILTER] = cmd + return ret def _calculate_hashes(self, file_infos): file_infos = list(file_infos) diff --git a/dvc/tree/gdrive.py b/dvc/tree/gdrive.py index abf48d17ab..799663adad 100644 --- a/dvc/tree/gdrive.py +++ b/dvc/tree/gdrive.py @@ -573,7 +573,7 @@ def remove(self, path_info): item_id = self._get_item_id(path_info) self.gdrive_delete_file(item_id) - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): raise NotImplementedError def _upload(self, from_file, to_info, name=None, no_progress_bar=False): diff --git a/dvc/tree/gs.py b/dvc/tree/gs.py index 767154de35..df4473abab 100644 --- a/dvc/tree/gs.py +++ b/dvc/tree/gs.py @@ -182,7 +182,8 @@ def copy(self, from_info, to_info): to_bucket = self.gs.bucket(to_info.bucket) from_bucket.copy_blob(blob, to_bucket, new_name=to_info.path) - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): + assert not cmd, NotImplementedError import base64 import codecs diff --git a/dvc/tree/hdfs.py b/dvc/tree/hdfs.py index cdc9a2c339..d064718eb8 100644 --- a/dvc/tree/hdfs.py +++ b/dvc/tree/hdfs.py @@ -161,7 +161,8 @@ def _group(regex, s, gname): assert match is not None return match.group(gname) - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): + assert not cmd, NotImplementedError # NOTE: pyarrow doesn't support checksum, so we need to use hadoop regex = r".*\t.*\t(?P.*)" stdout = self.hadoop_fs( diff --git a/dvc/tree/http.py b/dvc/tree/http.py index abed95ca4f..b62042fe3c 100644 --- a/dvc/tree/http.py +++ b/dvc/tree/http.py @@ -125,7 +125,8 @@ def request(self, method, url, **kwargs): def exists(self, path_info, use_dvcignore=True): return bool(self.request("HEAD", path_info.url)) - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): + assert not cmd, NotImplementedError url = path_info.url headers = self.request("HEAD", url).headers etag = headers.get("ETag") or headers.get("Content-MD5") diff --git a/dvc/tree/local.py b/dvc/tree/local.py index bfce3345f8..f686e5f437 100644 --- a/dvc/tree/local.py +++ b/dvc/tree/local.py @@ -29,6 +29,7 @@ class LocalTree(BaseTree): scheme = Schemes.LOCAL PATH_CLS = PathInfo PARAM_CHECKSUM = "md5" + PARAM_FILTER = "cmd" PARAM_PATH = "path" TRAVERSE_PREFIX_LEN = 2 UNPACKED_DIR_SUFFIX = ".unpacked" @@ -297,8 +298,8 @@ def is_protected(self, path_info): return stat.S_IMODE(mode) == self.CACHE_MODE - def get_file_hash(self, path_info): - return file_md5(path_info)[0] + def get_file_hash(self, path_info, cmd=None): + return file_md5(path_info, cmd=cmd)[0] @staticmethod def getsize(path_info): diff --git a/dvc/tree/s3.py b/dvc/tree/s3.py index 44368d7461..7a4a45c7c5 100644 --- a/dvc/tree/s3.py +++ b/dvc/tree/s3.py @@ -317,7 +317,8 @@ def _copy(cls, s3, from_info, to_info, extra_args): if etag != cached_etag: raise ETagMismatchError(etag, cached_etag) - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): + assert not cmd, NotImplementedError return self.get_etag(self.s3, path_info.bucket, path_info.path) def _upload(self, from_file, to_info, name=None, no_progress_bar=False): diff --git a/dvc/tree/ssh/__init__.py b/dvc/tree/ssh/__init__.py index 26bee19388..20560f8e6f 100644 --- a/dvc/tree/ssh/__init__.py +++ b/dvc/tree/ssh/__init__.py @@ -233,7 +233,8 @@ def reflink(self, from_info, to_info): with self.ssh(from_info) as ssh: ssh.reflink(from_info.path, to_info.path) - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): + assert not cmd, NotImplementedError if path_info.scheme != self.scheme: raise NotImplementedError diff --git a/dvc/tree/webdav.py b/dvc/tree/webdav.py index 115de31767..282851707e 100644 --- a/dvc/tree/webdav.py +++ b/dvc/tree/webdav.py @@ -131,7 +131,8 @@ def exists(self, path_info, use_dvcignore=True): return self._client.check(path_info.path) # Gets file hash 'etag' - def get_file_hash(self, path_info): + def get_file_hash(self, path_info, cmd=None): + assert not cmd, NotImplementedError # Use webdav client info method to get etag etag = self._client.info(path_info.path)["etag"].strip('"') diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index 8fece228ff..15966ba0ad 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -7,7 +7,9 @@ import math import os import re +import subprocess import sys +import tempfile import time import colorama @@ -43,8 +45,10 @@ def _fobj_md5(fobj, hash_md5, binary, progress_func=None): progress_func(len(data)) -def file_md5(fname, tree=None): - """ get the (md5 hexdigest, md5 digest) of a file """ +def file_md5(fname, tree=None, cmd=None): + """ + Returns (md5_hexdigest, md5_digest) of `cmd file` (default: `cmd=cat`) + """ from dvc.progress import Tqdm from dvc.istextfile import istextfile @@ -58,6 +62,21 @@ def file_md5(fname, tree=None): open_func = open if exists_func(fname): + filtered = None + if cmd: + p = subprocess.Popen( + cmd.split() + [fname], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, err = p.communicate() + if p.returncode != 0: + logger.error("filtering:%s %s", cmd, fname) + raise RuntimeError(err) + with tempfile.NamedTemporaryFile(delete=False) as fobj: + logger.debug("filtering:%s %s > %s", cmd, fname, fobj.name) + fobj.write(out) + fname = filtered = fobj.name hash_md5 = hashlib.md5() binary = not istextfile(fname, tree=tree) size = stat_func(fname).st_size @@ -80,6 +99,10 @@ def file_md5(fname, tree=None): with open_func(fname, "rb") as fobj: _fobj_md5(fobj, hash_md5, binary, pbar.update) + if filtered is not None: + from dvc.utils.fs import remove + + remove(filtered) return (hash_md5.hexdigest(), hash_md5.digest()) return (None, None)