/
interface.py
95 lines (74 loc) · 2.77 KB
/
interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import warnings
from hub.utils.store_control import StoreControlClient
from hub.marray.array import HubArray
import numpy as np
from hub.exceptions import WrongTypeError
import hub.backend.storage
from hub.backend.storage import StorageFactory
from hub.marray.dataset import Dataset
from hub.marray.bucket import HubBucket
def _get_path(name, public=False):
tag = 'latest'
if len(name.split(':')) == 2:
tag = name.split(':')[1]
name = name.split(':')[0]
if len(name.split('/')) == 1:
name = '{}/{}'.format(name, name)
user, dataset = name.split('/')
path = user+'/'+dataset+'/'+tag
return path
def S3(bucket=None, public=False, aws_access_key_id=None, aws_secret_access_key=None, parallel=25):
return HubBucket(hub.backend.storage.S3(bucket, public, aws_access_key_id, aws_secret_access_key, parallel), 's3')
def GS():
return HubBucket(hub.backend.storage.GS(), 'gs')
def FS(bucket=None):
return HubBucket(hub.backend.storage.FS(bucket), 'fs')
def array(shape=None, name=None, dtype='float', chunk_size=None, backend='s3', caching=False, storage=None, compression='zlib', compression_level=6):
if not name:
raise Exception(
'No name provided, please name your array - hub.array(..., name="username/dataset:version") '
)
path = _get_path(name)
if not shape:
return load(name)
try:
dtype = np.dtype(dtype).name
except:
raise WrongTypeError('Dtype {} is not supported '.format(dtype))
# auto chunking
if chunk_size is None:
chunk_size = list(shape)
chunk_size[0] = 1
# Input checking
assert len(chunk_size) == len(shape)
assert np.array(shape).dtype in np.sctypes['int']
assert np.array(chunk_size).dtype in np.sctypes['int']
if storage is None:
storage = StorageFactory(protocols=backend, caching=caching)
return HubArray(
shape=shape,
dtype=dtype,
chunk_shape=chunk_size,
key=path,
protocol=storage.protocol,
storage=storage,
compression=compression,
compression_level=compression_level
)
def dataset(arrays=None, name=None):
# TODO check inputs validity
name = _get_path(name)
if arrays is None:
return Dataset(key=name)
return Dataset(arrays, name)
def load(name, backend='s3', storage=None):
is_public = name in ['imagenet', 'cifar', 'coco', 'mnist']
path = _get_path(name, is_public)
if storage is None:
storage = StorageFactory(protocols=backend)
return HubArray(key=path, public=is_public, storage=storage)
# FIXME implement deletion of repositories
def delete(name):
path = _get_path(name)
bucket = StoreControlClient().get_config()['BUCKET']
s3.Object(bucket, path.split(bucket)[-1]).delete()