Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: google cloud storage class #252

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
178 changes: 178 additions & 0 deletions zarr/storage.py
Expand Up @@ -1886,3 +1886,181 @@ def __delitem__(self, key):
with self._mutex:
self._invalidate_keys()
self._invalidate_value(key)


# utility functions for object stores


def _strip_prefix_from_path(path, prefix):
# normalized things will not have any leading or trailing slashes
path_norm = normalize_storage_path(path)
prefix_norm = normalize_storage_path(prefix)
if path_norm.startswith(prefix_norm):
return path_norm[(len(prefix_norm)+1):]
else:
return path


def _append_path_to_prefix(path, prefix):
return '/'.join([normalize_storage_path(prefix),
normalize_storage_path(path)])


def atexit_rmgcspath(bucket, path):
from google.cloud import storage
client = storage.Client()
bucket = client.get_bucket(bucket)
bucket.delete_blobs(bucket.list_blobs(prefix=path))


class GCSStore(MutableMapping):
"""Storage class using a Google Cloud Storage (GCS)

Parameters
----------
bucket_name : string
The name of the GCS bucket
prefix : string, optional
The prefix within the bucket (i.e. subdirectory)
client_kwargs : dict, optional
Extra options passed to ``google.cloud.storage.Client`` when connecting
to GCS

Notes
-----
In order to use this store, you must install the Google Cloud Storage
`Python Client Library <https://cloud.google.com/storage/docs/reference/libraries>`_.
You must also provide valid application credentials, either by setting the
``GOOGLE_APPLICATION_CREDENTIALS`` environment variable or via
`default credentials <https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login>`_.
"""

def __init__(self, bucket_name, prefix=None, client_kwargs={}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding an option to use anonymous client. E.g., add an anonymous=False keyword argument, then make use of storage.Client.create_anonymous_client() when it comes to creating the client if user has provided anonymous=True.


self.bucket_name = bucket_name
self.prefix = normalize_storage_path(prefix)
self.client_kwargs = client_kwargs
self.initialize_bucket()

def initialize_bucket(self):
from google.cloud import storage
# run `gcloud auth application-default login` from shell
client = storage.Client(**self.client_kwargs)
self.bucket = client.get_bucket(self.bucket_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that it's also possible to do:

self.bucket = storage.Bucket(client, name=self.bucket_name)

...which involves no network communication. Not sure this is a good idea in general as may want to retrieve the bucket info, but just mentioning.

# need to properly handle excpetions
import google.api_core.exceptions as exceptions
self.exceptions = exceptions

# needed for pickling
def __getstate__(self):
state = self.__dict__.copy()
del state['bucket']
del state['exceptions']
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.initialize_bucket()

def __enter__(self):
return self

def __exit__(self, *args):
pass

def full_path(self, path=None):
return _append_path_to_prefix(path, self.prefix)

def list_gcs_directory_blobs(self, path):
"""Return list of all blobs *directly* under a gcs prefix."""
prefix = normalize_storage_path(path) + '/'
return [blob.name for blob in
self.bucket.list_blobs(prefix=prefix, delimiter='/')]

# from https://github.com/GoogleCloudPlatform/google-cloud-python/issues/920
def list_gcs_subdirectories(self, path):
"""Return set of all "subdirectories" from a gcs prefix."""
prefix = normalize_storage_path(path) + '/'
iterator = self.bucket.list_blobs(prefix=prefix, delimiter='/')
prefixes = set()
for page in iterator.pages:
prefixes.update(page.prefixes)
# need to strip trailing slash to be consistent with os.listdir
return [path[:-1] for path in prefixes]

def list_gcs_directory(self, prefix, strip_prefix=True):
"""Return a list of all blobs and subdirectories from a gcs prefix."""
items = set()
items.update(self.list_gcs_directory_blobs(prefix))
items.update(self.list_gcs_subdirectories(prefix))
items = list(items)
if strip_prefix:
items = [_strip_prefix_from_path(path, prefix) for path in items]
return items

def listdir(self, path=None):
dir_path = self.full_path(path)
return sorted(self.list_gcs_directory(dir_path, strip_prefix=True))

def rmdir(self, path=None):
# make sure it's a directory
dir_path = normalize_storage_path(self.full_path(path)) + '/'
self.bucket.delete_blobs(self.bucket.list_blobs(prefix=dir_path))

def getsize(self, path=None):
# this function should *not* be recursive
# a lot of slash trickery is required to make this work right
full_path = self.full_path(path)
blob = self.bucket.get_blob(full_path)
if blob is not None:
return blob.size
else:
dir_path = normalize_storage_path(full_path) + '/'
blobs = self.bucket.list_blobs(prefix=dir_path, delimiter='/')
size = 0
for blob in blobs:
size += blob.size
return size

def clear(self):
self.rmdir()

def __getitem__(self, key):
blob_name = self.full_path(key)
blob = self.bucket.get_blob(blob_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative here is to do:

        from google.cloud import storage
        blob = storage.Blob(blob_name, self.bucket)

...which involves less network communication (profiling shows number of calls to method 'read' of '_ssl._SSLSocket' objects goes from 3 down to 1) and reduces the time to retrieve small objects by around 50%.

If this change was made, some rethinking of error handling may be needed, as the point at which a non-existing blob was detected might change.

if blob:
return blob.download_as_string()
else:
raise KeyError('Blob %s not found' % blob_name)

def __setitem__(self, key, value):
blob_name = self.full_path(key)
blob = self.bucket.blob(blob_name)
blob.upload_from_string(value)

def __delitem__(self, key):
blob_name = self.full_path(key)
try:
self.bucket.delete_blob(blob_name)
except self.exceptions.NotFound as er:
raise KeyError(er.message)

def __contains__(self, key):
blob_name = self.full_path(key)
return self.bucket.get_blob(blob_name) is not None

def __eq__(self, other):
return (
isinstance(other, GCSStore) and
self.bucket_name == other.bucket_name and
self.prefix == other.prefix
)

def __iter__(self):
blobs = self.bucket.list_blobs(prefix=self.prefix)
for blob in blobs:
yield _strip_prefix_from_path(blob.name, self.prefix)

def __len__(self):
iterator = self.bucket.list_blobs(prefix=self.prefix)
return len(list(iterator))
26 changes: 24 additions & 2 deletions zarr/tests/test_core.py
Expand Up @@ -7,7 +7,7 @@
import pickle
import os
import warnings

import uuid

import numpy as np
from numpy.testing import assert_array_equal, assert_array_almost_equal
Expand All @@ -16,7 +16,7 @@

from zarr.storage import (DirectoryStore, init_array, init_group, NestedDirectoryStore,
DBMStore, LMDBStore, atexit_rmtree, atexit_rmglob,
LRUStoreCache)
LRUStoreCache, GCSStore, atexit_rmgcspath)
from zarr.core import Array
from zarr.errors import PermissionError
from zarr.compat import PY2, text_type, binary_type
Expand Down Expand Up @@ -1698,3 +1698,25 @@ def create_array(read_only=False, **kwargs):
init_array(store, **kwargs)
return Array(store, read_only=read_only, cache_metadata=cache_metadata,
cache_attrs=cache_attrs)


try:
from google.cloud import storage as gcstorage
except ImportError: # pragma: no cover
gcstorage = None


@unittest.skipIf(gcstorage is None, 'google-cloud-storage is not installed')
class TestGCSArray(TestArray):

def create_array(self, read_only=False, **kwargs):
bucket = 'zarr-test'
prefix = uuid.uuid4()
atexit.register(atexit_rmgcspath, bucket, prefix)
store = GCSStore(bucket, prefix)
cache_metadata = kwargs.pop('cache_metadata', True)
cache_attrs = kwargs.pop('cache_attrs', True)
kwargs.setdefault('compressor', Zlib(1))
init_array(store, **kwargs)
return Array(store, read_only=read_only, cache_metadata=cache_metadata,
cache_attrs=cache_attrs)
30 changes: 29 additions & 1 deletion zarr/tests/test_storage.py
Expand Up @@ -8,6 +8,7 @@
import array
import shutil
import os
import uuid


import numpy as np
Expand All @@ -19,7 +20,8 @@
DirectoryStore, ZipStore, init_group, group_meta_key,
getsize, migrate_1to2, TempStore, atexit_rmtree,
NestedDirectoryStore, default_compressor, DBMStore,
LMDBStore, atexit_rmglob, LRUStoreCache)
LMDBStore, atexit_rmglob, LRUStoreCache, GCSStore,
atexit_rmgcspath)
from zarr.meta import (decode_array_metadata, encode_array_metadata, ZARR_FORMAT,
decode_group_metadata, encode_group_metadata)
from zarr.compat import PY2
Expand Down Expand Up @@ -1235,3 +1237,29 @@ def test_format_compatibility():
else:
assert compressor.codec_id == z.compressor.codec_id
assert compressor.get_config() == z.compressor.get_config()


try:
from google.cloud import storage as gcstorage
# cleanup function

except ImportError: # pragma: no cover
gcstorage = None


@unittest.skipIf(gcstorage is None, 'google-cloud-storage is not installed')
class TestGCSStore(StoreTests, unittest.TestCase):

def create_store(self):
# would need to be replaced with a dedicated test bucket
bucket = 'zarr-test'
prefix = uuid.uuid4()
atexit.register(atexit_rmgcspath, bucket, prefix)
store = GCSStore(bucket, prefix)
return store

def test_context_manager(self):
with self.create_store() as store:
store['foo'] = b'bar'
store['baz'] = b'qux'
assert 2 == len(store)