Skip to content

Commit

Permalink
updates per review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
dgergel committed Apr 20, 2021
1 parent b4e0262 commit 458f292
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions skdownscale/spatial_models/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,16 @@ class SpatialDisaggregator:
temperature and other option is precipitation.
"""

def __init__(self, var='temperature'):
def __init__(self, var="temperature"):
self._var = var

if var == 'temperature':
pass
elif var == 'precipitation':
pass
else:

if var not in ['temperature', 'precipitation']:
raise NotImplementedError(
'functionality for spatial disaggregation' ' of %s has not yet been added' % var
"functionality for spatial disaggregation"
" of %s has not yet been added" % var
)

def fit(self, ds_bc, climo_coarse, var_name, lat_name='lat', lon_name='lon'):
def fit(self, ds_bc, climo_coarse, var_name, lat_name="lat", lon_name="lon"):
"""
Fit the scaling factor used in spatial disaggregation
Expand All @@ -52,16 +49,16 @@ def fit(self, ds_bc, climo_coarse, var_name, lat_name='lat', lon_name='lon'):
"""

# check that climo has been regridded to model res
if not np.array_equal(ds_bc[lat_name], climo_coarse[lat_name]):
raise ValueError('climo latitude dimension does not match model res')
if not np.array_equal(ds_bc[lon_name], climo_coarse[lon_name]):
raise ValueError('climo longitude dimension does not match model res')
if not ds_bc[lat_name].equals(climo_coarse[lat_name]):
raise ValueError("climo latitude dimension does not match model res")
if not ds_bc[lon_name].equals(climo_coarse[lon_name]):
raise ValueError("climo longitude dimension does not match model res")

scf = self._calculate_scaling_factor(ds_bc, climo_coarse, var_name, self._var)

return scf

def predict(self, scf, climo_fine, var_name, lat_name='lat', lon_name='lon'):
def predict(self, scf, climo_fine, var_name, lat_name="lat", lon_name="lon"):
"""
Predict (apply) the scaling factor to the observed climatology.
Expand All @@ -85,9 +82,9 @@ def predict(self, scf, climo_fine, var_name, lat_name='lat', lon_name='lon'):

# check that scale factor has been regridded to obs res
if not np.array_equal(scf[lat_name], climo_fine[lat_name]):
raise ValueError('scale factor latitude dimension does not match obs res')
raise ValueError("scale factor latitude dimension does not match obs res")
if not np.array_equal(scf[lon_name], climo_fine[lon_name]):
raise ValueError('scale factor longitude dimension does not match obs res')
raise ValueError("scale factor longitude dimension does not match obs res")

downscaled = self._apply_scaling_factor(scf, climo_fine, var_name, self._var)

Expand All @@ -105,9 +102,9 @@ def _calculate_scaling_factor(self, ds_bc, climo, var_name, var):
groupby_type = ds_bc.time.dt.dayofyear
gb = da.groupby(groupby_type)

if var == 'temperature':
if var == "temperature":
return gb - climo
elif var == 'precipitation':
elif var == "precipitation":
return gb / climo

def _apply_scaling_factor(self, scf, climo, var_name, var):
Expand All @@ -118,7 +115,7 @@ def _apply_scaling_factor(self, scf, climo, var_name, var):
da = scf[var_name]
sff_daily = da.groupby(groupby_type)

if var == 'temperature':
if var == "temperature":
return sff_daily + climo
elif var == 'precipitation':
elif var == "precipitation":
return sff_daily * climo

0 comments on commit 458f292

Please sign in to comment.