Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dvc/dependency/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import dvc.output as output
from dvc.dependency.azure import AzureDependency
from dvc.dependency.gdrive import GDriveDependency
from dvc.dependency.gs import GSDependency
from dvc.dependency.hdfs import HDFSDependency
from dvc.dependency.http import HTTPDependency
Expand All @@ -19,6 +20,7 @@

DEPS = [
AzureDependency,
GDriveDependency,
GSDependency,
HDFSDependency,
HTTPDependency,
Expand All @@ -33,6 +35,7 @@
Schemes.SSH: SSHDependency,
Schemes.S3: S3Dependency,
Schemes.AZURE: AzureDependency,
Schemes.GDRIVE: GDriveDependency,
Schemes.GS: GSDependency,
Schemes.HDFS: HDFSDependency,
Schemes.HTTP: HTTPDependency,
Expand Down
7 changes: 7 additions & 0 deletions dvc/dependency/gdrive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dvc.dependency.base import BaseDependency
from dvc.output.base import BaseOutput
from dvc.remote.gdrive import GDriveRemote


class GDriveDependency(BaseDependency, BaseOutput):
REMOTE = GDriveRemote
20 changes: 20 additions & 0 deletions dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import io
import logging
import os
import posixpath
import re
import threading
from collections import defaultdict
from contextlib import contextmanager
from urllib.parse import urlparse

from funcy import cached_property, retry, wrap_prop, wrap_with
Expand All @@ -15,6 +17,7 @@
from dvc.remote.base import BaseRemote
from dvc.scheme import Schemes
from dvc.utils import format_link, tmp_fname
from dvc.utils.stream import IterStream

logger = logging.getLogger(__name__)
FOLDER_MIME_TYPE = "application/vnd.google-apps.folder"
Expand Down Expand Up @@ -393,6 +396,23 @@ def _gdrive_download_file(
) as pbar:
gdrive_file.GetContentFile(to_file, callback=pbar.update_to)

@contextmanager
@_gdrive_retry
def open(self, path_info, mode="r", encoding=None):
assert mode in {"r", "rt", "rb"}

item_id = self._get_item_id(path_info)
param = {"id": item_id}
# it does not create a file on the remote
gdrive_file = self._drive.CreateFile(param)
fd = gdrive_file.GetContentIOBuffer()
stream = IterStream(iter(fd))

if mode != "rb":
stream = io.TextIOWrapper(stream, encoding=encoding)

yield stream

@_gdrive_retry
def _gdrive_delete_file(self, item_id):
from pydrive2.files import ApiRequestError
Expand Down
46 changes: 2 additions & 44 deletions dvc/utils/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import io
from contextlib import contextmanager

from dvc.utils.stream import IterStream


@contextmanager
def open_url(url, mode="r", encoding=None):
Expand Down Expand Up @@ -61,47 +63,3 @@ def gen(response):
finally:
# Ensure connection is closed
it.close()


class IterStream(io.RawIOBase):
"""Wraps an iterator yielding bytes as a file object"""

def __init__(self, iterator):
self.iterator = iterator
self.leftover = None

def readable(self):
return True

# Python 3 requires only .readinto() method, it still uses other ones
# under some circumstances and falls back if those are absent. Since
# iterator already constructs byte strings for us, .readinto() is not the
# most optimal, so we provide .read1() too.

def readinto(self, b):
try:
n = len(b) # We're supposed to return at most this much
chunk = self.leftover or next(self.iterator)
output, self.leftover = chunk[:n], chunk[n:]

n_out = len(output)
b[:n_out] = output
return n_out
except StopIteration:
return 0 # indicate EOF

readinto1 = readinto

def read1(self, n=-1):
try:
chunk = self.leftover or next(self.iterator)
except StopIteration:
return b""

# Return an arbitrary number or bytes
if n <= 0:
self.leftover = None
return chunk

output, self.leftover = chunk[:n], chunk[n:]
return output
45 changes: 45 additions & 0 deletions dvc/utils/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import io


class IterStream(io.RawIOBase):
"""Wraps an iterator yielding bytes as a file object"""

def __init__(self, iterator):
self.iterator = iterator
self.leftover = None

def readable(self):
return True

# Python 3 requires only .readinto() method, it still uses other ones
# under some circumstances and falls back if those are absent. Since
# iterator already constructs byte strings for us, .readinto() is not the
# most optimal, so we provide .read1() too.

def readinto(self, b):
try:
n = len(b) # We're supposed to return at most this much
chunk = self.leftover or next(self.iterator)
output, self.leftover = chunk[:n], chunk[n:]

n_out = len(output)
b[:n_out] = output
return n_out
except StopIteration:
return 0 # indicate EOF

readinto1 = readinto

def read1(self, n=-1):
try:
chunk = self.leftover or next(self.iterator)
except StopIteration:
return b""

# Return an arbitrary number or bytes
if n <= 0:
self.leftover = None
return chunk

output, self.leftover = chunk[:n], chunk[n:]
return output
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def run(self):
# Extra dependencies for remote integrations

gs = ["google-cloud-storage==1.19.0"]
gdrive = ["pydrive2>=1.4.13"]
gdrive = ["pydrive2>=1.4.15"]
s3 = ["boto3>=1.9.201"]
azure = ["azure-storage-blob==2.1.0"]
oss = ["oss2==2.6.1"]
Expand Down
6 changes: 4 additions & 2 deletions tests/func/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dvc.main import main
from dvc.path_info import URLInfo
from dvc.utils.fs import remove
from tests.remotes import GCP, HDFS, OSS, S3, SSH, Azure, Local
from tests.remotes import GCP, HDFS, OSS, S3, SSH, Azure, GDrive, Local

remote_params = [S3, GCP, Azure, OSS, SSH, HDFS]
all_remote_params = [Local] + remote_params
Expand Down Expand Up @@ -56,7 +56,9 @@ def test_get_url_requires_dvc(tmp_dir, scm):
api.get_url("foo", repo=f"file://{tmp_dir}")


@pytest.mark.parametrize("remote_url", all_remote_params, indirect=True)
@pytest.mark.parametrize(
"remote_url", all_remote_params + [GDrive], indirect=True
)
def test_open(remote_url, tmp_dir, dvc):
run_dvc("remote", "add", "-d", "upstream", remote_url)
tmp_dir.dvc_gen("foo", "foo-text")
Expand Down
2 changes: 1 addition & 1 deletion tests/remotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from moto.s3 import mock_s3

from dvc.remote import GDriveRemote
from dvc.remote.gdrive import GDriveRemote
from dvc.remote.gs import GSRemote
from dvc.remote.s3 import S3Remote
from dvc.utils import env2bool
Expand Down