Skip to content

Commit

Permalink
Index variables (#94)
Browse files Browse the repository at this point in the history
* add index variable

* handle index vars in @process decorated classes

* handle index vars in Model + add properties

* update doc (API)

* update tests

* fix Model.index_vars property

* add index vars as coords in output xarray Dataset

* update tests

* black

* update release notes
  • Loading branch information
benbovy committed Jan 17, 2020
1 parent 8a226cf commit b2753bc
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 33 deletions.
3 changes: 3 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ process names and values are objects of ``Process`` subclasses

Model.all_vars
Model.all_vars_dict
Model.index_vars
Model.index_vars_dict
Model.input_vars
Model.input_vars_dict
Model.dependent_processes
Expand Down Expand Up @@ -143,6 +145,7 @@ Variable
:toctree: _api_generated/

variable
index
foreign
group
on_demand
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ Enhancements
- :func:`xsimlab.variable` has now a ``converter`` parameter that can be used to
convert any input value before (maybe) validating it and setting the variable
(:issue:`92`).
- Added :func:`xsimlab.index` for setting index variables (e.g., coordinate
labels). Using the xarray extension, those variables are automatically added
in the output Dataset as coordinates (:issue:`94`).

Bug fixes
~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion xsimlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# flake8: noqa

from .xr_accessor import SimlabAccessor, create_setup
from .variable import variable, on_demand, foreign, group
from .variable import variable, index, on_demand, foreign, group
from .process import (
filter_variables,
process,
Expand Down
29 changes: 22 additions & 7 deletions xsimlab/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _maybe_save_output_vars(self, istep):

self.update_output_store(var_list)

def _to_xr_variable(self, key, clock):
def _to_xr_variable(self, key, clock, index=False):
"""Convert an output variable to a xarray.Variable object.
Maybe transpose the variable to match the dimension order
Expand All @@ -321,9 +321,13 @@ def _to_xr_variable(self, key, clock):
p_obj = self.model[p_name]
var = variables_dict(type(p_obj))[var_name]

data = self.output_store[key]
if clock is None:
data = data[0]
if index:
# get index directly from simulation store
data = self.store[key]
else:
data = self.output_store[key]
if clock is None:
data = data[0]

dims = _get_dims_from_variable(data, var, clock)
original_dims = self._transposed_vars.get(key)
Expand All @@ -349,20 +353,31 @@ def _to_xr_variable(self, key, clock):

def _get_output_dataset(self):
"""Return a new dataset as a copy of the input dataset updated with
output variables.
output variables and index variables.
"""
out_ds = self.dataset.copy()

# add/update output variables
xr_vars = {}

for key, clock in self.output_vars.items():
var_name = "__".join(key)
xr_vars[var_name] = self._to_xr_variable(key, clock)
xr_vars[var_name] = self._to_xr_variable(key, clock, index=False)

out_ds = self.dataset.copy()
out_ds.update(xr_vars)

# remove output_vars attributes in output dataset
out_ds.xsimlab._reset_output_vars(self.model, {})

# add/update index variables
xr_coords = {}

for key in self.model.index_vars():
_, var_name = key
xr_coords[var_name] = self._to_xr_variable(key, None, index=True)

out_ds.coords.update(xr_coords)

return out_ds

def _get_runtime_datasets(self):
Expand Down
61 changes: 40 additions & 21 deletions xsimlab/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
get_target_variable,
SimulationStage,
)
from .utils import AttrMapping, ContextMixin, variables_dict
from .utils import AttrMapping, ContextMixin
from .formatting import repr_model


Expand Down Expand Up @@ -98,7 +98,7 @@ def _get_var_key(self, p_name, var):

var_type = var.metadata["var_type"]

if var_type == VarType.VARIABLE:
if var_type in (VarType.VARIABLE, VarType.INDEX):
store_key = (p_name, var.name)

elif var_type == VarType.ON_DEMAND:
Expand Down Expand Up @@ -210,15 +210,19 @@ def filter_out(var):

raise ValueError(f"Conflict(s) found in given variable intents:\n{msg}")

def get_all_variables(self):
"""Get all variables in the model as a list of
def get_variables(self, **kwargs):
"""Get variables in the model as a list of
``(process_name, var_name)`` tuples.
**kwargs may be used to return only a subset of the variables.
"""
all_keys = []

for p_name, p_cls in self._processes_cls.items():
all_keys += [(p_name, var_name) for var_name in variables_dict(p_cls)]
all_keys += [
(p_name, var_name) for var_name in filter_variables(p_cls, **kwargs)
]

return all_keys

Expand Down Expand Up @@ -434,9 +438,12 @@ def __init__(self, processes):
builder.bind_processes(self)
builder.set_process_keys()

self._all_vars = builder.get_all_variables()
self._all_vars = builder.get_variables()
self._all_vars_dict = None

self._index_vars = builder.get_variables(var_type=VarType.INDEX)
self._index_vars_dict = None

builder.ensure_no_intent_conflict()

self._input_vars = builder.get_input_variables()
Expand All @@ -450,6 +457,19 @@ def __init__(self, processes):
super(Model, self).__init__(self._processes)
self._initialized = True

def _get_vars_dict_from_cache(self, attr_name):
dict_attr_name = attr_name + "_dict"

if getattr(self, dict_attr_name) is None:
vars_d = defaultdict(list)

for p_name, var_name in getattr(self, attr_name):
vars_d[p_name].append(var_name)

setattr(self, dict_attr_name, dict(vars_d))

return getattr(self, dict_attr_name)

@property
def all_vars(self):
"""Returns all variables in the model as a list of
Expand All @@ -464,15 +484,22 @@ def all_vars_dict(self):
variable names grouped by process.
"""
if self._all_vars_dict is None:
all_vars = defaultdict(list)
return self._get_vars_dict_from_cache("_all_vars")

for p_name, var_name in self._all_vars:
all_vars[p_name].append(var_name)
def index_vars(self):
"""Returns all index variables in the model as a list of
``(process_name, var_name)`` tuples (or an empty list).
self._all_vars_dict = dict(all_vars)
"""
return self._index_vars

@property
def index_vars_dict(self):
"""Returns all index variables in the model as a dictionary of lists of
variable names grouped by process.
return self._all_vars_dict
"""
return self._get_vars_dict_from_cache("_index_vars")

@property
def input_vars(self):
Expand All @@ -494,15 +521,7 @@ def input_vars_dict(self):
variable names grouped by process is returned.
"""
if self._input_vars_dict is None:
inputs = defaultdict(list)

for p_name, var_name in self._input_vars:
inputs[p_name].append(var_name)

self._input_vars_dict = dict(inputs)

return self._input_vars_dict
return self._get_vars_dict_from_cache("_input_vars")

@property
def dependent_processes(self):
Expand Down
1 change: 1 addition & 0 deletions xsimlab/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ class _ProcessBuilder:

_make_prop_funcs = {
VarType.VARIABLE: _make_property_variable,
VarType.INDEX: _make_property_variable,
VarType.ON_DEMAND: _make_property_on_demand,
VarType.FOREIGN: _make_property_variable,
VarType.GROUP: _make_property_group,
Expand Down
6 changes: 6 additions & 0 deletions xsimlab/tests/fixture_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ class InitProfile:
n_points = xs.variable(
description="nb. of profile points", converter=int, static=True
)

x = xs.index(dims="x")
u = xs.foreign(Profile, "u", intent="out")

def initialize(self):
self.x = np.arange(self.n_points)

self.u = np.zeros(self.n_points)
self.u[0] = 1.0

Expand Down Expand Up @@ -193,4 +197,6 @@ def out_dataset(in_dataset):
)
out_ds["add__u_diff"] = ("out", [1, 3, 4])

out_ds["x"] = ("x", [0.0, 1.0, 2.0, 3.0, 4.0])

return out_ds
13 changes: 12 additions & 1 deletion xsimlab/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ def test_bind_processes(self, model):
[
(
"init_profile",
{"n_points": ("init_profile", "n_points"), "u": ("profile", "u")},
{
"n_points": ("init_profile", "n_points"),
"x": ("init_profile", "x"),
"u": ("profile", "u"),
},
{},
),
(
Expand Down Expand Up @@ -164,6 +168,13 @@ def test_all_vars_dict(self, model):
)
assert "u" in model.all_vars_dict["profile"]

def test_index_vars_dict(self, model):
assert all([p_name in model for p_name in model.index_vars_dict])
assert all(
[isinstance(p_vars, list) for p_vars in model.index_vars_dict.values()]
)
assert "x" in model.index_vars_dict["init_profile"]

def test_input_vars_dict(self, model):
assert all([p_name in model for p_name in model.input_vars_dict])
assert all(
Expand Down
7 changes: 6 additions & 1 deletion xsimlab/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import attr

from xsimlab.tests.fixture_process import AnotherProcess, ExampleProcess
from xsimlab.variable import _as_dim_tuple, _as_group_tuple, foreign
from xsimlab.variable import _as_dim_tuple, _as_group_tuple, foreign, index


@pytest.mark.parametrize(
Expand Down Expand Up @@ -52,6 +52,11 @@ def test_as_group_tuple(groups, group, expected):
assert actual == expected


def test_index():
with pytest.raises(ValueError, match=r".*not accept scalar values.*"):
index(())


def test_foreign():
with pytest.raises(ValueError) as excinfo:
foreign(ExampleProcess, "some_var", intent="inout")
Expand Down
52 changes: 50 additions & 2 deletions xsimlab/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class VarType(Enum):
VARIABLE = "variable"
INDEX = "index"
ON_DEMAND = "on_demand"
FOREIGN = "foreign"
GROUP = "group"
Expand Down Expand Up @@ -124,7 +125,6 @@ def variable(
tuple corresponds to a 1-d variable and a n-length tuple corresponds to
a n-d variable. A list of str or tuple items may also be provided if
the variable accepts different numbers of dimensions.
This should not include a time dimension, which may always be added.
intent : {'in', 'out', 'inout'}, optional
Defines whether the variable is an input (i.e., the process needs the
variable's value for its computation), an output (i.e., the process
Expand Down Expand Up @@ -205,6 +205,55 @@ def variable(
)


def index(dims, groups=None, description="", attrs=None):
"""Create a variable aimed at indexing data.
The process class in which this variable is declared should set its value
(i.e., intent='out') with an index object or an index-compatible object. For
example, xarray may accept 1-d arrays, :class:`pandas.Index`,
:class:`pandas.MultiIndex`, etc.
As a simple example, index variable(s) should be used for setting coordinate
labels along the dimension(s) of a cartesian grid.
Parameters
----------
dims : str or tuple or list, optional
Dimension label(s) of the variable. A string or a 1-length
tuple corresponds to a 1-d variable and a n-length tuple corresponds to
a n-d variable. A list of str or tuple items may also be provided if
the variable accepts different numbers of dimensions. Note that an index
variable does not accept scalar values.
groups : str or list, optional
Variable group(s).
description : str, optional
Short description of the variable.
attrs : dict, optional
Dictionnary of additional metadata (e.g., standard_name,
units, math_symbol...).
See Also
--------
:func:`variable`
"""
dims = _as_dim_tuple(dims)

if tuple() in dims:
raise ValueError("An index variable does not accept scalar values")

metadata = {
"var_type": VarType.INDEX,
"dims": dims,
"intent": VarIntent.OUT,
"groups": _as_group_tuple(groups, None),
"attrs": attrs or {},
"description": description,
}

return attr.attrib(metadata=metadata, init=False, repr=False)


def on_demand(dims=(), group=None, groups=None, description="", attrs=None):
"""Create a variable that is computed on demand.
Expand Down Expand Up @@ -233,7 +282,6 @@ def on_demand(dims=(), group=None, groups=None, description="", attrs=None):
tuple corresponds to a 1-d variable and a n-length tuple corresponds to
a n-d variable. A list of str or tuple items may also be provided if
the variable accepts different numbers of dimensions.
This should not include a time dimension, which may always be added.
group : str, optional
Variable group (depreciated, use ``groups`` instead).
groups : str or list, optional
Expand Down

0 comments on commit b2753bc

Please sign in to comment.