diff --git a/dvc/command/imp_url.py b/dvc/command/imp_url.py index 3f1c1b03d6..012ea80c0e 100644 --- a/dvc/command/imp_url.py +++ b/dvc/command/imp_url.py @@ -14,10 +14,7 @@ class CmdImportUrl(CmdBase): def run(self): try: self.repo.imp_url( - self.args.url, - out=self.args.out, - resume=self.args.resume, - fname=self.args.file, + self.args.url, out=self.args.out, fname=self.args.file ) except DvcException: logger.exception( @@ -54,12 +51,6 @@ def add_parser(subparsers, parent_parser): "ssh://example.com:/path/to/file\n" "remote://myremote/path/to/file (see `dvc remote`)", ) - import_parser.add_argument( - "--resume", - action="store_true", - default=False, - help="Resume previously started download.", - ) import_parser.add_argument( "out", nargs="?", help="Destination path to put files to." ) diff --git a/dvc/output/base.py b/dvc/output/base.py index fbbac18b93..c887518be3 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -278,8 +278,8 @@ def verify_metric(self): "verify metric is not supported for {}".format(self.scheme) ) - def download(self, to, resume=False): - self.remote.download([self.path_info], [to.path_info], resume=resume) + def download(self, to): + self.remote.download(self.path_info, to.path_info) def checkout(self, force=False, progress_callback=None, tag=None): if not self.use_cache: diff --git a/dvc/remote/base.py b/dvc/remote/base.py index c4a88467d8..e1df48d748 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -8,7 +8,6 @@ import logging import tempfile import itertools -from contextlib import contextmanager from operator import itemgetter from multiprocessing import cpu_count from concurrent.futures import as_completed, ThreadPoolExecutor @@ -211,7 +210,7 @@ def _get_dir_info_checksum(self, dir_info): from_info = PathInfo(tmp) to_info = self.cache.path_info / tmp_fname("") - self.cache.upload([from_info], [to_info], no_progress_bar=True) + self.cache.upload(from_info, to_info, no_progress_bar=True) checksum = self.get_file_checksum(to_info) + self.CHECKSUM_DIR_SUFFIX return checksum, to_info @@ -233,7 +232,7 @@ def load_dir_cache(self, checksum): fobj = tempfile.NamedTemporaryFile(delete=False) path = fobj.name to_info = PathInfo(path) - self.cache.download([path_info], [to_info], no_progress_bar=True) + self.cache.download(path_info, to_info, no_progress_bar=True) try: with open(path, "r") as fobj: @@ -417,113 +416,81 @@ def _save(self, path_info, checksum): return self._save_file(path_info, checksum) - @contextmanager - def transfer_context(self): - yield None - - def upload(self, from_infos, to_infos, names=None, no_progress_bar=False): + def upload(self, from_info, to_info, name=None, no_progress_bar=False): if not hasattr(self, "_upload"): raise RemoteActionNotImplemented("upload", self.scheme) - names = self._verify_path_args(to_infos, from_infos, names) - fails = 0 - with self.transfer_context() as ctx: - for from_info, to_info, name in zip(from_infos, to_infos, names): - if to_info.scheme != self.scheme: - raise NotImplementedError + if to_info.scheme != self.scheme: + raise NotImplementedError - if from_info.scheme != "local": - raise NotImplementedError + if from_info.scheme != "local": + raise NotImplementedError - msg = "Uploading '{}' to '{}'" - logger.debug(msg.format(from_info, to_info)) + logger.debug("Uploading '{}' to '{}'".format(from_info, to_info)) - if not name: - name = from_info.name + name = name or from_info.name - if not no_progress_bar: - progress.update_target(name, 0, None) + if not no_progress_bar: + progress.update_target(name, 0, None) - try: - self._upload( - from_info.fspath, - to_info, - name=name, - ctx=ctx, - no_progress_bar=no_progress_bar, - ) - except Exception: - fails += 1 - msg = "failed to upload '{}' to '{}'" - logger.exception(msg.format(from_info, to_info)) - continue - - if not no_progress_bar: - progress.finish_target(name) - - return fails - - def download( - self, - from_infos, - to_infos, - names=None, - no_progress_bar=False, - resume=False, - ): + try: + self._upload( + from_info.fspath, + to_info, + name=name, + no_progress_bar=no_progress_bar, + ) + except Exception: + msg = "failed to upload '{}' to '{}'" + logger.exception(msg.format(from_info, to_info)) + return 1 # 1 fail + + if not no_progress_bar: + progress.finish_target(name) + + return 0 + + def download(self, from_info, to_info, name=None, no_progress_bar=False): if not hasattr(self, "_download"): raise RemoteActionNotImplemented("download", self.scheme) - names = self._verify_path_args(from_infos, to_infos, names) - fails = 0 - - with self.transfer_context() as ctx: - for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info.scheme != self.scheme: - raise NotImplementedError + if from_info.scheme != self.scheme: + raise NotImplementedError - if to_info.scheme == self.scheme != "local": - self.copy(from_info, to_info, ctx=ctx) - continue + if to_info.scheme == self.scheme != "local": + self.copy(from_info, to_info) + return 0 - if to_info.scheme != "local": - raise NotImplementedError + if to_info.scheme != "local": + raise NotImplementedError - msg = "Downloading '{}' to '{}'".format(from_info, to_info) - logger.debug(msg) + logger.debug("Downloading '{}' to '{}'".format(from_info, to_info)) - tmp_file = tmp_fname(to_info) - if not name: - name = to_info.name + name = name or to_info.name - if not no_progress_bar: - # real progress is not always available, - # lets at least show start and finish - progress.update_target(name, 0, None) + if not no_progress_bar: + # real progress is not always available, + # lets at least show start and finish + progress.update_target(name, 0, None) - makedirs(fspath_py35(to_info.parent), exist_ok=True) + makedirs(fspath_py35(to_info.parent), exist_ok=True) + tmp_file = tmp_fname(to_info) - try: - self._download( - from_info, - tmp_file, - name=name, - ctx=ctx, - resume=resume, - no_progress_bar=no_progress_bar, - ) - except Exception: - fails += 1 - msg = "failed to download '{}' to '{}'" - logger.exception(msg.format(from_info, to_info)) - continue + try: + self._download( + from_info, tmp_file, name=name, no_progress_bar=no_progress_bar + ) + except Exception: + msg = "failed to download '{}' to '{}'" + logger.exception(msg.format(from_info, to_info)) + return 1 # 1 fail - move(tmp_file, fspath_py35(to_info)) + move(tmp_file, fspath_py35(to_info)) - if not no_progress_bar: - progress.finish_target(name) + if not no_progress_bar: + progress.finish_target(name) - return fails + return 0 def remove(self, path_info): raise RemoteActionNotImplemented("remove", self.scheme) @@ -532,26 +499,12 @@ def move(self, from_info, to_info): self.copy(from_info, to_info) self.remove(from_info) - def copy(self, from_info, to_info, ctx=None): + def copy(self, from_info, to_info): raise RemoteActionNotImplemented("copy", self.scheme) def exists(self, path_info): raise NotImplementedError - @classmethod - def _verify_path_args(cls, from_infos, to_infos, names=None): - assert isinstance(from_infos, list) - assert isinstance(to_infos, list) - assert len(from_infos) == len(to_infos) - - if not names: - names = len(to_infos) * [None] - else: - assert isinstance(names, list) - assert len(names) == len(to_infos) - - return names - def path_to_checksum(self, path): return "".join(self.path_cls(path).parts[-2:]) diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 04a9aaf34d..95ce4c1c37 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -2,7 +2,7 @@ import logging import itertools -from contextlib import contextmanager +from funcy import cached_property try: from google.cloud import storage @@ -42,7 +42,7 @@ def compat_config(config): ret[Config.SECTION_REMOTE_URL] = url return ret - @property + @cached_property def gs(self): return ( storage.Client.from_service_account_json(self.credentialpath) @@ -64,16 +64,14 @@ def get_file_checksum(self, path_info): md5 = base64.b64decode(b64_md5) return codecs.getencoder("hex")(md5)[0].decode("utf-8") - def copy(self, from_info, to_info, ctx=None): - gs = ctx or self.gs - - from_bucket = gs.bucket(from_info.bucket) + def copy(self, from_info, to_info): + from_bucket = self.gs.bucket(from_info.bucket) blob = from_bucket.get_blob(from_info.path) if not blob: msg = "'{}' doesn't exist in the cloud".format(from_info.path) raise DvcException(msg) - to_bucket = gs.bucket(to_info.bucket) + to_bucket = self.gs.bucket(to_info.bucket) from_bucket.copy_blob(blob, to_bucket, new_name=to_info.path) def remove(self, path_info): @@ -87,10 +85,8 @@ def remove(self, path_info): blob.delete() - def _list_paths(self, bucket, prefix, gs=None): - gs = gs or self.gs - - for blob in gs.bucket(bucket).list_blobs(prefix=prefix): + def _list_paths(self, bucket, prefix): + for blob in self.gs.bucket(bucket).list_blobs(prefix=prefix): yield blob.name def list_cache_paths(self): @@ -102,28 +98,21 @@ def exists(self, path_info): def batch_exists(self, path_infos, callback): paths = [] - gs = self.gs for path_info in path_infos: - paths.append( - self._list_paths(path_info.bucket, path_info.path, gs) - ) + paths.append(self._list_paths(path_info.bucket, path_info.path)) callback.update(str(path_info)) paths = set(itertools.chain.from_iterable(paths)) return [path_info.path in paths for path_info in path_infos] - @contextmanager - def transfer_context(self): - yield self.gs - - def _upload(self, from_file, to_info, ctx=None, **_kwargs): - bucket = ctx.bucket(to_info.bucket) + def _upload(self, from_file, to_info, **_kwargs): + bucket = self.gs.bucket(to_info.bucket) blob = bucket.blob(to_info.path) blob.upload_from_filename(from_file) - def _download(self, from_info, to_file, ctx=None, **_kwargs): - bucket = ctx.bucket(from_info.bucket) + def _download(self, from_info, to_file, **_kwargs): + bucket = self.gs.bucket(from_info.bucket) blob = bucket.get_blob(from_info.path) blob.download_to_filename(to_file) diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 136172a629..715465293a 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -2,9 +2,8 @@ from dvc.scheme import Schemes -from dvc.utils.compat import open, makedirs, fspath_py35 +from dvc.utils.compat import open -import os import threading import requests import logging @@ -13,7 +12,6 @@ from dvc.exceptions import DvcException from dvc.config import Config from dvc.remote.base import RemoteBASE -from dvc.utils import move logger = logging.getLogger(__name__) @@ -44,54 +42,25 @@ def __init__(self, repo, config): url = config.get(Config.SECTION_REMOTE_URL) self.path_info = self.path_cls(url) if url else None - def download( - self, - from_infos, - to_infos, - names=None, - no_progress_bar=False, - resume=False, - ): - names = self._verify_path_args(to_infos, from_infos, names) - fails = 0 - - for to_info, from_info, name in zip(to_infos, from_infos, names): - if from_info.scheme != self.scheme: - raise NotImplementedError - - if to_info.scheme != "local": - raise NotImplementedError - - msg = "Downloading '{}' to '{}'".format(from_info, to_info) - logger.debug(msg) - - if not name: - name = to_info.name - - makedirs(fspath_py35(to_info.parent), exist_ok=True) - + def _download(self, from_info, to_file, name=None, no_progress_bar=False): + callback = None + if not no_progress_bar: total = self._content_length(from_info.url) + if total: + callback = ProgressBarCallback(name, total) - if no_progress_bar or not total: - cb = None - else: - cb = ProgressBarCallback(name, total) + request = self._request("GET", from_info.url, stream=True) - try: - self._download_to( - from_info.url, to_info.fspath, callback=cb, resume=resume - ) + with open(to_file, "wb") as fd: + transferred_bytes = 0 - except Exception: - fails += 1 - msg = "failed to download '{}'".format(from_info) - logger.exception(msg) - continue + for chunk in request.iter_content(chunk_size=self.CHUNK_SIZE): + fd.write(chunk) + fd.flush() + transferred_bytes += len(chunk) - if not no_progress_bar: - progress.finish_target(name) - - return fails + if callback: + callback(transferred_bytes) def exists(self, path_info): return bool(self._request("HEAD", path_info.url)) @@ -128,55 +97,6 @@ def get_file_checksum(self, path_info): return etag - def _download_to(self, url, target_file, callback=None, resume=False): - request = self._request("GET", url, stream=True) - partial_file = target_file + ".part" - - mode, transferred_bytes = self._determine_mode_get_transferred_bytes( - partial_file, resume - ) - - self._validate_existing_file_size(transferred_bytes, partial_file) - - self._write_request_content( - mode, partial_file, request, transferred_bytes, callback - ) - - move(partial_file, target_file) - - def _write_request_content( - self, mode, partial_file, request, transferred_bytes, callback=None - ): - with open(partial_file, mode) as fd: - - for index, chunk in enumerate( - request.iter_content(chunk_size=self.CHUNK_SIZE) - ): - chunk_number = index + 1 - if chunk_number * self.CHUNK_SIZE > transferred_bytes: - fd.write(chunk) - fd.flush() - transferred_bytes += len(chunk) - - if callback: - callback(transferred_bytes) - - def _validate_existing_file_size(self, bytes_transferred, partial_file): - if bytes_transferred % self.CHUNK_SIZE != 0: - raise DvcException( - "File {}, might be corrupted, please remove " - "it and retry importing".format(partial_file) - ) - - def _determine_mode_get_transferred_bytes(self, partial_file, resume): - if os.path.exists(partial_file) and resume: - mode = "ab" - bytes_transfered = os.path.getsize(partial_file) - else: - mode = "wb" - bytes_transfered = 0 - return mode, bytes_transfered - def _request(self, method, url, **kwargs): kwargs.setdefault("allow_redirects", True) kwargs.setdefault("timeout", self.REQUEST_TIMEOUT) diff --git a/dvc/remote/local/__init__.py b/dvc/remote/local/__init__.py index fa72369d5b..786e92ed38 100644 --- a/dvc/remote/local/__init__.py +++ b/dvc/remote/local/__init__.py @@ -24,7 +24,6 @@ remove, move, copyfile, - to_chunks, tmp_fname, file_md5, walk_files, @@ -341,7 +340,7 @@ def _fill_statuses(self, checksum_info_dir, local_exists, remote_exists): status = STATUS_MAP[(md5 in local, md5 in remote)] info["status"] = status - def _get_chunks(self, download, remote, status_info, status, jobs): + def _get_plans(self, download, remote, status_info, status): cache = [] path_infos = [] names = [] @@ -360,11 +359,7 @@ def _get_chunks(self, download, remote, status_info, status, jobs): to_infos = path_infos from_infos = cache - return ( - to_chunks(from_infos, num_chunks=jobs), - to_chunks(to_infos, num_chunks=jobs), - to_chunks(names, num_chunks=jobs), - ) + return from_infos, to_infos, names def _process( self, @@ -399,16 +394,16 @@ def _process( download=download, ) - chunks = self._get_chunks(download, remote, status_info, status, jobs) + plans = self._get_plans(download, remote, status_info, status) - if len(chunks[0]) == 0: + if len(plans[0]) == 0: return 0 if jobs > 1: with ThreadPoolExecutor(max_workers=jobs) as executor: - fails = sum(executor.map(func, *chunks)) + fails = sum(executor.map(func, *plans)) else: - fails = sum(map(func, *chunks)) + fails = sum(map(func, *plans)) if fails: msg = "{} file(s) failed to {}" @@ -416,7 +411,7 @@ def _process( msg.format(fails, "download" if download else "upload") ) - return len(chunks[0]) + return len(plans[0]) def push(self, checksum_infos, remote, jobs=None, show_checksums=False): return self._process( diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 0622ad2eb6..deb095170f 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -4,7 +4,7 @@ import threading import logging import itertools -from contextlib import contextmanager +from funcy import cached_property try: import boto3 @@ -77,7 +77,7 @@ def compat_config(config): ret[Config.SECTION_REMOTE_URL] = url return ret - @property + @cached_property def s3(self): session = boto3.session.Session( profile_name=self.profile, region_name=self.region @@ -191,8 +191,8 @@ def _copy(cls, s3, from_info, to_info, extra_args): if etag != cached_etag: raise ETagMismatchError(etag, cached_etag) - def copy(self, from_info, to_info, ctx=None): - self._copy(ctx or self.s3, from_info, to_info, self.extra_args) + def copy(self, from_info, to_info): + self._copy(self.s3, from_info, to_info, self.extra_args) def remove(self, path_info): if path_info.scheme != "s3": @@ -201,15 +201,14 @@ def remove(self, path_info): logger.debug("Removing {}".format(path_info)) self.s3.delete_object(Bucket=path_info.bucket, Key=path_info.path) - def _list_paths(self, bucket, prefix, s3=None): + def _list_paths(self, bucket, prefix): """ Read config for list object api, paginate through list objects.""" - s3 = s3 or self.s3 kwargs = {"Bucket": bucket, "Prefix": prefix} if self.list_objects: list_objects_api = "list_objects" else: list_objects_api = "list_objects_v2" - paginator = s3.get_paginator(list_objects_api) + paginator = self.s3.get_paginator(list_objects_api) for page in paginator.paginate(**kwargs): contents = page.get("Contents", None) if not contents: @@ -226,27 +225,18 @@ def exists(self, path_info): def batch_exists(self, path_infos, callback): paths = [] - s3 = self.s3 for path_info in path_infos: - paths.append( - self._list_paths(path_info.bucket, path_info.path, s3) - ) + paths.append(self._list_paths(path_info.bucket, path_info.path)) callback.update(str(path_info)) paths = set(itertools.chain.from_iterable(paths)) return [path_info.path in paths for path_info in path_infos] - @contextmanager - def transfer_context(self): - yield self.s3 - - def _upload( - self, from_file, to_info, name=None, ctx=None, no_progress_bar=False - ): + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): total = os.path.getsize(from_file) cb = None if no_progress_bar else Callback(name, total) - ctx.upload_file( + self.s3.upload_file( from_file, to_info.bucket, to_info.path, @@ -254,25 +244,15 @@ def _upload( ExtraArgs=self.extra_args, ) - def _download( - self, - from_info, - to_file, - name=None, - ctx=None, - no_progress_bar=False, - resume=False, - ): - s3 = ctx - + def _download(self, from_info, to_file, name=None, no_progress_bar=False): if no_progress_bar: cb = None else: - total = s3.head_object( + total = self.s3.head_object( Bucket=from_info.bucket, Key=from_info.path )["ContentLength"] cb = Callback(name, total) - s3.download_file( + self.s3.download_file( from_info.bucket, from_info.path, to_file, Callback=cb ) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 3e34e9ce9c..ed5e5889bd 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -169,15 +169,12 @@ def isdir(self, path_info): with self.ssh(path_info) as ssh: return ssh.isdir(path_info.path) - def copy(self, from_info, to_info, ctx=None): + def copy(self, from_info, to_info): if from_info.scheme != self.scheme or to_info.scheme != self.scheme: raise NotImplementedError - if ctx: - ctx.cp(from_info.path, to_info.path) - else: - with self.ssh(from_info) as ssh: - ssh.cp(from_info.path, to_info.path) + with self.ssh(from_info) as ssh: + ssh.cp(from_info.path, to_info.path) def remove(self, path_info): if path_info.scheme != self.scheme: @@ -193,36 +190,25 @@ def move(self, from_info, to_info): with self.ssh(from_info) as ssh: ssh.move(from_info.path, to_info.path) - def transfer_context(self): - return self.ssh(self.path_info) - - def _download( - self, - from_info, - to_file, - name=None, - ctx=None, - no_progress_bar=False, - resume=False, - ): + def _download(self, from_info, to_file, name=None, no_progress_bar=False): assert from_info.isin(self.path_info) - ctx.download( - from_info.path, - to_file, - progress_title=name, - no_progress_bar=no_progress_bar, - ) + with self.ssh(self.path_info) as ssh: + ssh.download( + from_info.path, + to_file, + progress_title=name, + no_progress_bar=no_progress_bar, + ) - def _upload( - self, from_file, to_info, name=None, ctx=None, no_progress_bar=False - ): + def _upload(self, from_file, to_info, name=None, no_progress_bar=False): assert to_info.isin(self.path_info) - ctx.upload( - from_file, - to_info.path, - progress_title=name, - no_progress_bar=no_progress_bar, - ) + with self.ssh(self.path_info) as ssh: + ssh.upload( + from_file, + to_info.path, + progress_title=name, + no_progress_bar=no_progress_bar, + ) def list_cache_paths(self): with self.ssh(self.path_info) as ssh: diff --git a/dvc/repo/imp_url.py b/dvc/repo/imp_url.py index c0eae33143..8d42d6bb17 100644 --- a/dvc/repo/imp_url.py +++ b/dvc/repo/imp_url.py @@ -3,9 +3,7 @@ @scm_context -def imp_url( - self, url, out=None, resume=False, fname=None, erepo=None, locked=False -): +def imp_url(self, url, out=None, fname=None, erepo=None, locked=False): from dvc.stage import Stage out = out or pathlib.PurePath(url).name @@ -21,7 +19,7 @@ def imp_url( self.check_dag(self.stages() + [stage]) with self.state: - stage.run(resume=resume) + stage.run() stage.locked = locked diff --git a/dvc/stage.py b/dvc/stage.py index 72367031f3..bea8f30ee9 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -797,7 +797,7 @@ def _run(self): if (p is None) or (p.returncode != 0): raise StageCmdFailedError(self) - def run(self, dry=False, resume=False, no_commit=False, force=False): + def run(self, dry=False, no_commit=False, force=False): if (self.cmd or self.is_import) and not self.locked and not dry: self.remove_outs(ignore_remove=False, force=False) @@ -820,7 +820,7 @@ def run(self, dry=False, resume=False, no_commit=False, force=False): if not force and self._already_cached(): self.outs[0].checkout() else: - self.deps[0].download(self.outs[0], resume=resume) + self.deps[0].download(self.outs[0]) elif self.is_data_source: msg = "Verifying data sources in '{}'".format(self.relpath) diff --git a/scripts/completion/dvc.bash b/scripts/completion/dvc.bash index 18c64b6abd..c9b0b5f680 100644 --- a/scripts/completion/dvc.bash +++ b/scripts/completion/dvc.bash @@ -26,7 +26,7 @@ _dvc_fetch='--show-checksums -j --jobs -r --remote -a --all-branches -T --all-ta _dvc_get_url='' _dvc_get='-o --out --rev' _dvc_gc='-a --all-branches -T --all-tags -c --cloud -r --remote -f --force -p --projects -j --jobs' -_dvc_import_url='--resume -f --file' +_dvc_import_url='-f --file' _dvc_import='-o --out --rev' _dvc_init='--no-scm -f --force' _dvc_install='' diff --git a/scripts/completion/dvc.zsh b/scripts/completion/dvc.zsh index 06b961e740..1a8669d8eb 100644 --- a/scripts/completion/dvc.zsh +++ b/scripts/completion/dvc.zsh @@ -133,7 +133,6 @@ _dvc_gc=( ) _dvc_importurl=( - "--resume[Resume previously started download.]" {-f,--file}"[Specify name of the DVC-file it generates.]:File:_files" "1:URL:" "2:Output:" diff --git a/tests/func/test_import_url.py b/tests/func/test_import_url.py index c257061d47..4938d8e492 100644 --- a/tests/func/test_import_url.py +++ b/tests/func/test_import_url.py @@ -5,12 +5,10 @@ import os from uuid import uuid4 -from dvc.utils.compat import urljoin from dvc.main import main -from mock import patch, mock_open, call +from mock import patch from tests.basic_env import TestDvc from tests.utils import spy -from tests.utils.httpd import StaticFileServer class TestCmdImport(TestDvc): @@ -43,71 +41,6 @@ def test(self): self.assertEqual(fd.read(), "content") -class TestInterruptedDownload(TestDvc): - def _prepare_interrupted_download(self, port): - import_url = urljoin("http://localhost:{}/".format(port), self.FOO) - import_output = "imported_file" - tmp_file_name = import_output + ".part" - tmp_file_path = os.path.realpath( - os.path.join(self._root_dir, tmp_file_name) - ) - self._import_with_interrupt(import_output, import_url) - self.assertTrue(os.path.exists(tmp_file_name)) - self.assertFalse(os.path.exists(import_output)) - return import_output, import_url, tmp_file_path - - def _import_with_interrupt(self, import_output, import_url): - def interrupting_generator(): - yield self.FOO[0].encode("utf8") - raise KeyboardInterrupt - - with patch( - "requests.models.Response.iter_content", - return_value=interrupting_generator(), - ): - with patch( - "dvc.remote.http.RemoteHTTP._content_length", return_value=3 - ): - result = main(["import-url", import_url, import_output]) - self.assertEqual(result, 252) - - -class TestShouldResumeDownload(TestInterruptedDownload): - @patch("dvc.remote.http.RemoteHTTP.CHUNK_SIZE", 1) - def test(self): - with StaticFileServer() as httpd: - output, url, file_path = self._prepare_interrupted_download( - httpd.server_port - ) - - m = mock_open() - with patch("dvc.remote.http.open", m): - result = main(["import-url", "--resume", url, output]) - self.assertEqual(result, 0) - m.assert_called_once_with(file_path, "ab") - m_handle = m() - expected_calls = [call(b"o"), call(b"o")] - m_handle.write.assert_has_calls(expected_calls, any_order=False) - - -class TestShouldNotResumeDownload(TestInterruptedDownload): - @patch("dvc.remote.http.RemoteHTTP.CHUNK_SIZE", 1) - def test(self): - with StaticFileServer() as httpd: - output, url, file_path = self._prepare_interrupted_download( - httpd.server_port - ) - - m = mock_open() - with patch("dvc.remote.http.open", m): - result = main(["import-url", url, output]) - self.assertEqual(result, 0) - m.assert_called_once_with(file_path, "wb") - m_handle = m() - expected_calls = [call(b"f"), call(b"o"), call(b"o")] - m_handle.write.assert_has_calls(expected_calls, any_order=False) - - class TestShouldRemoveOutsBeforeImport(TestDvc): def setUp(self): super(TestShouldRemoveOutsBeforeImport, self).setUp() diff --git a/tests/unit/command/test_imp_url.py b/tests/unit/command/test_imp_url.py index 2e11ff520b..3ec0efbaac 100644 --- a/tests/unit/command/test_imp_url.py +++ b/tests/unit/command/test_imp_url.py @@ -6,9 +6,7 @@ def test_import_url(mocker, dvc_repo): - cli_args = parse_args( - ["import-url", "src", "out", "--resume", "--file", "file"] - ) + cli_args = parse_args(["import-url", "src", "out", "--file", "file"]) assert cli_args.func == CmdImportUrl cmd = cli_args.func(cli_args) @@ -16,7 +14,7 @@ def test_import_url(mocker, dvc_repo): assert cmd.run() == 0 - m.assert_called_once_with("src", out="out", resume=True, fname="file") + m.assert_called_once_with("src", out="out", fname="file") def test_failed_import_url(mocker, caplog, dvc_repo):