Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aospy/__config__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
default_colormap = 'RdBu'
user_path = os.path.join(os.getenv('HOME'), 'aospy_user', 'aospy_user')

ETA_STR = 'sigma'
LON_STR = 'lon'
LAT_STR = 'lat'
LON_BOUNDS_STR = 'lon_bounds'
Expand Down
210 changes: 119 additions & 91 deletions aospy/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@

from . import Constant, Var
from .__config__ import (LAT_STR, LON_STR, LAT_BOUNDS_STR, LON_BOUNDS_STR,
PHALF_STR, PFULL_STR, PLEVEL_STR, TIME_STR, YEAR_STR)
PHALF_STR, PFULL_STR, PLEVEL_STR, TIME_STR, YEAR_STR,
ETA_STR)
from .io import (_data_in_label, _data_out_label, _ens_label, _yr_label, dmget,
data_in_name_gfdl)
from .timedate import TimeManager, _get_time
from .utils import (get_parent_attr, apply_time_offset, monthly_mean_ts,
monthly_mean_at_each_ind, pfull_from_ps,
monthly_mean_at_each_ind, pfull_from_ps, to_hpa,
to_pfull_from_phalf, dp_from_ps, dp_from_p, int_dp_g)


Expand Down Expand Up @@ -186,9 +187,7 @@ def _path_archive(self):

def _print_verbose(self, *args):
"""Print diagnostic message."""
if not self.verbose:
pass
else:
if self.verbose:
try:
print('{} {}'.format(args[0], args[1]),
'({})'.format(time.ctime()))
Expand Down Expand Up @@ -241,8 +240,8 @@ def _get_input_data_paths_one_dir(self, name, data_in_direc, n=0):
elif os.path.isfile(full):
paths.append(full)
else:
print("Warning: specified netCDF file `{}` "
"not found".format(nc))
warnings.warn("Warning: specified netCDF file `{}` "
"not found".format(nc))
# Remove duplicate entries.
files = list(set(paths))
files.sort()
Expand All @@ -261,7 +260,7 @@ def _get_input_data_paths_gfdl_dir_struct(self, name, data_in_direc,
dtype_lbl = self.dtype_in_time
if self.intvl_in == 'daily':
domain += '_daily'
if self.dtype_in_vert == 'sigma' and name != 'ps':
if self.dtype_in_vert == ETA_STR and name != 'ps':
domain += '_level'
if self.dtype_in_time == 'inst':
domain += '_inst'
Expand Down Expand Up @@ -374,9 +373,11 @@ def _add_grid_attributes(self, ds, n=0):
except ValueError:
ds = ds
ds = ds.set_coords(name_int)
if not ds[name_int].equals(model_attr):
warnings.warn("Model coordinates for '{}' "
"do not match those in Run".format(name_int))
if not np.array_equal(ds[name_int], model_attr):
msg = ("Model coordinates for '{}' do not match those in "
"Run: {} vs. {}".format(name_int, ds[name_int],
model_attr))
warnings.warn(msg)
else:
# Bring in coord from model object if it exists.
if model_attr is not None:
Expand All @@ -397,21 +398,11 @@ def _create_input_data_obj(self, var, start_date=False,
# specify what method to call to access the files on the filesystem.
dmget(paths)
ds_chunks = []
# 2015-10-16 S. Hill: Can we use the xarray.open_mfdataset function here
# instead of this logic of making individual datasets and then
# calling xarray.concat? Or does the year<1678 logic make this not
# possible?

# 2015-10-16 19:06:00 S. Clark: The year<1678 logic is independent of
# using xarray.open_mfdataset. The main reason I held off on using it
# here was that it opens a can of worms with regard to performance;
# we'd need to add some logic to make sure the data were chunked in a
# reasonable way (and logic to change the chunking if need be).
for file_ in paths:
test = xr.open_dataset(file_, decode_cf=False,
drop_variables=['time_bounds', 'nv',
'average_T1',
'average_T2'])
drop_variables=['time_bounds', 'nv',
'average_T1', 'average_T2'])
# Workaround for years < 1678 causing overflows.
if start_date.year < 1678:
for v in [TIME_STR]:
test[v].attrs['units'] = ('days since 1900-01-01 '
Expand Down Expand Up @@ -442,53 +433,59 @@ def _create_input_data_obj(self, var, start_date=False,
self.dt = (self._to_desired_dates(dt) /
np.timedelta64(1, 's'))
break
# At least one variable has to get us the pfull array, if its needed.
# At least one variable has to get us the pfull array, if it's needed.
if set_pfull:
try:
self.pfull = ds[PFULL_STR]
self.pfull_coord = ds[PFULL_STR]
except KeyError:
pass
return arr

def _get_pressure_from_p_coords(self, ps, name='p', n=0):
if np.any(self.pressure):
pressure = self.pressure
else:
pressure = self.model[n].level
if name == 'p':
return pressure
if name == 'dp':
return dp_from_p(pressure, ps)
raise ValueError("name must be 'p' or 'dp':"
"'{}'".format(name))

def _get_pressure_from_eta_coords(self, ps, name='p', n=0):
bk = self.model[n].bk
pk = self.model[n].pk
pfull_coord = self.model[n].pfull
if name == 'p':
return pfull_from_ps(bk, pk, ps, pfull_coord)
if name == 'dp':
return dp_from_ps(bk, pk, ps, pfull_coord)
raise ValueError("name must be 'p' or 'dp':"
"'{}'".format(name))

def _get_pressure_vals(self, var, start_date, end_date, n=0):
"""Get pressure array, whether sigma or standard levels."""
ps = self._create_input_data_obj(self.ps, start_date, end_date)

try:
ps = self._ps_data
except AttributeError:
self._ps_data = self._create_input_data_obj(self.ps, start_date,
end_date)
ps = self._ps_data
if self.dtype_in_vert == 'pressure':
if np.any(self.pressure):
pressure = self.pressure
else:
pressure = self.model[n].level
if var.name == 'p':
data = pressure
elif var.name == 'dp':
data = dp_from_p(pressure, ps)

elif self.dtype_in_vert == 'sigma':
bk = self.model[n].bk
pk = self.model[n].pk
pfull_coord = self.model[n].pfull
if var.name == 'p':
data = pfull_from_ps(bk, pk, ps, pfull_coord)
elif var.name == 'dp':
data = dp_from_ps(bk, pk, ps, pfull_coord)
else:
raise ValueError("var.name must be 'p' or 'dp':"
"'{}'".format(var.name))
else:
raise ValueError("`dtype_in_vert` must be either 'pressure' or "
"'sigma' for pressure data")
return data
return self._get_pressure_from_p_coords(ps, name=var.name, n=n)
if self.dtype_in_vert == ETA_STR:
return self._get_pressure_from_eta_coords(ps, name=var.name, n=n)
raise ValueError("`dtype_in_vert` must be either 'pressure' or "
"'sigma' for pressure data")

def _correct_gfdl_inst_time(self, arr):
"""Correct off-by-one error in GFDL instantaneous model data."""
time = arr[TIME_STR]
if self.intvl_in == '3hr':
offset = -3
elif self.intvl_in == '6hr':
offset = -6
time = apply_time_offset(time, hours=offset)
arr[TIME_STR] = time
if self.intvl_in.endswith('hr'):
offset = -1*int(self.intvl_in[0])
else:
raise NotImplementedError
arr[TIME_STR] = apply_time_offset(arr[TIME_STR], hours=offset)
return arr

def _get_input_data(self, var, start_date, end_date, n):
Expand All @@ -507,7 +504,7 @@ def _get_input_data(self, var, start_date, end_date, n):
# Pressure handled specially due to complications from sigma vs. p.
elif var.name in ('p', 'dp'):
data = self._get_pressure_vals(var, start_date, end_date)
if self.dtype_in_vert == 'sigma':
if self.dtype_in_vert == ETA_STR:
if self.dtype_in_time == 'inst':
data = self._correct_gfdl_inst_time(data)
return self._to_desired_dates(data)
Expand All @@ -518,15 +515,14 @@ def _get_input_data(self, var, start_date, end_date, n):
data = getattr(self.model[n], var.name)
else:
set_dt = True if not hasattr(self, 'dt') else False
cond_pfull = (not hasattr(self, 'pfull') and var.def_vert and
self.dtype_in_vert == 'sigma')
set_pfull = True if cond_pfull else False
cond_pfull = ((not hasattr(self, 'pfull')) and var.def_vert and
self.dtype_in_vert == ETA_STR)
data = self._create_input_data_obj(var, start_date, end_date, n=n,
set_dt=set_dt,
set_pfull=set_pfull)
set_pfull=cond_pfull)
# Force all data to be at full pressure levels, not half levels.
if self.dtype_in_vert == 'sigma' and var.def_vert == 'phalf':
data = to_pfull_from_phalf(data, self.pfull)
if self.dtype_in_vert == ETA_STR and var.def_vert == 'phalf':
data = to_pfull_from_phalf(data, self.pfull_coord)
# Correct GFDL instantaneous data time indexing problem.
if var.def_time:
if self.dtype_in_time == 'inst':
Expand Down Expand Up @@ -586,7 +582,9 @@ def _compute(self, data_in, monthly_mean=False):
local_ts = self._local_ts(*data_in)
if self.dtype_in_time == 'inst':
dt = xr.DataArray(np.ones(np.shape(local_ts[TIME_STR])),
dims=[TIME_STR], coords=[local_ts[TIME_STR]])
dims=[TIME_STR], coords=[local_ts[TIME_STR]])
if not hasattr(self, 'dt'):
self.dt = dt
else:
dt = self.dt
if monthly_mean:
Expand Down Expand Up @@ -641,6 +639,15 @@ def _time_reduce(self, arr, reduction):

def region_calcs(self, arr, func, n=0):
"""Perform a calculation for all regions."""
# Get pressure values for data output on hybrid vertical coordinates.
bool_pfull = (self.def_vert and self.dtype_in_vert == ETA_STR and
self.dtype_out_vert is False)
if bool_pfull:
pfull = self._full_to_yearly_ts(self._prep_data(
self._get_input_data(Var('p'), self.start_date, self.end_date,
0), self.var.func_input_dtype
), self.dt).rename('pressure')
# Loop over the regions, performing the calculation.
reg_dat = {}
for reg in self.region.values():
# Just pass along the data if averaged already.
Expand All @@ -650,9 +657,36 @@ def region_calcs(self, arr, func, n=0):
else:
method = getattr(reg, func)
data_out = method(arr)
reg_dat.update({reg.name: data_out})
if bool_pfull:
# Don't apply e.g. standard deviation to coordinates.
if func not in ['av', 'ts']:
method = reg.ts
# Convert Pa to hPa
coord = method(pfull) * 1e-2
data_out = data_out.assign_coords(
**{reg.name + '_pressure': coord}
)
reg_dat.update(**{reg.name: data_out})
return reg_dat

def _apply_all_time_reductions(self, full_ts, monthly_ts, eddy_ts):
# Determine which are regional, eddy, time-mean.
reduc_specs = [r.split('.') for r in self.dtype_out_time]
reduced = {}
for reduc, specs in zip(self.dtype_out_time, reduc_specs):
func = specs[-1]
if 'eddy' in specs:
data = eddy_ts
elif 'time-mean' in specs:
data = monthly_ts
else:
data = full_ts
if 'reg' in specs:
reduced.update({reduc: self.region_calcs(data, func)})
else:
reduced.update({reduc: self._time_reduce(data, func)})
return reduced

def compute(self):
"""Perform all desired calculations on the data and save externally."""
# Load the input data from disk.
Expand Down Expand Up @@ -690,24 +724,10 @@ def compute(self):
# Apply time reduction methods.
if self.def_time:
self._print_verbose("Applying desired time-reduction methods.")
# Determine which are regional, eddy, time-mean.
reduc_specs = [r.split('.') for r in self.dtype_out_time]
reduced = {}
for reduc, specs in zip(self.dtype_out_time, reduc_specs):
func = specs[-1]
if 'eddy' in specs:
data = eddy_ts
elif 'time-mean' in specs:
data = monthly_ts
else:
data = full_ts
if 'reg' in specs:
reduced.update({reduc: self.region_calcs(data, func)})
else:
reduced.update({reduc: self._time_reduce(data, func)})
reduced = self._apply_all_time_reductions(full_ts, monthly_ts,
eddy_ts)
else:
reduced = {'': full_ts}

# Save to disk.
self._print_verbose("Writing desired gridded outputs to disk.")
for dtype_time, data in reduced.items():
Expand Down Expand Up @@ -787,7 +807,18 @@ def _load_from_scratch(self, dtype_out_time, dtype_out_vert=False,
ds = xr.open_dataset(self.path_scratch[dtype_out_time],
engine='scipy')
if region:
return ds[region.name]
arr = ds[region.name]
# Use region-specific pressure values if available.
if self.dtype_in_vert == ETA_STR and not dtype_out_vert:
try:
reg_pfull_str = region.name + '_pressure'
arr = arr.drop([r for r in arr.coords.iterkeys()
if r not in (PFULL_STR, reg_pfull_str)])
except ValueError:
return arr
else:
arr = arr.rename({PFULL_STR: PFULL_STR + '_ref'})
return arr.rename({reg_pfull_str: PFULL_STR})
return ds[self.name]

def _load_from_archive(self, dtype_out_time, dtype_out_vert=False):
Expand All @@ -804,17 +835,14 @@ def _load_from_archive(self, dtype_out_time, dtype_out_vert=False):
def _get_data_subset(self, data, region=False, time=False,
vert=False, lat=False, lon=False, n=0):
"""Subset the data array to the specified time/level/lat/lon, etc."""
# if region:
# if type(region) is str:
# data = data[region]
# elif type(region) is Region:
# data = data[region.name]
if region:
raise NotImplementedError
if np.any(time):
data = data[time]
if 'monthly_from_' in self.dtype_in_time:
data = np.mean(data, axis=0)[np.newaxis, :]
if np.any(vert):
if self.dtype_in_vert != 'sigma':
if self.dtype_in_vert != ETA_STR:
if np.max(self.model[n].level) > 1e4:
# Convert from Pa to hPa.
lev_hpa = self.model[n].level*1e-2
Expand Down
Loading