Skip to content

Commit

Permalink
add sample in dataset.rst
Browse files Browse the repository at this point in the history
  • Loading branch information
TsumiNa committed Jan 22, 2018
1 parent 2b8b79b commit 7a2880c
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 70 deletions.
1 change: 0 additions & 1 deletion samples/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def cb_plot(y, y_cb, fname: str=None, describe: str=None):
#%%
import torch
import torch.nn as nn
from XenonPy import BaseNet
from scipy.stats import boxcox
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor

Expand Down
68 changes: 55 additions & 13 deletions tests/test_datatools.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,44 +89,86 @@ def test__fetch_data(setup):

def test_saver1(setup):
saver = Saver(setup['user_dataset'])
assert len(saver) == 0, 'no files'
assert len(saver._files) == 0, 'no files'


def test_saver2(setup):
saver = Saver(setup['user_dataset'])
saver(list('asdf'), list('qwer'))
assert len(saver) == 2, 'should got 2 files'
saver(list('abcd'), list('efgh'))
assert len(saver._files) == 2, 'should got 2 files'


def test_saver_last(setup):
def test_saver3(setup):
saver = Saver(setup['user_dataset'])
saver(key1=list('asdf'), key2=list('qwer'))
assert len(saver._files) == 2, 'should got 2 files'
assert len(saver._name_files['key1']) == 1, 'should got 1 files'
assert len(saver._name_files['key2']) == 1, 'should got 1 files'


def test_saver4(setup):
saver = Saver(setup['user_dataset'])
saver(list('asdf'), key1=list('qwer'))
assert len(saver._files) == 3, 'should got 3 files'
assert len(saver._name_files['key1']) == 2, 'should got 1 files'
assert len(saver._name_files['key2']) == 1, 'should got 1 files'


def test_saver_last1(setup):
saver = Saver(setup['user_dataset'])
last = saver.last()
assert last == list('asdf'), 'retriever same data'


def test_saver_last2(setup):
saver = Saver(setup['user_dataset'])
last = saver.last('key1')
assert last == list('qwer'), 'retriever same data'


def test_saver_getitem(setup):
def test_saver_getitem1(setup):
saver = Saver(setup['user_dataset'])
item = saver[:]
assert item[1] == list('efgh'), 'retriever same data'
item = saver[1]
assert item == list('efgh'), 'retriever same data'


def test_saver_getitem2(setup):
saver = Saver(setup['user_dataset'])
item = saver['key2', :]
assert item[0] == list('qwer'), 'retriever same data'
item = saver['key1', 1]
assert item == list('qwer'), 'retriever same data'


def test_saver_delete1(setup):
saver = Saver(setup['user_dataset'])
saver.rm(0)
assert len(saver) == 1, 'should got 1 files'
assert len(saver._files) == 2, 'should got 1 files'


def test_saver_delete2(setup):
saver = Saver(setup['user_dataset'])
assert len(saver) == 1, 'should got 1 files'
saver(list('asdf'), list('qwer'))
assert len(saver) == 3, 'should got 3 files'
saver.rm(slice(0, 2))
assert len(saver) == 1, 'should got 1 files'
saver(key1=list('qwer'))
assert len(saver._name_files['key1']) == 3, 'should got 3 files'
saver.rm(slice(0, 2), 'key1')
assert len(saver._name_files['key1']) == 1, 'should got 1 files'


def test_saver_delete3(setup):
def test_saver_clean1(setup):
saver = Saver(setup['user_dataset'])
saver_dir = Path.home() / '.xenonpy' / 'userdata' / setup['user_dataset']
saver.rm()
saver.clean('key1')
assert 'key1' not in saver._name_files, 'no saver dir'


def test_saver_clean2(setup):
saver = Saver(setup['user_dataset'])
saver_dir = Path.home() / '.xenonpy' / 'userdata' / setup['user_dataset']
saver.clean()
assert not saver_dir.exists(), 'no saver dir'


if __name__ == '__main__':
pytest.main()
174 changes: 118 additions & 56 deletions xenonpy/utils/datatools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.

from collections.abc import Iterator
from datetime import datetime as dt
import uuid
from collections import defaultdict
from os import remove
from os.path import getmtime
from pathlib import Path
Expand Down Expand Up @@ -241,7 +241,7 @@ def elements_completed(self):
return self('elements_completed')


class Saver(Iterator):
class Saver(object):
"""
Save data in a convenient way:
Expand Down Expand Up @@ -272,101 +272,163 @@ class Saver(Iterator):
See Also: :doc:`dataset`
"""

def __init__(self, dataset=None):
"""
Parameters
----------
dataset: str
The dir to save and load data
"""
self._it = 0 # for iterator

self._path = Path.home() / __pkg_cfg_root__ / 'userdata' / dataset
if not self._path.exists():
self._path.mkdir(parents=True)
self._files = None
self._name_files = None
self._make_file_index()

def _make_file_index(self):
self._files = [f for f in self._path.iterdir() if f.match('*.pkl.*')]
self._files.sort(key=lambda f: getmtime(str(f)))

def _load_data(self, file):
self._files = list()
self._name_files = defaultdict(list)
files = [f for f in self._path.iterdir() if f.match('*.pkl.*')]

for f in files:
# named data
if len(f.suffixes) == 3:
dname = f.suffixes[0].lstrip('.')
self._name_files[dname].append(f)
continue

# unnamed data
self._files.append(f)

self._sort_files(self._files)
for v in self._name_files.values():
self._sort_files(v)

@classmethod
def _sort_files(cls, files):
if files is not None:
files.sort(key=lambda f: getmtime(str(f)))

@classmethod
def _load_data(cls, file):
if file.suffix == '.pd_':
return pd.read_pickle(str(file))
else:
return jl.load(file)

def last(self):
def _save_data(self, data, filename=None):
uid = str(uuid.uuid1()).replace('-', '')
print(uid)
prefix = uid + '.' + filename if filename else uid
if not self._path.exists():
self._path.mkdir()

if isinstance(data, pd.DataFrame):
file = self._path / (prefix + '.pkl.pd_')
pd.to_pickle(data, str(file))
else:
file = self._path / (prefix + '.pkl.z')
jl.dump(data, file)

return file

def last(self, data_name: str = None):
"""
Return last saved data.
Args
----
data_name: str
Data's name. Omit for access unnamed data
Return
-------
ret:any python object
Data stored in `*.pkl` file.
"""
return self._load_data(self._files[-1])
if data_name is None:
return self._load_data(self._files[-1])
return self._load_data(self._name_files[data_name][-1])

def rm(self, index=None):
def rm(self, index, dname: str = None):
"""
Delete file(s) with given index.
Parameters
----------
index: int or slice
Data index of dataset can be a slice object returned by :func:`slice`
Omit will be delete hole dir.
Index of data. Data sorted by datetime.
dname: str
Data's name. Omit for access unnamed data.
"""
if index is None:
rmtree(str(self._path))
self._files = list()
return

if isinstance(index, int):
remove(str(self._files[index]))
if not dname:
files = self._files[index]
if not isinstance(files, list):
remove(str(files))
else:
for f in files:
remove(str(f))
del self._files[index]
return

if isinstance(index, slice):
del_files = self._files[index]
for f in del_files:
files = self._name_files[dname][index]
if not isinstance(files, list):
remove(str(files))
else:
for f in files:
remove(str(f))
del self._name_files[dname][index]

self._make_file_index()

def __getitem__(self, item):
""""""
files = self._files[item]
if not isinstance(files, list):
return self._load_data(files)
return [self._load_data(f) for f in files]
def clean(self, data_name: str = None):
"""
Remove all data by name. Omit to remove hole dataset.
def __iter__(self):
self._it = 0
return self
Parameters
----------
data_name: str
Data's name.Omit to remove hole dataset.
"""
if data_name is None:
rmtree(str(self._path))
self._files = list()
self._name_files = defaultdict(list)
return

def __len__(self):
return len(self._files)
for f in self._name_files[data_name]:
remove(str(f))
del self._name_files[data_name]

def __next__(self):
try:
file = self._files[self._it]
self._it += 1
return self._load_data(file)
except IndexError:
raise StopIteration
def __getitem__(self, item):

def __call__(self, *data, name: str = None):
# load file
def _load_file(files, item):
_files = files[item]
if not isinstance(_files, list):
return self._load_data(_files)
return [self._load_data(f) for f in _files]

if isinstance(item, tuple):
try:
key, index = item
except ValueError:
raise ValueError('except 2 parameters. [str, int or slice]')
if not isinstance(key, str) or \
(not isinstance(index, int) and not isinstance(index, slice)):
raise ValueError('except 2 parameters. [str, int or slice]')
return _load_file(self._name_files[key], index)

if isinstance(item, str):
return self.__getitem__((item, slice(None, None, None)))

return _load_file(self._files, item)

def __call__(self, *data, **named_data):
for d in data:
file_name = dt.now().strftime('%Y-%m-%d_%H-%M-%S_%f')
if not self._path.exists():
self._path.mkdir()

if isinstance(d, pd.DataFrame):
file = self._path / (file_name + '.pkl.pd_')
pd.to_pickle(d, str(file))
else:
file = self._path / (file_name + '.pkl.z')
jl.dump(d, file)
f = self._save_data(d)
self._files.append(f)

self._make_file_index()
for k, v in named_data.items():
f = self._save_data(v, k)
self._name_files[k].append(f)

0 comments on commit 7a2880c

Please sign in to comment.