In [2]:
import pandas as pd
import xarray as xr

# Utils
from sheerwater_benchmarking.utils import start_remote, salient_secret, clip_region, get_region
from sheerwater_benchmarking.metrics import summary_metrics_table_salient

# Salient functions
from salientsdk.skill import _crps_core
import salientsdk as sk

In [41]:
start_remote(remote_name='genevieve', remote_config='xlarge_cluster')

Output()

## Run tests on Salient evaluation period, for a specific variable and lead

In [135]:
start_time = '2015-01-01'
end_time = '2022-12-31'
variable = 'tmp2m'
var = {"tmp2m": "temp", "precip": "precip"}[variable] # salient naming
metric = 'crps'
region = 'africa'
mask = None
lead = 'week3'
timescale = 'sub-seasonal'

## Pull both forecasts and gt directly from the bucket

In [136]:
filename = f'gs://sheerwater-datalake/salient-data/v9/africa/{var}_{timescale}/blend'
fcst_ds = xr.open_zarr(filename)
fcst_ds = fcst_ds['vals'].to_dataset()
fcst_ds = fcst_ds.rename(vals=variable)
fcst_ds = fcst_ds.sel(forecast_date=slice(start_time, end_time))
fcst_ds

Unnamed: 0,Array,Chunk
Bytes,44.19 GiB,41.59 MiB
Shape,"(1088, 5, 300, 316, 23)","(1, 5, 300, 316, 23)"
Dask graph,1088 chunks in 3 graph layers,1088 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 44.19 GiB 41.59 MiB Shape (1088, 5, 300, 316, 23) (1, 5, 300, 316, 23) Dask graph 1088 chunks in 3 graph layers Data type float32 numpy.ndarray",5  1088  23  316  300,

Unnamed: 0,Array,Chunk
Bytes,44.19 GiB,41.59 MiB
Shape,"(1088, 5, 300, 316, 23)","(1, 5, 300, 316, 23)"
Dask graph,1088 chunks in 3 graph layers,1088 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [137]:
filename = f'gs://sheerwater-datalake/salient-data/v9/africa/{var}_{timescale}/truth'
gt_ds = xr.open_zarr(filename)
gt_ds = gt_ds['vals_actual'].to_dataset()
gt_ds = gt_ds.rename(vals_actual=variable)
gt_ds = gt_ds.sel(forecast_date=fcst_ds.forecast_date)
gt_ds

Unnamed: 0,Array,Chunk
Bytes,1.92 GiB,1.81 MiB
Shape,"(1088, 5, 300, 316)","(1, 5, 300, 316)"
Dask graph,1088 chunks in 3 graph layers,1088 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.92 GiB 1.81 MiB Shape (1088, 5, 300, 316) (1, 5, 300, 316) Dask graph 1088 chunks in 3 graph layers Data type float32 numpy.ndarray",1088  1  316  300  5,

Unnamed: 0,Array,Chunk
Bytes,1.92 GiB,1.81 MiB
Shape,"(1088, 5, 300, 316)","(1, 5, 300, 316)"
Dask graph,1088 chunks in 3 graph layers,1088 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [138]:
point_lon, point_lat = (25.125, 0.125)

met_ds = _crps_core(observations=gt_ds, forecasts=fcst_ds, qnt_dim='quantiles')
met_ds = met_ds.mean('forecast_date')
# met = met_ds.sel(lon=point_lon, lat=point_lat)[metric].compute()
met_ds = clip_region(met_ds, region)
met = met_ds.mean(['lat', 'lon'])[metric].compute()

In [139]:
met_df = met.to_pandas()

In [140]:
lons, lats, mask = get_region(region)
filename = f'{region}.geojson'
file = mask.to_file(filename, driver="GeoJSON")

In [141]:
username, password = salient_secret()
sk.login(username, password)
file = sk.upload_file(filename)
loc = sk.Location(shapefile=filename)

Uploading africa.geojson
Successfully uploaded africa.geojson to shapefiles



In [142]:
# The variable that we'll be evaluating.
fld = "vals"
timescale = "sub-seasonal"
ref_model = "clim"  # Works across all timescale values.

skill_summ = pd.read_csv(
    sk.hindcast_summary(
        loc=loc,
        metric="crps",
        variable=var,
        timescale=timescale,
        reference=ref_model,
        split_set="test",
        force=True
    )
)

In [143]:
skill_summ['Salient Bucket CRPS'] = met_df.values
skill_summ['Bucket Diff (%)'] = 100 * (skill_summ['Salient CRPS'] - skill_summ['Salient Bucket CRPS']) /  skill_summ['Salient CRPS']

In [144]:
# Get value from our cached table [should match exactly for week 3]
tab = summary_metrics_table_salient(start_time, end_time, variable, 
                           truth='salient_era5', metric=metric,                            
                           grid='salient0_25', mask='lsm', region="africa")
tab = tab.set_index('forecast')
# Divide by 7, to convert weekly totals to daily averages 
if variable == 'precip':
    div = 7.
else:
    div = 1.
tab = (tab[['week1', 'week2', 'week3', 'week4', 'week5']] / div).T

Found cache for gs://sheerwater-datalake/caches/summary_metrics_table_salient/None_2022-12-31_salient0_25_lsm_crps_africa_2015-01-01_None_salient_era5_tmp2m.delta
Opening cache gs://sheerwater-datalake/caches/summary_metrics_table_salient/None_2022-12-31_salient0_25_lsm_crps_africa_2015-01-01_None_salient_era5_tmp2m.delta


In [145]:
skill_summ['Salient Nimbus CRPS'] = tab['salient'].values
skill_summ['Nimbus Diff (%)'] = 100. * (skill_summ['Salient CRPS'] - skill_summ['Salient Nimbus CRPS']) /  skill_summ['Salient CRPS']

In [146]:
# Get value from our cached table [should match exactly for week 3]
tab_regrid = summary_metrics_table_salient(start_time, end_time, variable, 
                           truth='era5', metric=metric,                            
                           grid='global0_25', mask='lsm', region="africa")
tab_regrid = tab_regrid.set_index('forecast')
# Divide by 7, to convert weekly totals to daily averages 
if variable == 'precip':
    div = 7.
else:
    div = 1.
tab_regrid = (tab_regrid[['week1', 'week2', 'week3', 'week4', 'week5']] / div).T

Found cache for gs://sheerwater-datalake/caches/summary_metrics_table_salient/None_2022-12-31_global0_25_lsm_crps_africa_2015-01-01_None_era5_tmp2m.delta
Opening cache gs://sheerwater-datalake/caches/summary_metrics_table_salient/None_2022-12-31_global0_25_lsm_crps_africa_2015-01-01_None_era5_tmp2m.delta


In [147]:
skill_summ['Salient Regrid CRPS'] = tab_regrid['salient'].values
skill_summ['Regrid Diff (%)'] = 100. * (skill_summ['Salient CRPS'] - skill_summ['Salient Regrid CRPS']) /  skill_summ['Salient CRPS']

In [148]:
print('Variable:', variable)
skill_summ[['Lead', 'Salient CRPS',  'Salient Bucket CRPS', 'Salient Nimbus CRPS', 'Salient Regrid CRPS', 'Bucket Diff (%)', 'Nimbus Diff (%)', 'Regrid Diff (%)']]

Variable: tmp2m


Unnamed: 0,Lead,Salient CRPS,Salient Bucket CRPS,Salient Nimbus CRPS,Salient Regrid CRPS,Bucket Diff (%),Nimbus Diff (%),Regrid Diff (%)
0,Week 1,0.31,0.315544,0.315572,0.352178,-1.788509,-1.797458,-13.605907
1,Week 2,0.57,0.572545,0.572613,0.598339,-0.446512,-0.458481,-4.971696
2,Week 3,0.71,0.717065,0.717157,0.737013,-0.995082,-1.008082,-3.804655
3,Week 4,0.74,0.751504,0.751599,0.767636,-1.554571,-1.567438,-3.734638
4,Week 5,0.76,0.762229,0.762324,0.777164,-0.293233,-0.305742,-2.258389


In [81]:
print('Variable:', variable)
skill_summ[['Lead', 'Salient CRPS', 'Salient Nimbus CRPS', 'Salient Regrid CRPS', 'Nimbus Diff (%)', 'Regrid Diff (%)']]

Variable: tmp2m


Unnamed: 0,Lead,Salient CRPS,Salient Nimbus CRPS,Salient Regrid CRPS,Nimbus Diff (%),Regrid Diff (%)
0,Week 1,0.31,0.315544,0.352178,0.0,-10.402099
1,Week 2,0.57,0.572545,0.598339,0.0,-4.310861
2,Week 3,0.71,0.717065,0.737013,0.0,-2.706597
3,Week 4,0.74,0.751504,0.767636,0.0,-2.101581
4,Week 5,0.76,0.762229,0.777164,0.0,-1.921755
