diff --git a/dvc/config.py b/dvc/config.py index d89b2fc6f0..3d500b685d 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -1,18 +1,24 @@ """DVC config objects.""" - from __future__ import unicode_literals -from dvc.utils.compat import str, open - -import os -import re import copy import errno -import configobj import logging +import os +import re -from schema import Schema, Optional, And, Use, Regex, SchemaError -from dvc.exceptions import DvcException, NotDvcRepoError +import configobj +from schema import And +from schema import Optional +from schema import Regex +from schema import Schema +from schema import SchemaError +from schema import Use + +from dvc.exceptions import DvcException +from dvc.exceptions import NotDvcRepoError +from dvc.utils.compat import open +from dvc.utils.compat import str logger = logging.getLogger(__name__) @@ -26,19 +32,16 @@ class ConfigError(DvcException): """ def __init__(self, msg, cause=None): - super(ConfigError, self).__init__( - "config file error: {}".format(msg), cause=cause - ) + super(ConfigError, self).__init__("config file error: {}".format(msg), + cause=cause) class NoRemoteError(ConfigError): def __init__(self, command, cause=None): - msg = ( - "no remote specified. Setup default remote with\n" - " dvc config core.remote \n" - "or use:\n" - " dvc {} -r \n".format(command) - ) + msg = ("no remote specified. Setup default remote with\n" + " dvc config core.remote \n" + "or use:\n" + " dvc {} -r \n".format(command)) super(NoRemoteError, self).__init__(msg, cause=cause) @@ -147,6 +150,8 @@ class Config(object): # pylint: disable=too-many-instance-attributes CONFIG = "config" CONFIG_LOCAL = "config.local" + CREDENTIALPATH = "credentialpath" + LEVEL_LOCAL = 0 LEVEL_REPO = 1 LEVEL_GLOBAL = 2 @@ -157,8 +162,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_CORE = "core" SECTION_CORE_LOGLEVEL = "loglevel" SECTION_CORE_LOGLEVEL_SCHEMA = And( - Use(str.lower), Choices("info", "debug", "warning", "error") - ) + Use(str.lower), Choices("info", "debug", "warning", "error")) SECTION_CORE_REMOTE = "remote" SECTION_CORE_INTERACTIVE_SCHEMA = BOOL_SCHEMA SECTION_CORE_INTERACTIVE = "interactive" @@ -197,25 +201,22 @@ class Config(object): # pylint: disable=too-many-instance-attributes } SECTION_CORE_SCHEMA = { - Optional(SECTION_CORE_LOGLEVEL): And( - str, Use(str.lower), SECTION_CORE_LOGLEVEL_SCHEMA - ), - Optional(SECTION_CORE_REMOTE, default=""): And(str, Use(str.lower)), - Optional( - SECTION_CORE_INTERACTIVE, default=False - ): SECTION_CORE_INTERACTIVE_SCHEMA, - Optional( - SECTION_CORE_ANALYTICS, default=True - ): SECTION_CORE_ANALYTICS_SCHEMA, - Optional( - SECTION_CORE_CHECKSUM_JOBS, default=None - ): SECTION_CORE_CHECKSUM_JOBS_SCHEMA, + Optional(SECTION_CORE_LOGLEVEL): + And(str, Use(str.lower), SECTION_CORE_LOGLEVEL_SCHEMA), + Optional(SECTION_CORE_REMOTE, default=""): + And(str, Use(str.lower)), + Optional(SECTION_CORE_INTERACTIVE, default=False): + SECTION_CORE_INTERACTIVE_SCHEMA, + Optional(SECTION_CORE_ANALYTICS, default=True): + SECTION_CORE_ANALYTICS_SCHEMA, + Optional(SECTION_CORE_CHECKSUM_JOBS, default=None): + SECTION_CORE_CHECKSUM_JOBS_SCHEMA, } # backward compatibility SECTION_AWS = "aws" SECTION_AWS_STORAGEPATH = "storagepath" - SECTION_AWS_CREDENTIALPATH = "credentialpath" + SECTION_AWS_CREDENTIALPATH = CREDENTIALPATH SECTION_AWS_ENDPOINT_URL = "endpointurl" SECTION_AWS_LIST_OBJECTS = "listobjects" SECTION_AWS_REGION = "region" @@ -238,7 +239,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes # backward compatibility SECTION_GCP = "gcp" SECTION_GCP_STORAGEPATH = SECTION_AWS_STORAGEPATH - SECTION_GCP_CREDENTIALPATH = SECTION_AWS_CREDENTIALPATH + SECTION_GCP_CREDENTIALPATH = CREDENTIALPATH SECTION_GCP_PROJECTNAME = "projectname" SECTION_GCP_SCHEMA = { SECTION_GCP_STORAGEPATH: str, @@ -255,6 +256,10 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_OSS_ACCESS_KEY_ID = "oss_key_id" SECTION_OSS_ACCESS_KEY_SECRET = "oss_key_secret" SECTION_OSS_ENDPOINT = "oss_endpoint" + # GDrive options + SECTION_GDRIVE_CLIENT_ID = "gdrive_client_id" + SECTION_GDRIVE_CLIENT_SECRET = "gdrive_client_secret" + SECTION_GDRIVE_USER_CREDENTIALS_FILE = "gdrive_user_credentials_file" SECTION_REMOTE_REGEX = r'^\s*remote\s*"(?P.*)"\s*$' SECTION_REMOTE_FMT = 'remote "{}"' @@ -271,7 +276,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_REMOTE_URL: str, Optional(SECTION_AWS_REGION): str, Optional(SECTION_AWS_PROFILE): str, - Optional(SECTION_AWS_CREDENTIALPATH): str, + Optional(CREDENTIALPATH): str, Optional(SECTION_AWS_ENDPOINT_URL): str, Optional(SECTION_AWS_LIST_OBJECTS, default=False): BOOL_SCHEMA, Optional(SECTION_AWS_USE_SSL, default=True): BOOL_SCHEMA, @@ -291,6 +296,9 @@ class Config(object): # pylint: disable=too-many-instance-attributes Optional(SECTION_OSS_ACCESS_KEY_ID): str, Optional(SECTION_OSS_ACCESS_KEY_SECRET): str, Optional(SECTION_OSS_ENDPOINT): str, + Optional(SECTION_GDRIVE_CLIENT_ID): str, + Optional(SECTION_GDRIVE_CLIENT_SECRET): str, + Optional(SECTION_GDRIVE_USER_CREDENTIALS_FILE): str, Optional(PRIVATE_CWD): str, Optional(SECTION_REMOTE_NO_TRAVERSE, default=True): BOOL_SCHEMA, } @@ -339,9 +347,8 @@ def get_global_config_dir(): """ from appdirs import user_config_dir - return user_config_dir( - appname=Config.APPNAME, appauthor=Config.APPAUTHOR - ) + return user_config_dir(appname=Config.APPNAME, + appauthor=Config.APPAUTHOR) @staticmethod def get_system_config_dir(): @@ -352,9 +359,8 @@ def get_system_config_dir(): """ from appdirs import site_config_dir - return site_config_dir( - appname=Config.APPNAME, appauthor=Config.APPAUTHOR - ) + return site_config_dir(appname=Config.APPNAME, + appauthor=Config.APPAUTHOR) @staticmethod def init(dvc_dir): @@ -397,13 +403,11 @@ def _resolve_paths(self, config): return ret def _load_configs(self): - system_config_file = os.path.join( - self.get_system_config_dir(), self.CONFIG - ) + system_config_file = os.path.join(self.get_system_config_dir(), + self.CONFIG) - global_config_file = os.path.join( - self.get_global_config_dir(), self.CONFIG - ) + global_config_file = os.path.join(self.get_global_config_dir(), + self.CONFIG) self._system_config = configobj.ConfigObj(system_config_file) self._global_config = configobj.ConfigObj(global_config_file) @@ -437,10 +441,10 @@ def load(self): self.config = configobj.ConfigObj() for c in [ - self._system_config, - self._global_config, - self._repo_config, - self._local_config, + self._system_config, + self._global_config, + self._repo_config, + self._local_config, ]: c = self._resolve_paths(c) c = self._lower(c) @@ -516,9 +520,8 @@ def unset(self, section, opt=None, level=None, force=False): if opt not in config[section].keys(): if force: return - raise ConfigError( - "option '{}.{}' doesn't exist".format(section, opt) - ) + raise ConfigError("option '{}.{}' doesn't exist".format( + section, opt)) del config[section][opt] if not config[section]: @@ -551,8 +554,7 @@ def set(self, section, opt, value, level=None, force=True): elif not force: raise ConfigError( "Section '{}' already exists. Use `-f|--force` to overwrite " - "section with new value.".format(section) - ) + "section with new value.".format(section)) config[section][opt] = value self.save(config) @@ -574,9 +576,8 @@ def get(self, section, opt=None, level=None): raise ConfigError("section '{}' doesn't exist".format(section)) if opt not in config[section].keys(): - raise ConfigError( - "option '{}.{}' doesn't exist".format(section, opt) - ) + raise ConfigError("option '{}.{}' doesn't exist".format( + section, opt)) return config[section][opt] diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index e8ffe81f45..c70b09ec69 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -1,20 +1,20 @@ from __future__ import unicode_literals +from .config import RemoteConfig from dvc.remote.azure import RemoteAZURE +from dvc.remote.gdrive import RemoteGDrive from dvc.remote.gs import RemoteGS from dvc.remote.hdfs import RemoteHDFS -from dvc.remote.local import RemoteLOCAL -from dvc.remote.s3 import RemoteS3 -from dvc.remote.ssh import RemoteSSH from dvc.remote.http import RemoteHTTP from dvc.remote.https import RemoteHTTPS +from dvc.remote.local import RemoteLOCAL from dvc.remote.oss import RemoteOSS - -from .config import RemoteConfig - +from dvc.remote.s3 import RemoteS3 +from dvc.remote.ssh import RemoteSSH REMOTES = [ RemoteAZURE, + RemoteGDrive, RemoteGS, RemoteHDFS, RemoteHTTP, diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py new file mode 100644 index 0000000000..95ef42999c --- /dev/null +++ b/dvc/remote/gdrive/__init__.py @@ -0,0 +1,266 @@ +from __future__ import unicode_literals + +import logging +import os +import posixpath + +from backoff import expo +from backoff import on_exception +from funcy import cached_property + +from dvc.config import Config +from dvc.exceptions import DvcException +from dvc.path_info import CloudURLInfo +from dvc.remote.base import RemoteBASE +from dvc.remote.gdrive.pydrive import RequestCreateFolder +from dvc.remote.gdrive.pydrive import RequestDownloadFile +from dvc.remote.gdrive.pydrive import RequestListFile +from dvc.remote.gdrive.pydrive import RequestListFilePaginated +from dvc.remote.gdrive.pydrive import RequestUploadFile +from dvc.remote.gdrive.utils import FOLDER_MIME_TYPE +from dvc.scheme import Schemes + +logger = logging.getLogger(__name__) + + +class GDriveURLInfo(CloudURLInfo): + @property + def netloc(self): + return self.parsed.netloc + + +class RemoteGDrive(RemoteBASE): + scheme = Schemes.GDRIVE + path_cls = GDriveURLInfo + REGEX = r"^gdrive://.*$" + REQUIRES = {"pydrive": "pydrive"} + GDRIVE_USER_CREDENTIALS_DATA = "GDRIVE_USER_CREDENTIALS_DATA" + DEFAULT_USER_CREDENTIALS_FILE = ".dvc/tmp/gdrive-user-credentials.json" + + def __init__(self, repo, config): + super(RemoteGDrive, self).__init__(repo, config) + self.no_traverse = False + self.path_info = self.path_cls(config[Config.SECTION_REMOTE_URL]) + self.config = config + self.init_drive() + + def init_drive(self): + self.gdrive_client_id = self.config.get( + Config.SECTION_GDRIVE_CLIENT_ID, None) + self.gdrive_client_secret = self.config.get( + Config.SECTION_GDRIVE_CLIENT_SECRET, None) + if not self.gdrive_client_id or not self.gdrive_client_secret: + raise DvcException("Please specify Google Drive's client id and " + "secret in DVC's config. Learn more at " + "https://man.dvc.org/remote/add.") + self.gdrive_user_credentials_path = self.config.get( + Config.SECTION_GDRIVE_USER_CREDENTIALS_FILE, + self.DEFAULT_USER_CREDENTIALS_FILE, + ) + + self.root_id = self.get_path_id(self.path_info, create=True) + self.cached_dirs, self.cached_ids = self.cache_root_dirs() + + @on_exception(expo, DvcException, max_tries=8) + def execute_request(self, request): + try: + result = request.execute() + except Exception as exception: + retry_codes = ["403", "500", "502", "503", "504"] + if any(code in str(exception) for code in retry_codes): + raise DvcException("Google API request failed") + raise + return result + + def list_drive_item(self, query): + list_request = RequestListFilePaginated(self.drive, query) + page_list = self.execute_request(list_request) + while page_list: + for item in page_list: + yield item + page_list = self.execute_request(list_request) + + def cache_root_dirs(self): + cached_dirs = {} + cached_ids = {} + for dir1 in self.list_drive_item( + "'{}' in parents and trashed=false".format(self.root_id)): + cached_dirs.setdefault(dir1["title"], []).append(dir1["id"]) + cached_ids[dir1["id"]] = dir1["title"] + return cached_dirs, cached_ids + + @cached_property + def drive(self): + from pydrive.auth import GoogleAuth + from pydrive.drive import GoogleDrive + + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + with open(self.gdrive_user_credentials_path, + "w") as credentials_file: + credentials_file.write( + os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA)) + + GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings" + GoogleAuth.DEFAULT_SETTINGS["client_config"] = { + "client_id": self.gdrive_client_id, + "client_secret": self.gdrive_client_secret, + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "revoke_uri": "https://oauth2.googleapis.com/revoke", + "redirect_uri": "", + } + GoogleAuth.DEFAULT_SETTINGS["save_credentials"] = True + GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file" + GoogleAuth.DEFAULT_SETTINGS[ + "save_credentials_file"] = self.gdrive_user_credentials_path + GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True + GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [ + "https://www.googleapis.com/auth/drive", + "https://www.googleapis.com/auth/drive.appdata", + ] + + # Pass non existent settings path to force DEFAULT_SETTINGS loading + gauth = GoogleAuth(settings_file="") + gauth.CommandLineAuth() + + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + os.remove(self.gdrive_user_credentials_path) + + gdrive = GoogleDrive(gauth) + return gdrive + + def create_drive_item(self, parent_id, title): + upload_request = RequestCreateFolder({ + "drive": self.drive, + "title": title, + "parent_id": parent_id + }) + result = self.execute_request(upload_request) + return result + + def get_drive_item(self, name, parents_ids): + if not parents_ids: + return None + query = " or ".join("'{}' in parents".format(parent_id) + for parent_id in parents_ids) + + query += " and trashed=false and title='{}'".format(name) + + list_request = RequestListFile(self.drive, query) + item_list = self.execute_request(list_request) + return next(iter(item_list), None) + + def resolve_remote_file(self, parents_ids, path_parts, create): + for path_part in path_parts: + item = self.get_drive_item(path_part, parents_ids) + if not item and create: + item = self.create_drive_item(parents_ids[0], path_part) + elif not item: + return None + parents_ids = [item["id"]] + return item + + def subtract_root_path(self, parts): + if not hasattr(self, "root_id"): + return parts, [self.path_info.netloc] + + for part in self.path_info.path.split("/"): + if parts and parts[0] == part: + parts.pop(0) + else: + break + return parts, [self.root_id] + + def get_path_id_from_cache(self, path_info): + files_ids = [] + parts, parents_ids = self.subtract_root_path(path_info.path.split("/")) + if (hasattr(self, "cached_dirs") and path_info != self.path_info + and parts and (parts[0] in self.cached_dirs)): + parents_ids = self.cached_dirs[parts[0]] + files_ids = self.cached_dirs[parts[0]] + parts.pop(0) + + return files_ids, parents_ids, parts + + def get_path_id(self, path_info, create=False): + files_ids, parents_ids, parts = self.get_path_id_from_cache(path_info) + + if not parts and files_ids: + return files_ids[0] + + file1 = self.resolve_remote_file(parents_ids, parts, create) + return file1["id"] if file1 else "" + + def exists(self, path_info): + return self.get_path_id(path_info) != "" + + def _upload(self, from_file, to_info, name, no_progress_bar): + dirname = to_info.parent + if dirname: + parent_id = self.get_path_id(dirname, True) + else: + parent_id = to_info.netloc + + upload_request = RequestUploadFile( + { + "drive": self.drive, + "title": to_info.name, + "parent_id": parent_id, + }, + no_progress_bar, + from_file, + name, + ) + self.execute_request(upload_request) + + def _download(self, from_info, to_file, name, no_progress_bar): + file_id = self.get_path_id(from_info) + download_request = RequestDownloadFile({ + "drive": + self.drive, + "file_id": + file_id, + "to_file": + to_file, + "progress_name": + name, + "no_progress_bar": + no_progress_bar, + }) + self.execute_request(download_request) + + def list_cache_paths(self): + file_id = self.get_path_id(self.path_info) + prefix = self.path_info.path + for path in self.list_path(file_id): + yield posixpath.join(prefix, path) + + def list_file_path(self, drive_file): + if drive_file["mimeType"] == FOLDER_MIME_TYPE: + for i in self.list_path(drive_file["id"]): + yield posixpath.join(drive_file["title"], i) + else: + yield drive_file["title"] + + def list_path(self, parent_id): + for file1 in self.list_drive_item( + "'{}' in parents and trashed=false".format(parent_id)): + for path in self.list_file_path(file1): + yield path + + def all(self): + if not hasattr(self, "cached_ids") or not self.cached_ids: + return + + query = " or ".join("'{}' in parents".format(dir_id) + for dir_id in self.cached_ids) + + query += " and trashed=false" + for file1 in self.list_drive_item(query): + parent_id = file1["parents"][0]["id"] + path = posixpath.join(self.cached_ids[parent_id], file1["title"]) + try: + yield self.path_to_checksum(path) + except ValueError: + # We ignore all the non-cache looking files + logger.debug('Ignoring path as "non-cache looking"') diff --git a/dvc/remote/gdrive/pydrive.py b/dvc/remote/gdrive/pydrive.py new file mode 100644 index 0000000000..5c8595c56e --- /dev/null +++ b/dvc/remote/gdrive/pydrive.py @@ -0,0 +1,111 @@ +import os + +from dvc.remote.gdrive.utils import FOLDER_MIME_TYPE +from dvc.remote.gdrive.utils import TrackFileReadProgress + + +class RequestBASE: + def __init__(self, drive): + self.drive = drive + + def execute(self): + raise NotImplementedError + + +class RequestListFile(RequestBASE): + def __init__(self, drive, query): + super(RequestListFile, self).__init__(drive) + self.query = query + + def execute(self): + return self.drive.ListFile({ + "q": self.query, + "maxResults": 1000 + }).GetList() + + +class RequestListFilePaginated(RequestBASE): + def __init__(self, drive, query): + super(RequestListFilePaginated, self).__init__(drive) + self.query = query + self.iter = None + + def execute(self): + if not self.iter: + self.iter = iter( + self.drive.ListFile({ + "q": self.query, + "maxResults": 1000 + })) + return next(self.iter, None) + + +class RequestCreateFolder(RequestBASE): + def __init__(self, args): + super(RequestCreateFolder, self).__init__(args["drive"]) + self.title = args["title"] + self.parent_id = args["parent_id"] + + def execute(self): + item = self.drive.CreateFile({ + "title": self.title, + "parents": [{ + "id": self.parent_id + }], + "mimeType": FOLDER_MIME_TYPE, + }) + item.Upload() + return item + + +class RequestUploadFile(RequestBASE): + def __init__(self, + args, + no_progress_bar=True, + from_file="", + progress_name=""): + super(RequestUploadFile, self).__init__(args["drive"]) + self.title = args["title"] + self.parent_id = args["parent_id"] + self.no_progress_bar = no_progress_bar + self.from_file = from_file + self.progress_name = progress_name + + def upload(self, item): + with open(self.from_file, "rb") as from_file: + if not self.no_progress_bar: + from_file = TrackFileReadProgress(self.progress_name, + from_file) + if os.stat(self.from_file).st_size: + item.content = from_file + item.Upload() + + def execute(self): + item = self.drive.CreateFile({ + "title": self.title, + "parents": [{ + "id": self.parent_id + }] + }) + self.upload(item) + return item + + +class RequestDownloadFile(RequestBASE): + def __init__(self, args): + super(RequestDownloadFile, self).__init__(args["drive"]) + self.file_id = args["file_id"] + self.to_file = args["to_file"] + self.progress_name = args["progress_name"] + self.no_progress_bar = args["no_progress_bar"] + + def execute(self): + from dvc.progress import Tqdm + + gdrive_file = self.drive.CreateFile({"id": self.file_id}) + if not self.no_progress_bar: + tqdm = Tqdm(desc=self.progress_name, + total=int(gdrive_file["fileSize"])) + gdrive_file.GetContentFile(self.to_file) + if not self.no_progress_bar: + tqdm.close() diff --git a/dvc/remote/gdrive/utils.py b/dvc/remote/gdrive/utils.py new file mode 100644 index 0000000000..e067e2b737 --- /dev/null +++ b/dvc/remote/gdrive/utils.py @@ -0,0 +1,31 @@ +import os + +from dvc.progress import Tqdm + +FOLDER_MIME_TYPE = "application/vnd.google-apps.folder" + + +class TrackFileReadProgress(object): + UPDATE_AFTER_READ_COUNT = 30 + + def __init__(self, progress_name, fobj): + self.progress_name = progress_name + self.fobj = fobj + self.file_size = os.fstat(fobj.fileno()).st_size + self.tqdm = Tqdm(desc=self.progress_name, total=self.file_size) + self.update_counter = 0 + + def read(self, size): + if self.update_counter == 0: + self.tqdm.update_to(self.fobj.tell()) + self.update_counter = self.UPDATE_AFTER_READ_COUNT + else: + self.update_counter -= 1 + return self.fobj.read(size) + + def close(self): + self.fobj.close() + self.tqdm.close() + + def __getattr__(self, attr): + return getattr(self.fobj, attr) diff --git a/dvc/scheme.py b/dvc/scheme.py index e12b768f58..5f7a8d1a28 100644 --- a/dvc/scheme.py +++ b/dvc/scheme.py @@ -9,5 +9,6 @@ class Schemes: HTTP = "http" HTTPS = "https" GS = "gs" + GDRIVE = "gdrive" LOCAL = "local" OSS = "oss" diff --git a/setup.py b/setup.py index 47bd6c46bb..cf772a1786 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,14 @@ -from setuptools import setup, find_packages -from setuptools.command.build_py import build_py as _build_py import os import sys +from setuptools import find_packages +from setuptools import setup +from setuptools.command.build_py import build_py as _build_py + +import fastentrypoints # noqa: F401 # Prevents pkg_resources import in entry point script, # see https://github.com/ninjaaron/fast-entry_points. # This saves about 200 ms on startup time for non-wheel installs. -import fastentrypoints # noqa: F401 - # https://packaging.python.org/guides/single-sourcing-package-version/ pkg_dir = os.path.dirname(os.path.abspath(__file__)) @@ -86,6 +87,7 @@ def run(self): # Extra dependencies for remote integrations gs = ["google-cloud-storage==1.19.0"] +gdrive = ["pydrive==1.3.1", "backoff>=1.8.1"] s3 = ["boto3==1.9.115"] azure = ["azure-storage-blob==2.1.0"] oss = ["oss2==2.6.1"] @@ -96,7 +98,7 @@ def run(self): # we can start shipping it by default. ssh_gssapi = ["paramiko[gssapi]>=2.5.0"] hdfs = ["pyarrow==0.14.0"] -all_remotes = gs + s3 + azure + ssh + oss +all_remotes = gs + s3 + azure + ssh + oss + gdrive if os.name != "nt" or sys.version_info[0] != 2: # NOTE: there are no pyarrow wheels for python2 on windows @@ -147,6 +149,7 @@ def run(self): extras_require={ "all": all_remotes, "gs": gs, + "gdrive": gdrive, "s3": s3, "azure": azure, "oss": oss, diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 6ceeb07bb3..23a3e8f281 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -1,47 +1,52 @@ -from subprocess import CalledProcessError -from subprocess import check_output, Popen -from unittest import SkipTest -import os -import uuid -import shutil -import getpass -import platform import copy +import getpass import logging -import pytest +import os +import platform +import shutil +import uuid +from subprocess import CalledProcessError +from subprocess import check_output +from subprocess import Popen +from unittest import SkipTest +import pytest from mock import patch -from dvc.utils.compat import str -from dvc.utils import env2bool -from dvc.main import main -from dvc.config import Config from dvc.cache import NamedCache +from dvc.config import Config from dvc.data_cloud import DataCloud -from dvc.remote import ( - RemoteS3, - RemoteGS, - RemoteAZURE, - RemoteOSS, - RemoteLOCAL, - RemoteSSH, - RemoteHDFS, - RemoteHTTP, -) -from dvc.remote.base import STATUS_OK, STATUS_NEW, STATUS_DELETED +from dvc.main import main +from dvc.remote import RemoteAZURE +from dvc.remote import RemoteGDrive +from dvc.remote import RemoteGS +from dvc.remote import RemoteHDFS +from dvc.remote import RemoteHTTP +from dvc.remote import RemoteLOCAL +from dvc.remote import RemoteOSS +from dvc.remote import RemoteS3 +from dvc.remote import RemoteSSH +from dvc.remote.base import STATUS_DELETED +from dvc.remote.base import STATUS_NEW +from dvc.remote.base import STATUS_OK +from dvc.utils import env2bool from dvc.utils import file_md5 -from dvc.utils.stage import load_stage_file, dump_stage_file - +from dvc.utils.compat import str +from dvc.utils.stage import dump_stage_file +from dvc.utils.stage import load_stage_file from tests.basic_env import TestDvc from tests.utils import spy - TEST_REMOTE = "upstream" TEST_SECTION = 'remote "{}"'.format(TEST_REMOTE) TEST_CONFIG = { Config.SECTION_CACHE: {}, - Config.SECTION_CORE: {Config.SECTION_CORE_REMOTE: TEST_REMOTE}, - TEST_SECTION: {Config.SECTION_REMOTE_URL: ""}, + Config.SECTION_CORE: { + Config.SECTION_CORE_REMOTE: TEST_REMOTE + }, + TEST_SECTION: { + Config.SECTION_REMOTE_URL: "" + }, } TEST_AWS_REPO_BUCKET = os.environ.get("DVC_TEST_AWS_REPO_BUCKET", "dvc-test") @@ -52,11 +57,14 @@ os.environ.get( "GOOGLE_APPLICATION_CREDENTIALS", os.path.join("scripts", "ci", "gcp-creds.json"), - ) -) + )) # Ensure that absolute path is used os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = TEST_GCP_CREDS_FILE +TEST_GDRIVE_CLIENT_ID = ( + "719861249063-v4an78j9grdtuuuqg3lnm0sugna6v3lh.apps.googleusercontent.com") +TEST_GDRIVE_CLIENT_SECRET = "2fy_HyzSwkxkGzEken7hThXb" + def _should_test_aws(): do_test = env2bool("DVC_TEST_AWS", undefined=None) @@ -69,6 +77,13 @@ def _should_test_aws(): return False +def _should_test_gdrive(): + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + return True + + return False + + def _should_test_gcp(): do_test = env2bool("DVC_TEST_GCP", undefined=None) if do_test is not None: @@ -78,15 +93,13 @@ def _should_test_gcp(): return False try: - check_output( - [ - "gcloud", - "auth", - "activate-service-account", - "--key-file", - TEST_GCP_CREDS_FILE, - ] - ) + check_output([ + "gcloud", + "auth", + "activate-service-account", + "--key-file", + TEST_GCP_CREDS_FILE, + ]) except (CalledProcessError, OSError): return False return True @@ -98,8 +111,7 @@ def _should_test_azure(): return do_test return os.getenv("AZURE_STORAGE_CONTAINER_NAME") and os.getenv( - "AZURE_STORAGE_CONNECTION_STRING" - ) + "AZURE_STORAGE_CONNECTION_STRING") def _should_test_oss(): @@ -107,11 +119,8 @@ def _should_test_oss(): if do_test is not None: return do_test - return ( - os.getenv("OSS_ENDPOINT") - and os.getenv("OSS_ACCESS_KEY_ID") - and os.getenv("OSS_ACCESS_KEY_SECRET") - ) + return (os.getenv("OSS_ENDPOINT") and os.getenv("OSS_ACCESS_KEY_ID") + and os.getenv("OSS_ACCESS_KEY_SECRET")) def _should_test_ssh(): @@ -136,9 +145,9 @@ def _should_test_hdfs(): return False try: - check_output( - ["hadoop", "version"], shell=True, executable=os.getenv("SHELL") - ) + check_output(["hadoop", "version"], + shell=True, + executable=os.getenv("SHELL")) except (CalledProcessError, IOError): return False @@ -163,9 +172,8 @@ def get_local_url(): def get_ssh_url(): - return "ssh://{}@127.0.0.1:22{}".format( - getpass.getuser(), get_local_storagepath() - ) + return "ssh://{}@127.0.0.1:22{}".format(getpass.getuser(), + get_local_storagepath()) def get_ssh_url_mocked(user, port): @@ -188,9 +196,8 @@ def get_ssh_url_mocked(user, port): def get_hdfs_url(): - return "hdfs://{}@127.0.0.1{}".format( - getpass.getuser(), get_local_storagepath() - ) + return "hdfs://{}@127.0.0.1{}".format(getpass.getuser(), + get_local_storagepath()) def get_aws_storagepath(): @@ -201,6 +208,10 @@ def get_aws_url(): return "s3://" + get_aws_storagepath() +def get_gdrive_url(): + return "gdrive://root/" + str(uuid.uuid4()) + + def get_gcp_storagepath(): return TEST_GCP_REPO_BUCKET + "/" + str(uuid.uuid4()) @@ -263,9 +274,8 @@ def _get_keyfile(self): def _ensure_should_run(self): if not self._should_test(): - raise SkipTest( - "Test {} is disabled".format(self.__class__.__name__) - ) + raise SkipTest("Test {} is disabled".format( + self.__class__.__name__)) def _setup_cloud(self): self._ensure_should_run() @@ -374,6 +384,33 @@ def _get_cloud_class(self): return RemoteS3 +class TestRemoteGDrive(TestDataCloudBase): + def _should_test(self): + return _should_test_gdrive() + + def _setup_cloud(self): + self._ensure_should_run() + + repo = self._get_url() + + config = copy.deepcopy(TEST_CONFIG) + config[TEST_SECTION][Config.SECTION_REMOTE_URL] = repo + config[TEST_SECTION][ + Config.SECTION_GDRIVE_CLIENT_ID] = TEST_GDRIVE_CLIENT_ID + config[TEST_SECTION][ + Config.SECTION_GDRIVE_CLIENT_SECRET] = TEST_GDRIVE_CLIENT_SECRET + self.dvc.config.config = config + self.cloud = DataCloud(self.dvc) + + self.assertIsInstance(self.cloud.get_remote(), self._get_cloud_class()) + + def _get_url(self): + return get_gdrive_url() + + def _get_cloud_class(self): + return RemoteGDrive + + class TestRemoteGS(TestDataCloudBase): def _should_test(self): return _should_test_gcp() @@ -386,8 +423,7 @@ def _setup_cloud(self): config = copy.deepcopy(TEST_CONFIG) config[TEST_SECTION][Config.SECTION_REMOTE_URL] = repo config[TEST_SECTION][ - Config.SECTION_GCP_CREDENTIALPATH - ] = TEST_GCP_CREDS_FILE + Config.SECTION_GCP_CREDENTIALPATH] = TEST_GCP_CREDS_FILE self.dvc.config.config = config self.cloud = DataCloud(self.dvc) @@ -569,9 +605,8 @@ def _test(self): def test(self): if not self._should_test(): - raise SkipTest( - "Test {} is disabled".format(self.__class__.__name__) - ) + raise SkipTest("Test {} is disabled".format( + self.__class__.__name__)) self._test() @@ -620,6 +655,32 @@ def _test(self): self._test_cloud(TEST_REMOTE) +class TestRemoteGDriveCLI(TestDataCloudCLIBase): + def _should_test(self): + return _should_test_gdrive() + + def _test(self): + url = get_gdrive_url() + + self.main(["remote", "add", TEST_REMOTE, url]) + self.main([ + "remote", + "modify", + TEST_REMOTE, + Config.SECTION_GDRIVE_CLIENT_ID, + TEST_GDRIVE_CLIENT_ID, + ]) + self.main([ + "remote", + "modify", + TEST_REMOTE, + Config.SECTION_GDRIVE_CLIENT_SECRET, + TEST_GDRIVE_CLIENT_SECRET, + ]) + + self._test_cloud(TEST_REMOTE) + + class TestRemoteGSCLI(TestDataCloudCLIBase): def _should_test(self): return _should_test_gcp() @@ -628,15 +689,13 @@ def _test(self): url = get_gcp_url() self.main(["remote", "add", TEST_REMOTE, url]) - self.main( - [ - "remote", - "modify", - TEST_REMOTE, - "credentialpath", - TEST_GCP_CREDS_FILE, - ] - ) + self.main([ + "remote", + "modify", + TEST_REMOTE, + "credentialpath", + TEST_GCP_CREDS_FILE, + ]) self._test_cloud(TEST_REMOTE) @@ -701,8 +760,7 @@ def _test(self): expected_warning = ( "Output 'bar'(Stage: 'bar.dvc') is missing version info." " Cache for it will not be collected." - " Use dvc repro to get your pipeline up to date." - ) + " Use dvc repro to get your pipeline up to date.") assert expected_warning in self._caplog.text @@ -795,9 +853,8 @@ def test(self): class TestCheckSumRecalculation(TestDvc): def test(self): test_get_file_checksum = spy(RemoteLOCAL.get_file_checksum) - with patch.object( - RemoteLOCAL, "get_file_checksum", test_get_file_checksum - ): + with patch.object(RemoteLOCAL, "get_file_checksum", + test_get_file_checksum): url = get_local_url() ret = main(["remote", "add", "-d", TEST_REMOTE, url]) self.assertEqual(ret, 0) @@ -833,14 +890,11 @@ def setUp(self): checksum_bar = file_md5(self.BAR)[0] self.message_header = ( "Some of the cache files do not exist neither locally " - "nor on remote. Missing cache files: " - ) + "nor on remote. Missing cache files: ") self.message_bar_part = "name: {}, md5: {}".format( - self.BAR, checksum_bar - ) + self.BAR, checksum_bar) self.message_foo_part = "name: {}, md5: {}".format( - self.FOO, checksum_foo - ) + self.FOO, checksum_foo) def test(self): self._caplog.clear() diff --git a/tests/unit/remote/gdrive/__init__.py b/tests/unit/remote/gdrive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/remote/gdrive/conftest.py b/tests/unit/remote/gdrive/conftest.py new file mode 100644 index 0000000000..035ca15094 --- /dev/null +++ b/tests/unit/remote/gdrive/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from dvc.remote.gdrive import RemoteGDrive + + +@pytest.fixture +def gdrive(repo): + ret = RemoteGDrive(None, {"url": "gdrive://root/data"}) + return ret diff --git a/tests/unit/remote/gdrive/test_gdrive.py b/tests/unit/remote/gdrive/test_gdrive.py new file mode 100644 index 0000000000..012adf12c9 --- /dev/null +++ b/tests/unit/remote/gdrive/test_gdrive.py @@ -0,0 +1,10 @@ +import mock + +from dvc.remote.gdrive import RemoteGDrive + + +@mock.patch("dvc.remote.gdrive.RemoteGDrive.init_drive") +def test_init_drive(repo): + url = "gdrive://root/data" + gdrive = RemoteGDrive(repo, {"url": url}) + assert str(gdrive.path_info) == url