Skip to content

Commit

Permalink
Merge pull request #1067 from wright-group/fix_complex2
Browse files Browse the repository at this point in the history
Fix complex array operations
  • Loading branch information
ddkohler committed May 24, 2022
2 parents 2d0de5e + f736cd6 commit 91554cc
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).
- better error messages for some functions
- remove unused imports
- remove unused variables
- complex array support for data object operations

## [3.4.3]

Expand Down
18 changes: 13 additions & 5 deletions WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,12 @@ def create_channel(
require_kwargs["dtype"] = np.dtype(np.float64)
else:
require_kwargs["dtype"] = dtype
if require_kwargs["dtype"].kind in "fcmM":
if require_kwargs["dtype"].kind == "f":
require_kwargs["fillvalue"] = np.nan
elif require_kwargs["dtype"].kind == "M":
require_kwargs["fillvalue"] = np.datetime64("NaT")
elif require_kwargs["dtype"].kind == "c":
require_kwargs["fillvalue"] = complex(np.nan, np.nan)
else:
require_kwargs["fillvalue"] = 0
else:
Expand Down Expand Up @@ -917,8 +921,12 @@ def create_variable(
shape = self.shape
if dtype is None:
dtype = np.dtype(np.float64)
if dtype.kind in "fcmM":
if dtype.kind in "f":
fillvalue = np.nan
elif dtype.kind in "M":
fillvalue = np.datetime64("NaT")
elif dtype.kind in "c":
fillvalue = complex(np.nan, np.nan)
else:
fillvalue = 0
else:
Expand Down Expand Up @@ -1716,7 +1724,7 @@ def split(
out_arr = np.full(omask.shape, np.nan)
imask = wt_kit.enforce_mask_shape(imask, var.shape)
out_arr[omask] = var[:][imask]
out[i].create_variable(values=out_arr, **var.attrs)
out[i].create_variable(values=out_arr, **var.attrs, dtype=var.dtype)

for ch in self.channels:
for i, (imask, omask, cut) in enumerate(zip(masks, omasks, cuts)):
Expand All @@ -1725,10 +1733,10 @@ def split(
continue
omask = wt_kit.enforce_mask_shape(omask, ch.shape)
omask.shape = tuple([s for s, c in zip(omask.shape, cut) if not c])
out_arr = np.full(omask.shape, np.nan)
out_arr = np.full(omask.shape, np.nan, dtype=ch.dtype)
imask = wt_kit.enforce_mask_shape(imask, ch.shape)
out_arr[omask] = ch[:][imask]
out[i].create_channel(values=out_arr, **ch.attrs)
out[i].create_channel(values=out_arr, **ch.attrs, dtype=ch.dtype)

if verbose:
for d in out.values():
Expand Down
6 changes: 5 additions & 1 deletion WrightTools/data/_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,12 @@ def combine(data, out, item_name, new_idx, transpose, slice_):
new = out[item_name]
vals = np.empty_like(new)
# Default fill value based on whether dtype is floating or not
if vals.dtype.kind in "fcmM":
if vals.dtype.kind == "f":
vals[:] = np.nan
elif vals.dtype.kind == "M":
vals[:] = np.datetime64("NaT")
elif vals.dtype.kind == "c":
vals[:] = complex(np.nan, np.nan)
else:
vals[:] = 0
# Use advanced indexing to populate vals, a temporary array with same shape as out
Expand Down
10 changes: 9 additions & 1 deletion WrightTools/kit/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,15 @@ def share_nans(*arrs) -> tuple:
list
List of nD arrays in same order as given, with nan indicies syncronized.
"""
nans = np.zeros(joint_shape(*arrs))
kinds = {arr.dtype.kind for arr in arrs}
if "c" in kinds:
dtype = complex
elif "f" in kinds:
dtype = float
else:
dtype = arrs[0].dtype

nans = np.zeros(joint_shape(*arrs), dtype=dtype)
for arr in arrs:
nans *= arr
return tuple([a + nans for a in arrs])
Expand Down
23 changes: 23 additions & 0 deletions tests/data/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,28 @@ def test_2D_overlap_offset():
joined.close()


def test_2D_overlap_offset_complexarray():
a = wt.Data()
b = wt.Data()

a.create_variable("x", np.linspace(0, 10, 11)[:, None])
a.create_variable("y", np.linspace(0, 10, 11)[None, :])
b.create_variable("x", np.linspace(5, 15, 11)[:, None])
b.create_variable("y", np.linspace(0.5, 10.5, 11)[None, :])
a.create_channel("z", np.full(a.shape, 1j, dtype=np.complex128), dtype=np.complex128)
b.create_channel("z", np.full(b.shape, 2j, dtype=np.complex128), dtype=np.complex128)
a.transform("x", "y")
b.transform("x", "y")

joined = wt.data.join([a, b])

assert joined.shape == (16, 22)
assert joined.z[:].dtype == np.complex128
a.close()
b.close()
joined.close()


def test_2D_to_3D_overlap():
x1 = np.arange(-2.5, 2.5, 0.5)
x2 = np.arange(1, 10, 1)
Expand Down Expand Up @@ -701,6 +723,7 @@ def test_transpose():
test_1D_overlap_offset()
test_2D_no_overlap_aligned()
test_2D_no_overlap_offset()
test_2D_overlap_offset_complexarray()
test_2D_overlap_identical()
test_2D_overlap_offset()
test_1D_to_2D_aligned()
Expand Down
15 changes: 15 additions & 0 deletions tests/data/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ def test_split():
split.close()


def test_split_complexarray():
p = datasets.PyCMDS.wm_w2_w1_000
a = wt.data.from_PyCMDS(p)
a.create_channel(name="complex", values=np.complex128(a.channels[0][:]), dtype=np.complex128)
split = a.split(0, [19700], units="wn")
assert len(split) == 2
assert split[0].shape == (14, 11, 11)
assert split[1].shape == (21, 11, 11)
assert split[0].complex[:].dtype == np.complex128
assert a.units == split[0].units
a.close()
split.close()


def test_split_edge():
p = datasets.PyCMDS.wm_w2_w1_000
a = wt.data.from_PyCMDS(p)
Expand Down Expand Up @@ -182,4 +196,5 @@ def test_autotune():
test_split_expression()
test_split_hole()
test_split_constants()
test_split_complexarray()
test_autotune()

0 comments on commit 91554cc

Please sign in to comment.