Skip to content

Commit

Permalink
Merge f5f1bac into 412e71e
Browse files Browse the repository at this point in the history
  • Loading branch information
emilyjcosta5 committed Aug 6, 2019
2 parents 412e71e + f5f1bac commit bef6c4e
Show file tree
Hide file tree
Showing 21 changed files with 1,582 additions and 74 deletions.
35 changes: 18 additions & 17 deletions pyUSID/viz/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import ImageGrid
from ..io.dtype_utils import get_exponent
import dask.array as da

if sys.version_info.major == 3:
unicode = str
Expand Down Expand Up @@ -346,7 +347,7 @@ def make_linear_alpha_cmap(name, solid_color, normalization_val, min_alpha=0, ma
"""
if not isinstance(name, (str, unicode)):
raise TypeError('name should be a string')
if not isinstance(solid_color, (list, tuple, np.ndarray)):
if not isinstance(solid_color, (list, tuple, np.ndarray, da.core.Array)):
raise TypeError('solid_color must be a list of numbers')
if not len(solid_color) == 4:
raise ValueError('solid-color should have fourth values')
Expand Down Expand Up @@ -438,9 +439,9 @@ def rainbow_plot(axis, x_vec, y_vec, num_steps=32, **kwargs):
"""
if not isinstance(axis, mpl.axes.Axes):
raise TypeError('axis must be a matplotlib.axes.Axes object')
if not isinstance(x_vec, (list, tuple, np.ndarray)):
if not isinstance(x_vec, (list, tuple, np.ndarray, da.core.Array)):
raise TypeError('x_vec must be array-like of numbers')
if not isinstance(x_vec, (list, tuple, np.ndarray)):
if not isinstance(x_vec, (list, tuple, np.ndarray, da.core.Array)):
raise TypeError('x_vec must be array-like of numbers')
x_vec = np.array(x_vec)
y_vec = np.array(y_vec)
Expand Down Expand Up @@ -499,13 +500,13 @@ def plot_line_family(axis, x_vec, line_family, line_names=None, label_prefix='',
"""
if not isinstance(axis, mpl.axes.Axes):
raise TypeError('axis must be a matplotlib.axes.Axes object')
if not isinstance(x_vec, (list, tuple, np.ndarray)):
if not isinstance(x_vec, (list, tuple, np.ndarray, da.core.Array)):
raise TypeError('x_vec must be array-like of numbers')
x_vec = np.array(x_vec)
assert x_vec.ndim == 1, 'x_vec must be a 1D array'
if not isinstance(line_family, list):
line_family = np.array(line_family)
if not isinstance(line_family, np.ndarray):
if not isinstance(line_family, (np.ndarray, da.core.Array)):
raise TypeError('line_family must be a 2d array of numbers')
assert line_family.ndim == 2, 'line_family must be a 2D array'
# assert x_vec.shape[1] == line_family.shape[1], \
Expand Down Expand Up @@ -596,7 +597,7 @@ def plot_map(axis, img, show_xy_ticks=True, show_cbar=True, x_vec=None, y_vec=No
"""
if not isinstance(axis, mpl.axes.Axes):
raise TypeError('axis must be a matplotlib.axes.Axes object')
if not isinstance(img, np.ndarray):
if not isinstance(img, (np.ndarray, da.core.Array)):
raise TypeError('img should be a numpy array')
if not img.ndim == 2:
raise ValueError('img should be a 2D array')
Expand Down Expand Up @@ -658,7 +659,7 @@ def set_ticks_for_axis(tick_vals, is_x):
print(tick_labs)
tick_vals = np.linspace(0, tick_vals, img_size)
else:
if not isinstance(tick_vals, (np.ndarray, list, tuple, range)) or len(tick_vals) != img_size:
if not isinstance(tick_vals, (np.ndarray, list, tuple, range, da.core.Array)) or len(tick_vals) != img_size:
raise ValueError(
'{} should be array-like with shape equal to axis {} of img'.format(tick_vals_var_name,
img_axis))
Expand Down Expand Up @@ -798,11 +799,11 @@ def plot_curves(excit_wfms, datasets, line_colors=[], dataset_names=[], evenly_s
for var, var_name, dim_size in zip([datasets, excit_wfms], ['datasets', 'excit_wfms'], [2, 1]):
mesg = '{} should be {}D arrays or iterables (list or tuples) of {}D arrays' \
'.'.format(var_name, dim_size, dim_size)
if isinstance(var, (h5py.Dataset, np.ndarray)):
if isinstance(var, (h5py.Dataset, np.ndarray, da.core.Array)):
if not len(var.shape) == dim_size:
raise ValueError(mesg)
elif isinstance(var, (list, tuple)):
if not np.all([isinstance(dset, (h5py.Dataset, np.ndarray)) for dset in datasets]):
if not np.all([isinstance(dset, (h5py.Dataset, np.ndarray, da.core.Array)) for dset in datasets]):
raise TypeError(mesg)
else:
raise TypeError(mesg)
Expand All @@ -811,12 +812,12 @@ def plot_curves(excit_wfms, datasets, line_colors=[], dataset_names=[], evenly_s
# 0 = one excitation waveform and one dataset
# 1 = one excitation waveform but many datasets
# 2 = one excitation waveform for each of many dataset
if isinstance(datasets, (h5py.Dataset, np.ndarray)):
if isinstance(datasets, (h5py.Dataset, np.ndarray, da.core.Array)):
# can be numpy array or h5py.dataset
num_pos = datasets.shape[0]
num_points = datasets.shape[1]
datasets = [datasets]
if isinstance(excit_wfms, (np.ndarray, h5py.Dataset)):
if isinstance(excit_wfms, (np.ndarray, h5py.Dataset, da.core.Array)):
excit_wfms = [excit_wfms]
elif isinstance(excit_wfms, list):
if len(excit_wfms) == num_points:
Expand All @@ -835,7 +836,7 @@ def plot_curves(excit_wfms, datasets, line_colors=[], dataset_names=[], evenly_s
num_points_es = list()

for dataset in datasets:
if not isinstance(dataset, (h5py.Dataset, np.ndarray)):
if not isinstance(dataset, (h5py.Dataset, np.ndarray, da.core.Array)):
raise TypeError('datasets can be a list of 2D h5py.Dataset or numpy array objects')
if len(dataset.shape) != 2:
raise ValueError('Each datset should be a 2D array')
Expand Down Expand Up @@ -971,11 +972,11 @@ def plot_complex_spectra(map_stack, x_vec=None, num_comps=4, title=None, x_label
---------
fig, axes
"""
if not isinstance(map_stack, np.ndarray) or not map_stack.ndim in [2, 3]:
if not isinstance(map_stack, (np.ndarray, da.core.Array)) or not map_stack.ndim in [2, 3]:
raise TypeError('map_stack should be a 2/3 dimensional array arranged as [component, row, col] or '
'[component, spectra')
if x_vec is not None:
if not isinstance(x_vec, (list, tuple, np.ndarray)):
if not isinstance(x_vec, (list, tuple, np.ndarray, da.core.Array)):
raise TypeError('x_vec should be a 1D array')
x_vec = np.array(x_vec)
if x_vec.ndim != 1:
Expand Down Expand Up @@ -1071,7 +1072,7 @@ def plot_scree(scree, title='Scree', **kwargs):
if isinstance(scree, (list, tuple)):
scree = np.array(scree)

if not (isinstance(scree, np.ndarray) or isinstance(scree, h5py.Dataset)):
if not (isinstance(scree, (np.ndarray, da.core.Array)) or isinstance(scree, h5py.Dataset)):
raise TypeError('scree must be a 1D array or Dataset')
if not isinstance(title, (str, unicode)):
raise TypeError('title must be a string')
Expand Down Expand Up @@ -1147,7 +1148,7 @@ def plot_map_stack(map_stack, num_comps=9, stdevs=2, color_bar_mode=None, evenly
---------
fig, axes
"""
if not isinstance(map_stack, np.ndarray) or not map_stack.ndim == 3:
if not isinstance(map_stack, (np.ndarray, da.core.Array)) or not map_stack.ndim == 3:
raise TypeError('map_stack should be a 3 dimensional array arranged as [component, row, col]')
if num_comps is None:
num_comps = 4 # Default
Expand Down Expand Up @@ -1181,7 +1182,7 @@ def plot_map_stack(map_stack, num_comps=9, stdevs=2, color_bar_mode=None, evenly
if not isinstance(var, bool):
raise TypeError(var_name + ' should be a bool')
for var, var_name in zip([fig_mult, pad_mult], ['fig_mult', 'pad_mult']):
if not isinstance(var, (list, tuple, np.ndarray)) or len(var) != 2:
if not isinstance(var, (list, tuple, np.ndarray, da.core.Array)) or len(var) != 2:
raise TypeError(var_name + ' should be a tuple / list / numpy array of size 2')
if not np.all([x > 0 and isinstance(x, Number) for x in var]):
raise ValueError(var_name + ' should contain positive numbers')
Expand Down
6 changes: 4 additions & 2 deletions tests/io/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def make_sparse_sampling_file():
if os.path.exists(sparse_sampling_path):
os.remove(sparse_sampling_path)

h5_main = None

with h5py.File(sparse_sampling_path) as h5_f:
h5_meas_grp = h5_f.create_group('Measurement_000')

Expand Down Expand Up @@ -193,7 +195,7 @@ def make_sparse_sampling_file():
# Link ancillary
for dset in [h5_pos_inds, h5_pos_vals, h5_spec_inds, h5_spec_vals]:
h5_main.attrs[dset.name.split('/')[-1]] = dset.ref

return h5_meas_grp

def make_incomplete_measurement_file():
if os.path.exists(incomplete_measurement_path):
Expand Down Expand Up @@ -484,4 +486,4 @@ def make_beps_file(rev_spec=False):

# Now need to link as main!
for dset in [h5_pos_inds, h5_pos_vals, h5_results_2_spec_inds, h5_results_2_spec_vals]:
h5_results_2_main.attrs[dset.name.split('/')[-1]] = dset.ref
h5_results_2_main.attrs[dset.name.split('/')[-1]] = dset.ref
3 changes: 1 addition & 2 deletions tests/io/hdf_utils/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
sys.path.append("../../pyUSID/")
from pyUSID.io import hdf_utils

from tests.io import data_utils

from tests.processing import data_utils

if sys.version_info.major == 3:
unicode = str
Expand Down
33 changes: 16 additions & 17 deletions tests/io/hdf_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
sys.path.append("../../pyUSID/")
from pyUSID.io import hdf_utils, write_utils, USIDataset

from tests.io import data_utils

from tests.processing import data_utils

if sys.version_info.major == 3:
unicode = str
Expand All @@ -32,7 +31,7 @@ def setUp(self):
data_utils.make_relaxation_file()

def tearDown(self):
for file_path in [data_utils.std_beps_path,
for file_path in [data_utils.std_beps_path,
data_utils.sparse_sampling_path,
data_utils.incomplete_measurement_path,
data_utils.relaxation_path]:
Expand Down Expand Up @@ -525,10 +524,10 @@ def test_small(self):
self.assertTrue(np.allclose(main_data, usid_main[()]))

data_utils.validate_aux_dset_pair(self, h5_f, usid_main.h5_pos_inds, usid_main.h5_pos_vals, pos_names, pos_units,
pos_data, h5_main=usid_main, is_spectral=False)
pos_data, h5_main=usid_main, is_spectral=False)

data_utils.validate_aux_dset_pair(self, h5_f, usid_main.h5_spec_inds, usid_main.h5_spec_vals, spec_names, spec_units,
spec_data, h5_main=usid_main, is_spectral=True)
spec_data, h5_main=usid_main, is_spectral=True)
os.remove(file_path)

def test_dask(self):
Expand Down Expand Up @@ -568,10 +567,10 @@ def test_dask(self):
self.assertTrue(np.allclose(main_data, usid_main[()]))

data_utils.validate_aux_dset_pair(self, h5_f, usid_main.h5_pos_inds, usid_main.h5_pos_vals, pos_names, pos_units,
pos_data, h5_main=usid_main, is_spectral=False)
pos_data, h5_main=usid_main, is_spectral=False)

data_utils.validate_aux_dset_pair(self, h5_f, usid_main.h5_spec_inds, usid_main.h5_spec_vals, spec_names, spec_units,
spec_data, h5_main=usid_main, is_spectral=True)
spec_data, h5_main=usid_main, is_spectral=True)
os.remove(file_path)

def test_empty(self):
Expand Down Expand Up @@ -610,10 +609,10 @@ def test_empty(self):
self.assertEqual(main_data, usid_main.shape)

data_utils.validate_aux_dset_pair(self, h5_f, usid_main.h5_pos_inds, usid_main.h5_pos_vals, pos_names, pos_units,
pos_data, h5_main=usid_main, is_spectral=False)
pos_data, h5_main=usid_main, is_spectral=False)

data_utils.validate_aux_dset_pair(self, h5_f, usid_main.h5_spec_inds, usid_main.h5_spec_vals, spec_names, spec_units,
spec_data, h5_main=usid_main, is_spectral=True)
spec_data, h5_main=usid_main, is_spectral=True)
os.remove(file_path)

def test_write_main_existing_spec_aux(self):
Expand Down Expand Up @@ -644,15 +643,15 @@ def test_write_main_existing_spec_aux(self):

with h5py.File(file_path) as h5_f:
h5_spec_inds, h5_spec_vals = hdf_utils.write_ind_val_dsets(h5_f, spec_dims, is_spectral=True)
data_utils.validate_aux_dset_pair(self,h5_f, h5_spec_inds, h5_spec_vals, spec_names, spec_units, spec_data,
is_spectral=True)
data_utils.validate_aux_dset_pair(self, h5_f, h5_spec_inds, h5_spec_vals, spec_names, spec_units, spec_data,
is_spectral=True)

usid_main = hdf_utils.write_main_dataset(h5_f, main_data, main_data_name, quantity, dset_units, pos_dims,
None, h5_spec_inds=h5_spec_inds, h5_spec_vals=h5_spec_vals,
main_dset_attrs=None)

data_utils.validate_aux_dset_pair(self,h5_f, usid_main.h5_pos_inds, usid_main.h5_pos_vals, pos_names, pos_units,
pos_data, h5_main=usid_main, is_spectral=False)
data_utils.validate_aux_dset_pair(self, h5_f, usid_main.h5_pos_inds, usid_main.h5_pos_vals, pos_names, pos_units,
pos_data, h5_main=usid_main, is_spectral=False)

os.remove(file_path)

Expand Down Expand Up @@ -691,11 +690,11 @@ def test_existing_both_aux(self):
h5_pos_vals=h5_pos_vals, h5_pos_inds=h5_pos_inds,
main_dset_attrs=None)

data_utils.validate_aux_dset_pair(self,h5_f, h5_pos_inds, h5_pos_vals, pos_names, pos_units,
pos_data, h5_main=usid_main, is_spectral=False)
data_utils.validate_aux_dset_pair(self, h5_f, h5_pos_inds, h5_pos_vals, pos_names, pos_units,
pos_data, h5_main=usid_main, is_spectral=False)

data_utils.validate_aux_dset_pair(self,h5_f, h5_spec_inds, h5_spec_vals, spec_names,spec_units,
spec_data, h5_main=usid_main, is_spectral=True)
data_utils.validate_aux_dset_pair(self, h5_f, h5_spec_inds, h5_spec_vals, spec_names, spec_units,
spec_data, h5_main=usid_main, is_spectral=True)
os.remove(file_path)

def test_prod_sizes_mismatch(self):
Expand Down

0 comments on commit bef6c4e

Please sign in to comment.