In [1]:
import pandas as pd
import xarray as xr

# Utils
from sheerwater_benchmarking.utils import start_remote, salient_secret, clip_region, get_region
from sheerwater_benchmarking.metrics import summary_metrics_table_salient

# Salient functions
from salientsdk.skill import _crps_core
import salientsdk as sk

In [2]:
start_remote(remote_config='xlarge_cluster')

Output()

Output()

2024-12-03 08:53:11,921 - distributed.deploy.adaptive - INFO - Adaptive scaling started: minimum=15 maximum=16


## Run tests on Salient evaluation period, for a specific variable and lead

In [3]:
start_time = '2015-01-01'
end_time = '2022-12-31'
variable = 'precip'
var = {"tmp2m": "temp", "precip": "precip"}[variable] # salient naming
metric = 'crps'
region = 'africa'
mask = None
lead = 'week3'
timescale = 'sub-seasonal'

## Pull both forecasts and gt directly from the bucket

In [4]:
filename = f'gs://sheerwater-datalake/salient-data/v9/africa/{var}_{timescale}/blend'
fcst_ds = xr.open_zarr(filename)
fcst_ds = fcst_ds['vals'].to_dataset()
fcst_ds = fcst_ds.rename(vals=variable)
fcst_ds = fcst_ds.sel(forecast_date=slice(start_time, end_time))
fcst_ds

Unnamed: 0,Array,Chunk
Bytes,44.19 GiB,41.59 MiB
Shape,"(1088, 5, 300, 316, 23)","(1, 5, 300, 316, 23)"
Dask graph,1088 chunks in 3 graph layers,1088 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 44.19 GiB 41.59 MiB Shape (1088, 5, 300, 316, 23) (1, 5, 300, 316, 23) Dask graph 1088 chunks in 3 graph layers Data type float32 numpy.ndarray",5  1088  23  316  300,

Unnamed: 0,Array,Chunk
Bytes,44.19 GiB,41.59 MiB
Shape,"(1088, 5, 300, 316, 23)","(1, 5, 300, 316, 23)"
Dask graph,1088 chunks in 3 graph layers,1088 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [5]:
filename = f'gs://sheerwater-datalake/salient-data/v9/africa/{var}_{timescale}/truth'
gt_ds = xr.open_zarr(filename)
gt_ds = gt_ds['vals_actual'].to_dataset()
gt_ds = gt_ds.rename(vals_actual=variable)
gt_ds = gt_ds.sel(forecast_date=fcst_ds.forecast_date)
gt_ds

Unnamed: 0,Array,Chunk
Bytes,1.92 GiB,1.81 MiB
Shape,"(1088, 5, 300, 316)","(1, 5, 300, 316)"
Dask graph,1088 chunks in 3 graph layers,1088 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.92 GiB 1.81 MiB Shape (1088, 5, 300, 316) (1, 5, 300, 316) Dask graph 1088 chunks in 3 graph layers Data type float32 numpy.ndarray",1088  1  316  300  5,

Unnamed: 0,Array,Chunk
Bytes,1.92 GiB,1.81 MiB
Shape,"(1088, 5, 300, 316)","(1, 5, 300, 316)"
Dask graph,1088 chunks in 3 graph layers,1088 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [6]:
point_lon, point_lat = (25.125, 0.125)

met_ds = _crps_core(observations=gt_ds, forecasts=fcst_ds, qnt_dim='quantiles')
met_ds = met_ds.mean('forecast_date')
# met = met_ds.sel(lon=point_lon, lat=point_lat)[metric].compute()
met_ds = clip_region(met_ds, region)
met = met_ds.mean(['lat', 'lon'])[metric].compute()

In [7]:
met_df = met.to_pandas()

In [8]:
lons, lats, mask = get_region(region)
filename = f'{region}.geojson'
file = mask.to_file(filename, driver="GeoJSON")

In [9]:
username, password = salient_secret()
sk.login(username, password)
file = sk.upload_file(filename)
loc = sk.Location(shapefile=filename)

In [10]:
# The variable that we'll be evaluating.
fld = "vals"
timescale = "sub-seasonal"
ref_model = "clim"  # Works across all timescale values.

skill_summ = pd.read_csv(
    sk.hindcast_summary(
        loc=loc,
        metric="crps",
        variable=var,
        timescale=timescale,
        reference=ref_model,
        split_set="test",
        force=True
    )
)

In [11]:
skill_summ['Salient Bucket CRPS'] = met_df.values
skill_summ['Bucket Diff (%)'] = 100 * (skill_summ['Salient CRPS'] - skill_summ['Salient Bucket CRPS']) /  skill_summ['Salient CRPS']

In [12]:
# Get value from our cached table [should match exactly for week 3]
tab = summary_metrics_table_salient(start_time, end_time, variable, 
                           truth='salient_era5', metric=metric,                            
                           grid='salient0_25', mask='lsm', region="africa")
tab = tab.set_index('forecast')
# Divide by 7, to convert weekly totals to daily averages 
if variable == 'precip':
    div = 7.
else:
    div = 1.
tab = (tab[['week1', 'week2', 'week3', 'week4', 'week5']] / div).T

Found cache for gs://sheerwater-datalake/caches/summary_metrics_table_salient/2022-12-31_salient0_25_lsm_crps_africa_2015-01-01_None_salient_era5_precip.delta
Opening cache gs://sheerwater-datalake/caches/summary_metrics_table_salient/2022-12-31_salient0_25_lsm_crps_africa_2015-01-01_None_salient_era5_precip.delta


In [13]:
skill_summ['Salient Nimbus CRPS'] = tab['salient'].values
skill_summ['Nimbus Diff (%)'] = 100. * (skill_summ['Salient CRPS'] - skill_summ['Salient Nimbus CRPS']) /  skill_summ['Salient CRPS']

In [None]:
# Get value from our cached table [should match exactly for week 3]
tab_regrid = summary_metrics_table_salient(start_time, end_time, variable, 
                           truth='era5', metric=metric,                            
                           grid='global0_25', mask='lsm', region="africa")
tab_regrid = tab_regrid.set_index('forecast')
# Divide by 7, to convert weekly totals to daily averages 
if variable == 'precip':
    div = 7.
else:
    div = 1.
tab_regrid = (tab_regrid[['week1', 'week2', 'week3', 'week4', 'week5']] / div).T

In [None]:
skill_summ['Salient Regrid CRPS'] = tab_regrid['salient'].values
skill_summ['Regrid Diff (%)'] = 100. * (skill_summ['Salient CRPS'] - skill_summ['Salient Regrid CRPS']) /  skill_summ['Salient CRPS']

In [None]:
print('Variable:', variable)
skill_summ[['Lead', 'Salient CRPS',  'Salient Bucket CRPS', 'Salient Nimbus CRPS', 'Salient Regrid CRPS', 'Bucket Diff (%)', 'Nimbus Diff (%)', 'Regrid Diff (%)']]

In [74]:
print('Variable:', variable)
skill_summ[['Lead', 'Salient CRPS', 'Salient Nimbus CRPS', 'Salient Regrid CRPS', 'Nimbus Diff (%)', 'Regrid Diff (%)']]

Variable: precip


Unnamed: 0,Lead,Salient CRPS,Salient Nimbus CRPS,Salient Regrid CRPS,Nimbus Diff (%),Regrid Diff (%)
0,Week 1,0.43,0.426978,0.465573,0.702704,-8.27273
1,Week 2,0.55,0.544665,0.578166,0.970048,-5.121093
2,Week 3,0.6,0.594312,0.626025,0.94801,-4.337525
3,Week 4,0.62,0.611091,0.640858,1.436991,-3.364129
4,Week 5,0.63,0.617978,0.646853,1.908277,-2.675144
