Skip to content

Commit

Permalink
Merge pull request #1085 from wright-group/compression_test
Browse files Browse the repository at this point in the history
enable compression
  • Loading branch information
kameyer226 committed Jul 27, 2022
2 parents bc6f240 + 2688b36 commit f86bf00
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).
- `Data.__getitem__` supports array slicing
- `artists.interact2D` supports `cmap` kwarg.
- iPython integration: autocomplete includes axis, variable, and channel names
- Allow `create_variable` and `create_channel` to create compressed datasets

### Changed
- `Data.chop` refactored to make steps modular
Expand Down
17 changes: 16 additions & 1 deletion WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,12 @@ def create_channel(
require_kwargs["dtype"] = values.dtype
if np.prod(require_kwargs["shape"]) == 1:
require_kwargs["chunks"] = None
if "compression" in kwargs:
require_kwargs["compression"] = kwargs["compression"]
if "compression_opts" in kwargs:
require_kwargs["compression_opts"] = kwargs["compression_opts"]
if "shuffle" in kwargs:
require_kwargs["shuffle"] = kwargs["shuffle"]
# create dataset
dataset_id = self.require_dataset(name=name, **require_kwargs).id
channel = Channel(self, dataset_id, units=units, **kwargs)
Expand Down Expand Up @@ -1017,9 +1023,18 @@ def create_variable(
shape = values.shape
dtype = values.dtype
fillvalue = None
require_kwargs = {"chunks": True}
if "compression" in kwargs:
require_kwargs["compression"] = kwargs["compression"]
if "compression_opts" in kwargs:
require_kwargs["compression_opts"] = kwargs["compression_opts"]
if "shuffle" in kwargs:
require_kwargs["shuffle"] = kwargs["shuffle"]
if np.prod(shape) == 1:
require_kwargs["chunks"] = None
# create dataset
id = self.require_dataset(
name=name, data=values, shape=shape, dtype=dtype, fillvalue=fillvalue
name=name, data=values, shape=shape, dtype=dtype, fillvalue=fillvalue, **require_kwargs
).id
variable = Variable(self, id, units=units, **kwargs)
# finish
Expand Down
13 changes: 13 additions & 0 deletions docs/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,22 @@ Of course, if you find yourself processing a lot of data from a particular file
Having trouble connecting the WrightTools `Data` structure to bare `numpy` arrays?
We have a notebook that takes a look at how many common `numpy.ndarray` operations--
slicing, element-wise math, broadcasting, etc.--have analogues within the WrightTools data structure:

.. image:: https://mybinder.org/badge_logo.svg
:target: https://mybinder.org/v2/gh/wright-group/WrightTools/master?filepath=examples%2Fwt%20for%20np%20users.ipynb

Creating Compressed Datasets
````````````````````````````

WrightTools can transparently read and create compressed datasets by passing arguments to :meth:`~WrightTools.data.Data.create_variable` or :meth:`~WrightTools.data.Data.create_channel`.
These arguments are the same as are passed to `h5py's create_dataset method <https://docs.h5py.org/en/stable/high/dataset.html#filter-pipeline>`_.

.. code-block:: python
data = wt.Data(name='example')
data.create_variable(name='w1', units='wn', shape=(1024, 1024), compression="gzip")
data.create_channel(name='signal', shape=(1024, 1024), compression="gzip", compression_opts=9)
Structure & Attributes
----------------------

Expand Down
9 changes: 9 additions & 0 deletions tests/data/dataset_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pytest

import os


def test_create_variable():
data = wt.Data()
Expand Down Expand Up @@ -31,6 +33,13 @@ def test_exception():
d.create_channel(name="w1")


def test_create_compressed_channel():
data = wt.Data()
child1 = data.create_channel("hi", shape=(1024, 1024), compression="gzip")
data["hi"][:] = 0
assert os.path.getsize(data.filepath) < 1e6


if __name__ == "__main__":
test_create_variable()
test_create_channel()
Expand Down

0 comments on commit f86bf00

Please sign in to comment.