Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_xarray as field method #123

Merged
merged 61 commits into from Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
6dc1723
Feat(Field): add to_xarray method
swapneelap Feb 1, 2022
57c5caa
Chore(Field): add examples to to_xarray docstring
swapneelap Feb 1, 2022
ee55f77
Chore(Field): fix to_xarray docstring
swapneelap Feb 1, 2022
8d1dc92
Update setup.cfg
lang-m Feb 1, 2022
baf6469
Update pyproject.toml
lang-m Feb 1, 2022
e5ea0bc
Typos in docstring.
lang-m Feb 2, 2022
bd31225
Change style for better readability.
lang-m Feb 2, 2022
896c8ef
Fix doctests.
lang-m Feb 2, 2022
5b26d11
Refactor(Field): address @lang-m comments
swapneelap Feb 2, 2022
9b00318
Fix(Field): check field.components is not None
swapneelap Feb 2, 2022
4b5e33f
Refactor(Field): address @marijanbeg comments
swapneelap Feb 3, 2022
c541737
Update pyproject.toml
lang-m Feb 3, 2022
43719a5
Update pyproject.toml
lang-m Feb 3, 2022
43c0ce0
Refactor(Field): alter docstring following @fangohr comments
swapneelap Feb 4, 2022
ac4104f
Test(Field): remove unused variable
swapneelap Feb 7, 2022
0961e6e
Fix(Field): check if unit exists in mesh.attributes
swapneelap Feb 7, 2022
9bd4a1e
Test(Field): add tests for to_xarray
swapneelap Feb 7, 2022
51d07dd
Minimal version of pre-commit config to make tests pass
lang-m Feb 7, 2022
7bb2925
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2022
86d68c7
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
lang-m Feb 7, 2022
4184cf7
fix pre-commit
lang-m Feb 7, 2022
271ac30
regex for pre-commit
lang-m Feb 7, 2022
0b5cc13
Refactor(Field): address @lang-m and @marijanbeg comments
swapneelap Feb 8, 2022
4931360
Test(Field): address @lang-m and @marijanbeg comments
swapneelap Feb 8, 2022
ed29497
Fix(Field): update docstring
swapneelap Feb 8, 2022
dce5a05
Refactor(Field): add docstring headings following @lang-m comments
swapneelap Feb 8, 2022
3bb3300
Merge branch 'master' into field-to-xarray
lang-m Feb 11, 2022
7389e80
Test(Field): add tests for from_xarray class method
swapneelap Feb 16, 2022
a064eff
Feat(Field): add from_xarray class method
swapneelap Feb 16, 2022
6bf86ff
Fix(Field): import decimal not required
swapneelap Feb 16, 2022
f8d6b97
Refactor(Field): address @marijanbeg and @lang-m comments
swapneelap Feb 17, 2022
330f39b
Refactor(Field): address comments
swapneelap Feb 17, 2022
fd963e3
Test(Field): address comments
swapneelap Feb 17, 2022
dd0c596
Rewrite creation of cell, p1, and p2.
lang-m Feb 17, 2022
41131c0
Fix wrong indexing type.
lang-m Feb 17, 2022
6a461ca
Store cell as-is instead of in a new dictionary.
lang-m Feb 17, 2022
deeaa46
Unused variable
lang-m Feb 17, 2022
da4ea0d
Test(Field): add cell attribute tests for to_xarray
swapneelap Feb 18, 2022
2aab034
Merge branch 'field-to-xarray' of github.com:ubermag/discretisedfield…
lang-m Feb 18, 2022
9b4bc59
Refactor(Field): update to_xarray and from_xarray
swapneelap Feb 21, 2022
b2e2a2f
Test(Field): update test for to_xarray and from_xarray
swapneelap Feb 21, 2022
e3e4645
Chore(Field): add docstrings to from_xarray
swapneelap Feb 21, 2022
98aa4f2
Update field.py
lang-m Feb 21, 2022
7ecebc6
Merge branch 'master' into field-to-xarray
lang-m Feb 21, 2022
1c07e04
Refactor(Field): address comments by @lang-m
swapneelap Feb 22, 2022
79aa86c
Refactor(Field): address additional comments by @lang-m
swapneelap Feb 22, 2022
812876d
Test(Field): address comments by @lang-m
swapneelap Feb 22, 2022
efc7d96
Merge branch 'master' into field-to-xarray
lang-m Mar 1, 2022
9d6c75e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2022
639f326
Remove unused variable names
lang-m Mar 1, 2022
72d0650
Merge branch 'field-to-xarray' of github.com:ubermag/discretisedfield…
lang-m Mar 1, 2022
44e58e5
Merge branch 'master' into field-to-xarray
lang-m Mar 16, 2022
89fb2e7
Docs(xarray-usage): Add docs for `to_xarray` and `from_xarray`
swapneelap Mar 22, 2022
ef1f064
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2022
80883fd
Docs,Refactor,Fix(xarray-usage): remove exception cases
swapneelap Mar 22, 2022
48f662e
British english spelling and typos
lang-m Mar 22, 2022
eb5ee0c
Remove print around xarray in favor of html
lang-m Mar 22, 2022
dfc22c6
Use code-like str instead of the word string.
lang-m Mar 22, 2022
4242fe5
Merge branch 'master' into field-to-xarray
lang-m Mar 23, 2022
2cac3d5
Fix(Field): update `to_xarray` to use `mesh.midpoints`
swapneelap Mar 23, 2022
9f9dca3
Test(test_field): update tests to use `mesh.midpoints`
swapneelap Mar 23, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
248 changes: 248 additions & 0 deletions discretisedfield/field.py
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pandas as pd
import ubermagutil.typesystem as ts
import xarray as xr

import discretisedfield as df
import discretisedfield.plotting as dfp
Expand Down Expand Up @@ -3481,3 +3482,250 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
else:
return self.__class__(mesh[0], dim=result.shape[-1], value=result,
components=self.components)

def to_xarray(self, name='field', units=None):
"""Field value as ``xarray.DataArray``.

The function returns an ``xarray.DataArray`` with dimensions ``x``,
``y``, ``z``, and ``comp`` (``only if field.dim > 1``). The coordinates
of the geometric dimensions are derived from ``self.mesh.axis_points``,
and for vector field components from ``self.components``. Addtionally,
the values of ``self.mesh.cell``, ``self.mesh.region.p1``, and
``self.mesh.region.p2`` are stored as ``cell``, ``p1``, and ``p2``
attributes of the DataArray. The ``units`` attribute of geometric
dimensions is set to ``self.mesh.attributes['unit']``.

The name and units of the field ``DataArray`` can be set by passing
``name`` and ``units``. If the type of value passed to any of the two
arguments is not ``str``, then a ``TypeError`` is raised.

Parameters
----------
name : str, optional

String to set name of the field ``DataArray``.

units : str, optional
lang-m marked this conversation as resolved.
Show resolved Hide resolved

String to set units of the field ``DataArray``.

Returns
-------
xarray.DataArray

Field values DataArray.

Raises
------
TypeError

If either ``name`` or ``units`` argument is not a string.

Examples
swapneelap marked this conversation as resolved.
Show resolved Hide resolved
--------
1. Create a field

>>> import discretisedfield as df
...
>>> p1 = (0, 0, 0)
>>> p2 = (10, 10, 10)
>>> cell = (1, 1, 1)
>>> mesh = df.Mesh(p1=p1, p2=p2, cell=cell)
>>> field = df.Field(mesh=mesh, dim=3, value=(1, 0, 0), norm=1.)
...
>>> field
Field(...)

2. Create `xarray.DataArray` from field

>>> xa = field.to_xarray()
>>> xa
<xarray.DataArray 'field' (x: 10, y: 10, z: 10, comp: 3)>
...

3. Select values of `x` component

>>> xa.sel(comp='x')
<xarray.DataArray 'field' (x: 10, y: 10, z: 10)>
...

"""
if not isinstance(name, str):
msg = "Name argument must be a string."
raise TypeError(msg)

if units is not None and not isinstance(units, str):
lang-m marked this conversation as resolved.
Show resolved Hide resolved
msg = "Units argument must be a string."
raise TypeError(msg)

axes = ['x', 'y', 'z']

data_array_coords = {
axis: np.fromiter(self.mesh.axis_points(axis), dtype=float)
for axis in axes
}

if 'unit' in self.mesh.attributes:
lang-m marked this conversation as resolved.
Show resolved Hide resolved
geo_units_dict = dict.fromkeys(axes, self.mesh.attributes['unit'])
else:
geo_units_dict = dict.fromkeys(axes, 'm')

if self.dim > 1:
data_array_dims = axes + ['comp']
if self.components is not None:
data_array_coords['comp'] = self.components
field_array = self.array
else:
data_array_dims = axes
field_array = np.squeeze(self.array, axis=-1)

data_array = xr.DataArray(field_array,
dims=data_array_dims,
coords=data_array_coords,
name=name,
attrs=dict(units=units,
cell=self.mesh.cell,
p1=self.mesh.region.p1,
p2=self.mesh.region.p2))

for dim in geo_units_dict:
data_array[dim].attrs['units'] = geo_units_dict[dim]
lang-m marked this conversation as resolved.
Show resolved Hide resolved

return data_array

@classmethod
def from_xarray(cls, xa):
"""Create ``discretisedfield.Field`` from ``xarray.DataArray``

The class method accepts an ``xarray.DataArray`` as an argument to
return a ``discretisedfield.Field`` object. The DataArray must have
either three (``x``, ``y``, and ``z`` for a scalar field) or four
(additionally ``comp`` for a vector field) dimensions corresponding to
geometric axes and components of the field, respectively. The
coordinates of the ``x``, ``y``, and ``z`` dimensions represent the
discretisation along the respective axis and must have equally spaced
values. The coordinates of ``comp`` represent the field components
(e.g. ['x', 'y', 'z'] for a 3D vector field).

The ``DataArray`` is expected to have ``cell``, ``p1``, and ``p2``
attributes for creating ``discretisedfield.Mesh`` required by the
``discretisedfield.Field`` object. However, in the absence of these
attributes, the coordinates of ``x``, ``y``, and ``z`` dimensions are
utilized. It should be noted that ``cell`` attribute is required if
any of the geometric directions has only a single cell.

Parameters
----------
xa : xarray.DataArray

DataArray to create Field.

Returns
-------
discretisedfield.Field

Field created from DataArray.

Raises
------
TypeError

If argument is not ``xarray.DataArray``.

KeyError

If at least one of the geometric dimension coordinates has a single
value and ``cell`` attribute is missing.

ValueError

- If ``DataArray.ndim`` is not 3 or 4.
- If ``DataArray.dims`` are not either ``['x', 'y', 'z']`` or
``['x', 'y', 'z', 'comp']``
- If coordinates of ``x``, ``y``, or ``z`` are not equally
spaced

Examples
--------
1. Create a DataArray

>>> import xarray as xr
>>> import numpy as np
...
>>> xa = xr.DataArray(np.ones((20, 20, 20, 3), dtype=float),
... dims = ['x', 'y', 'z', 'comp'],
... coords = dict(x=np.arange(0, 20),
... y=np.arange(0, 20),
... z=np.arange(0, 20),
... comp=['x', 'y', 'z']),
... name = 'mag',
... attrs = dict(cell=[1., 1., 1.],
... p1=[1., 1., 1.],
... p2=[21., 21., 21.]))
>>> xa
<xarray.DataArray 'mag' (x: 20, y: 20, z: 20, comp: 3)>
...

2. Create Field from DataArray

>>> import discretisedfield as df
...
>>> field = df.Field.from_xarray(xa)
>>> field
Field(...)
>>> field.average
(1.0, 1.0, 1.0)

"""
if not isinstance(xa, xr.DataArray):
raise TypeError("Argument must be a xr.DataArray.")

if xa.ndim not in [3, 4]:
raise ValueError("DataArray dimensions must be 3 for a scalar "
"and 4 for a vector field.")

if xa.ndim == 3 and sorted(xa.dims) != ['x', 'y', 'z']:
raise ValueError("The dimensions must be 'x', 'y', and 'z'.")
elif xa.ndim == 4 and sorted(xa.dims) != ['comp', 'x', 'y', 'z']:
raise ValueError("The dimensions must be 'x', 'y', 'z',"
"and 'comp'.")

for i in 'xyz':
if xa[i].values.size > 1 and not np.allclose(
np.diff(xa[i].values), np.diff(xa[i].values).mean()):
raise ValueError(f'Coordinates of {i} must be'
' equally spaced.')

try:
cell = xa.attrs['cell']
except KeyError:
if any(len_ == 1 for len_ in xa.values.shape[:3]):
raise KeyError(
"DataArray must have a 'cell' attribute if any "
"of the geometric directions has a single cell."
) from None
cell = [np.diff(xa[i].values).mean() for i in 'xyz']

p1 = (
xa.attrs['p1'] if 'p1' in xa.attrs else
[xa[i].values[0] - c / 2 for i, c in zip('xyz', cell)]
)
p2 = (
xa.attrs['p2'] if 'p2' in xa.attrs else
[xa[i].values[-1] + c / 2 for i, c in zip('xyz', cell)]
)

if any('units' not in xa[i].attrs for i in 'xyz'):
mesh = df.Mesh(p1=p1, p2=p2, cell=cell)
else:
mesh = df.Mesh(p1=p1, p2=p2, cell=cell,
attributes={'unit': xa['z'].attrs['units']})

comp = xa.comp.values if 'comp' in xa.coords else None
val = np.expand_dims(xa.values, axis=-1) if xa.ndim == 3 else xa.values
dim = 1 if xa.ndim == 3 else val.shape[-1]
return cls(mesh=mesh,
dim=dim,
value=val,
components=comp,
dtype=xa.values.dtype)