Skip to content

Commit

Permalink
Merge 38eb7cf into 2f6e775
Browse files Browse the repository at this point in the history
  • Loading branch information
Spencer Hill committed Sep 23, 2019
2 parents 2f6e775 + 38eb7cf commit 56100ed
Show file tree
Hide file tree
Showing 12 changed files with 331 additions and 235 deletions.
1 change: 1 addition & 0 deletions .stickler.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
linters:
flake8:
fixer: true
python: 3
18 changes: 2 additions & 16 deletions aospy/automate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,9 @@ def _merge_dicts(*dict_args):
return result


def _input_func_py2_py3():
"""Find function for reading user input that works on Python 2 and 3.
See e.g. http://stackoverflow.com/questions/21731043
"""
try:
input = raw_input
except NameError:
import builtins
input = builtins.input
return input


def _user_verify(input_func=_input_func_py2_py3(),
prompt='Perform these computations? [y/n] '):
def _user_verify(prompt='Perform these computations? [y/n] '):
"""Prompt the user for verification."""
if not input_func(prompt).lower()[0] == 'y':
if not input(prompt).lower()[0] == 'y':
raise AospyException('Execution cancelled by user.')


Expand Down
65 changes: 18 additions & 47 deletions aospy/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
import pprint
import warnings

import numpy as np
import xarray as xr
Expand All @@ -11,7 +10,6 @@
ETA_STR,
GRID_ATTRS,
TIME_STR,
TIME_BOUNDS_STR,
)
from .utils import times, io

Expand Down Expand Up @@ -112,6 +110,7 @@ def set_grid_attrs_as_coords(ds):
-------
Dataset
Dataset with grid attributes set as coordinates
"""
grid_attrs_in_ds = set(GRID_ATTRS.keys()).intersection(
set(ds.coords) | set(ds.data_vars))
Expand All @@ -130,6 +129,7 @@ def _maybe_cast_to_float64(da):
Returns
-------
DataArray
"""
if da.dtype == np.float32:
logging.warning('Datapoints were stored using the np.float32 datatype.'
Expand Down Expand Up @@ -162,6 +162,7 @@ def _sel_var(ds, var, upcast_float32=True):
------
KeyError
If the variable is not in the Dataset
"""
for name in var.names:
try:
Expand All @@ -176,46 +177,6 @@ def _sel_var(ds, var, upcast_float32=True):
raise LookupError(msg)


def _prep_time_data(ds):
"""Prepare time coordinate information in Dataset for use in aospy.
1. If the Dataset contains a time bounds coordinate, add attributes
representing the true beginning and end dates of the time interval used
to construct the Dataset
2. If the Dataset contains a time bounds coordinate, overwrite the time
coordinate values with the averages of the time bounds at each timestep
3. Decode the times into np.datetime64 objects for time indexing
Parameters
----------
ds : Dataset
Pre-processed Dataset with time coordinate renamed to
internal_names.TIME_STR
Returns
-------
Dataset
The processed Dataset
"""
ds = times.ensure_time_as_index(ds)
if TIME_BOUNDS_STR in ds:
ds = times.ensure_time_avg_has_cf_metadata(ds)
ds[TIME_STR] = times.average_time_bounds(ds)
else:
logging.warning("dt array not found. Assuming equally spaced "
"values in time, even though this may not be "
"the case")
ds = times.add_uniform_time_weights(ds)
# Suppress enable_cftimeindex is a no-op warning; we'll keep setting it for
# now to maintain backwards compatibility for older xarray versions.
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
with xr.set_options(enable_cftimeindex=True):
ds = xr.decode_cf(ds, decode_times=True, decode_coords=False,
mask_and_scale=True)
return ds


def _load_data_from_disk(file_set, preprocess_func=lambda ds: ds,
data_vars='minimal', coords='minimal',
grid_attrs=None, **kwargs):
Expand Down Expand Up @@ -243,14 +204,21 @@ def _load_data_from_disk(file_set, preprocess_func=lambda ds: ds,
Returns
-------
Dataset
"""
apply_preload_user_commands(file_set)
func = _preprocess_and_rename_grid_attrs(preprocess_func, grid_attrs,
**kwargs)
return xr.open_mfdataset(file_set, preprocess=func, concat_dim=TIME_STR,
decode_times=False, decode_coords=False,
mask_and_scale=True, data_vars=data_vars,
coords=coords)
return xr.open_mfdataset(
file_set,
preprocess=func,
combine='by_coords',
decode_times=False,
decode_coords=False,
mask_and_scale=True,
data_vars=data_vars,
coords=coords,
)


def apply_preload_user_commands(file_set, cmd=io.dmget):
Expand All @@ -259,6 +227,7 @@ def apply_preload_user_commands(file_set, cmd=io.dmget):
For example, on the NOAA Geophysical Fluid Dynamics Laboratory
computational cluster, data that is saved on their tape archive
must be accessed via a `dmget` (or `hsmget`) command before being used.
"""
if cmd is not None:
cmd(file_set)
Expand Down Expand Up @@ -301,6 +270,7 @@ def load_variable(self, var=None, start_date=None, end_date=None,
-------
da : DataArray
DataArray for the specified variable, date range, and interval in
"""
file_set = self._generate_file_set(var=var, start_date=start_date,
end_date=end_date, **DataAttrs)
Expand All @@ -310,7 +280,7 @@ def load_variable(self, var=None, start_date=None, end_date=None,
time_offset=time_offset, grid_attrs=grid_attrs, **DataAttrs
)
if var.def_time:
ds = _prep_time_data(ds)
ds = times.prep_time_data(ds)
start_date = times.maybe_convert_to_index_date_type(
ds.indexes[TIME_STR], start_date)
end_date = times.maybe_convert_to_index_date_type(
Expand All @@ -330,6 +300,7 @@ def _load_or_get_from_model(self, var, start_date=None, end_date=None,
Supports both access of grid attributes either through the DataLoader
or through an optionally-provided Model object. Defaults to using
the version found in the DataLoader first.
"""
grid_attrs = None if model is None else model.grid_attrs

Expand Down
3 changes: 2 additions & 1 deletion aospy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def _get_grid_files(self):
try:
ds = xr.open_dataset(path, decode_times=False)
except (TypeError, AttributeError):
ds = xr.open_mfdataset(path, decode_times=False).load()
ds = xr.open_mfdataset(path, decode_times=False,
combine='by_coords').load()
except (RuntimeError, OSError) as e:
msg = str(e) + ': {}'.format(path)
raise RuntimeError(msg)
Expand Down
68 changes: 68 additions & 0 deletions aospy/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""pytest conftest.py file for sharing fixtures across modules."""
import datetime

from cftime import DatetimeNoLeap
import numpy as np
import pytest
import xarray as xr

from aospy.internal_names import (
LON_STR,
TIME_STR,
TIME_BOUNDS_STR,
BOUNDS_STR,
)


_DATE_RANGES = {
'datetime': (datetime.datetime(2000, 1, 1),
datetime.datetime(2002, 12, 31)),
'datetime64': (np.datetime64('2000-01-01'),
np.datetime64('2002-12-31')),
'cftime': (DatetimeNoLeap(2000, 1, 1),
DatetimeNoLeap(2002, 12, 31)),
'str': ('2000', '2002')
}


@pytest.fixture()
def alt_lat_str():
return 'LATITUDE'


@pytest.fixture()
def var_name():
return 'a'


@pytest.fixture()
def ds_with_time_bounds(alt_lat_str, var_name):
time_bounds = np.array([[0, 31], [31, 59], [59, 90]])
bounds = np.array([0, 1])
time = np.array([15, 46, 74])
data = np.zeros((3, 1, 1))
lat = [0]
lon = [0]
ds = xr.DataArray(data,
coords=[time, lat, lon],
dims=[TIME_STR, alt_lat_str, LON_STR],
name=var_name).to_dataset()
ds[TIME_BOUNDS_STR] = xr.DataArray(time_bounds,
coords=[time, bounds],
dims=[TIME_STR, BOUNDS_STR],
name=TIME_BOUNDS_STR)
units_str = 'days since 2000-01-01 00:00:00'
ds[TIME_STR].attrs['units'] = units_str
ds[TIME_BOUNDS_STR].attrs['units'] = units_str
return ds


@pytest.fixture()
def ds_inst(ds_with_time_bounds):
inst_time = np.array([3, 6, 9])
inst_units_str = 'hours since 2000-01-01 00:00:00'
ds_inst = ds_with_time_bounds.copy()
ds_inst.drop(TIME_BOUNDS_STR)
ds_inst[TIME_STR].values = inst_time
ds_inst[TIME_STR].attrs['units'] = inst_units_str
return ds_inst
68 changes: 37 additions & 31 deletions aospy/test/test_automate.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,46 @@
from multiprocessing import cpu_count
from os.path import isfile
import shutil
import sys
import itertools
from unittest import mock

import distributed
import pytest

from aospy import Var, Proj
from aospy.automate import (_get_attr_by_tag, _permuted_dicts_of_specs,
_get_all_objs_of_type, _merge_dicts,
_input_func_py2_py3, AospyException,
_user_verify, CalcSuite, _MODELS_STR, _RUNS_STR,
_VARIABLES_STR, _REGIONS_STR,
_compute_or_skip_on_error, submit_mult_calcs,
_n_workers_for_local_cluster,
_prune_invalid_time_reductions)
from aospy.automate import (
_user_verify,
_MODELS_STR,
_RUNS_STR,
_VARIABLES_STR,
_REGIONS_STR,
_compute_or_skip_on_error,
_get_all_objs_of_type,
_get_attr_by_tag,
_merge_dicts,
_n_workers_for_local_cluster,
_permuted_dicts_of_specs,
_prune_invalid_time_reductions,
AospyException,
CalcSuite,
submit_mult_calcs,
)
from .data.objects import examples as lib
from .data.objects.examples import (
example_proj, example_model, example_run, var_not_time_defined,
condensation_rain, convection_rain, precip, ps, sphum, globe, sahel, bk,
p, dp
example_proj,
example_model,
example_run,
var_not_time_defined,
condensation_rain,
convection_rain,
precip,
ps,
sphum,
globe,
sahel,
bk,
p,
dp,
)


Expand Down Expand Up @@ -128,19 +148,12 @@ def test_merge_dicts():
assert expected == _merge_dicts(dict1, dict2, dict3, dict4)


def test_input_func_py2_py3():
result = _input_func_py2_py3()
if sys.version.startswith('3'):
import builtins
assert result is builtins.input
elif sys.version.startswith('2'):
assert result is raw_input # noqa: F821


def test_user_verify():
with mock.patch('builtins.input', return_value='YES'):
_user_verify()
with pytest.raises(AospyException):
_user_verify(lambda x: 'no')
_user_verify(lambda x: 'YES')
with mock.patch('builtins.input', return_value='no'):
_user_verify()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -235,8 +248,7 @@ def assert_calc_files_exist(calcs, write_to_tar, dtypes_out_time):
assert not isfile(calc.path_tar_out)


@pytest.mark.skipif(sys.version.startswith('2'),
reason='https://github.com/spencerahill/aospy/issues/259')
@pytest.mark.filterwarnings('ignore:Using or importing the ABCs from')
@pytest.mark.parametrize(
('exec_options'),
[dict(parallelize=True, write_to_tar=False),
Expand All @@ -251,8 +263,6 @@ def test_submit_mult_calcs_external_client(calcsuite_init_specs_single_calc,
calcsuite_init_specs_single_calc['output_time_regional_reductions'])


@pytest.mark.skipif(sys.version.startswith('2'),
reason='https://github.com/spencerahill/aospy/issues/259')
@pytest.mark.parametrize(
('exec_options'),
[dict(parallelize=False, write_to_tar=False),
Expand All @@ -278,8 +288,6 @@ def test_submit_mult_calcs_no_calcs(calcsuite_init_specs):
submit_mult_calcs(specs)


@pytest.mark.skipif(sys.version.startswith('2'),
reason='https://github.com/spencerahill/aospy/issues/259')
@pytest.mark.parametrize(
('exec_options'),
[dict(parallelize=True, write_to_tar=False),
Expand All @@ -294,8 +302,6 @@ def test_submit_two_calcs_external_client(calcsuite_init_specs_two_calcs,
calcsuite_init_specs_two_calcs['output_time_regional_reductions'])


@pytest.mark.skipif(sys.version.startswith('2'),
reason='https://github.com/spencerahill/aospy/issues/259')
@pytest.mark.parametrize(
('exec_options'),
[dict(parallelize=False, write_to_tar=False),
Expand Down

0 comments on commit 56100ed

Please sign in to comment.