# Create biweekly renku datasets from `climetlab-s2s-ai-challenge`

adapted from renku_datasets_biweekly

* downloads climetlab data
* computes weekly and biweekly averages and saves it to zarr according to challenge naming conventions. File names are appended by the variable name to avoid overwriting existing files.

For me: run this in s2s-ai-new3-copy

More information:

https://confluence.ecmwf.int/display/S2S/Parameters

https://github.com/ecmwf-lab/climetlab-s2s-ai-challenge

download settings: 

https://climetlab.readthedocs.io/en/latest/guide/settings.html

in particular you can change timeout time when downloading from an url.

In [None]:
import matplotlib.pyplot as plt
import xarray as xr
import xskillscore as xs
import pandas as pd

import climetlab_s2s_ai_challenge
import climetlab as cml
print(f'Climetlab version : {cml.__version__}')
print(f'Climetlab-s2s-ai-challenge plugin version : {climetlab_s2s_ai_challenge.__version__}')

xr.set_options(keep_attrs=True)
xr.set_options(display_style='text')

In [None]:
# caching path for climetlab
cache_path_cml = "/work/S2S_AI/cache4"#'../../Data/s2s_ai/cache'## set your own path
cml.settings.set("cache-directory", cache_path_cml)

In [None]:
cache_path = "../data"#'../../Data/s2s_ai'

# Download and cache

Download all files for the observations, forecast and hindcast.

In [None]:
# shortcut
#from scripts import download
#download()

## hindcast and forecast `input`

### adapt below to your variable (adapt both, cache_path_cml and varlist_forecast)
#### first check the size of the data here: https://storage.ecmwf.europeanweather.cloud/s2s-ai-challenge/data/training-input/0.3.0/netcdf/index.html
small data sets use 40GB of cache, larger up to 300GB

In [None]:
varlist_forecast = ['sst']#z500 #['tp','t2m'] # can add more

# caching path for climetlab
# set your own path: I created a new folder for each variable
cache_path_cml = "/work/S2S_AI/cache4/sst"#'../../Data/s2s_ai/cache/sm20' 
cml.settings.set("cache-directory", cache_path_cml)

center_list = ['ecmwf'] # 'ncep', 'eccc'

#forecast_dataset_labels = ['training-input','test-input'] # ML community
# equiv to
forecast_dataset_labels = ['hindcast-input','forecast-input'] # NWP community

In [None]:
%%time
# takes ~ 10-30 min to download for one model one variable depending on number of model realizations
# and download settings https://climetlab.readthedocs.io/en/latest/guide/settings.html 
for center in center_list:
    for ds in forecast_dataset_labels:
        print(ds)
        cml.load_dataset(f"s2s-ai-challenge-{ds}", origin=center, parameter=varlist_forecast, format='netcdf').to_xarray()

## observations `output-reference`

In [None]:
varlist_obs = ['tp']#['tp', 't2m']#tcw

# starting dates forecast_time in 2020
dates = xr.cftime_range(start='20200102',freq='7D', periods=53).strftime('%Y%m%d').to_list()

obs_dataset_labels = ['training-output-reference','test-output-reference'] # ML community
# equiv to
obs_dataset_labels = ['hindcast-like-observations','forecast-like-observations'] # NWP community

In [None]:
#### we are mostly interested in downloading more variables for the input...
#%%time
# takes 10min to download
#for ds in obs_dataset_labels:
#    print(ds)
#    # only netcdf, no format choice
#    cml.load_dataset(f"s2s-ai-challenge-{ds}", date=dates, parameter=varlist_obs).to_xarray()

In [None]:
# download obs_time for to create output-reference/observations for other models than ecmwf and eccc,
# i.e. ncep or any S2S or Sub model
#obs_time = cml.load_dataset(f"s2s-ai-challenge-observations", parameter=['t2m', 'pr']).to_xarray()

# create bi-weekly aggregates

In [None]:
from scripts import aggregate_biweekly, ensure_attributes

In [None]:
print(center_list)
print(obs_dataset_labels)
print(varlist_forecast)

## biweekly

change "dsl in ..."-line to aggregate obs

In [None]:
for c, center in enumerate(center_list):  # forecast centers (could also take models)
    for dsl in forecast_dataset_labels:  # climetlab dataset labels #obs_dataset_labels:# +  change this line depending on whether you want to download obs or fct
        #read in data to ds
        for p, parameter in enumerate(varlist_forecast):  # variables
            if c != 0 and 'observation' in dsl:  # only do once for observations 
                continue
            print(f"datasetlabel: {dsl}, center: {center}, parameter: {parameter}")
            
            #forecasts/hindcasts
            if 'input' in dsl:
                ds = cml.load_dataset(f"s2s-ai-challenge-{dsl}", origin=center, parameter=parameter, format='netcdf').to_xarray()
            #obs
            elif 'observation' in dsl: # obs only netcdf, no choice
                if parameter not in ['t2m', 'tp']:
                    continue
                ds = cml.load_dataset(f"s2s-ai-challenge-{dsl}", parameter=parameter, date=dates).to_xarray()

            if p == 0: #first variable
                ds_biweekly = ds.map(aggregate_biweekly)
            else:
                #here you do the biweekly aggregation #add more vars
                ds_biweekly[parameter] = ds.map(aggregate_biweekly)[parameter]#map: apply function to each variable in dataset

            ds_biweekly = ds_biweekly.map(ensure_attributes, biweekly=True)
            
            ds_biweekly = ds_biweekly.sortby('forecast_time')
    #ds is now ready
        if 'test' in dsl:
            ds_biweekly = ds_biweekly.chunk('auto')
        else:
            ds_biweekly = ds_biweekly.chunk({'forecast_time':'auto','lead_time':-1,'longitude':-1,'latitude':-1})

        if 'hindcast' in dsl:
            time = f'{int(ds_biweekly.forecast_time.dt.year.min())}-{int(ds_biweekly.forecast_time.dt.year.max())}'
            if 'input' in dsl:
                name = f'{center}_{dsl}'
            elif 'observations':
                name = dsl

        elif 'forecast' in dsl:
            time = '2020'
            if 'input' in dsl:
                name = f'{center}_{dsl}'
            elif 'observations':
                name = dsl
        else:
            assert False

        # pattern: {model_if_not_observations}{observations/forecast/hindcast}_{time}_biweekly_deterministic.zarr
        params = '_'.join(varlist_forecast)
        zp = f'{cache_path}/{name}_{time}_biweekly_deterministic_{params}.zarr'
        ds_biweekly.attrs.update({'postprocessed_by':'https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/blob/master/notebooks/renku_datasets_biweekly.ipynb'})
        print(f'save to: {zp}')
        ds_biweekly.astype('float32').to_zarr(zp, consolidated=True, mode='w')

## create weekly aggregates

In [None]:
from scripts import add_valid_time_from_forecast_reference_time_and_lead_time
def aggregate_weekly(da):
    """
    Aggregate initialized S2S forecasts weekly for xr.DataArrays.
    Use ds.map(aggregate_weekly) for xr.Datasets.
    
    This function does not return values for week 1, 
    since lead_time = 0 is not available for certain variables (e.g. t2m) or is zero (e.g. tp).
    
    Applies to the ECMWF S2S data model: https://confluence.ecmwf.int/display/S2S/Parameters
    
    Parameters
    ----------
    da : xarray dataarray with time coordinate

    Returns
    -------
    da_weekly : xarray dataarray containing weekly aggregated quantities
    """
    #da = da.assign_coords({'forecast_time': da.time[0].values})
    #da = da.expand_dims('forecast_time')
    #da = da.assign_coords({'lead_time': da.time - da.time[0]})
    #da = da.swap_dims({'time': 'lead_time'})

    #da = da.assign_coords({'forecast_time': da.time[0].values})
    #da = da.assign_coords({'lead_time': da.time - da.time[0]})
    #da = da.swap_dims({'time': 'lead_time'})
    #da = da.assign_coords({'forecast_time': da.time[0]})
    #da = da.assign_coords({'lead_time': da.time - da.time[0]})
    
    # weekly averaging
    #w1 = [pd.Timedelta(f'{i} d') for i in range(0,7)]
    #w1 = xr.DataArray(w1,dims='lead_time', coords={'lead_time':w1})
    
    w2 = [pd.Timedelta(f'{i} d') for i in range(7,14)]
    w2 = xr.DataArray(w2,dims='lead_time', coords={'lead_time':w2})
    
    w3 = [pd.Timedelta(f'{i} d') for i in range(14,21)]
    w3 = xr.DataArray(w3,dims='lead_time', coords={'lead_time':w3})
    
    w4 = [pd.Timedelta(f'{i} d') for i in range(21,28)]
    w4 = xr.DataArray(w4,dims='lead_time', coords={'lead_time':w4})
    
    w4 = [pd.Timedelta(f'{i} d') for i in range(21,28)]
    w4 = xr.DataArray(w4,dims='lead_time', coords={'lead_time':w4})
    
    w5 = [pd.Timedelta(f'{i} d') for i in range(28,35)]
    w5 = xr.DataArray(w5,dims='lead_time', coords={'lead_time':w5})
    
    w6 = [pd.Timedelta(f'{i} d') for i in range(35,42)]
    w6 = xr.DataArray(w6,dims='lead_time', coords={'lead_time':w6})
    
    
    weekly_lead = [pd.Timedelta(f"{i} d") for i in [7, 14, 21, 28, 35,]] # take first day of weekly average as new coordinate

    v = da.name
    if v in ['tp', 'ttr']:#climetlab_s2s_ai_challenge.CF_CELL_METHODS[v] == 'sum': # weekly difference for sum variables: tp and ttr
        #d1 = da.sel(lead_time=pd.Timedelta("7 d")) - da.sel(lead_time=pd.Timedelta("0 d"))
        d2 = da.sel(lead_time=pd.Timedelta("14 d")) - da.sel(lead_time=pd.Timedelta("7 d"))
        d3 = da.sel(lead_time=pd.Timedelta("21 d")) - da.sel(lead_time=pd.Timedelta("14 d"))
        d4 = da.sel(lead_time=pd.Timedelta("28 d")) - da.sel(lead_time=pd.Timedelta("21 d"))
        d5 = da.sel(lead_time=pd.Timedelta("35 d")) - da.sel(lead_time=pd.Timedelta("28 d"))
        d6 = da.sel(lead_time=pd.Timedelta("42 d")) - da.sel(lead_time=pd.Timedelta("35 d"))
        
        #d34 = da.sel(lead_time=pd.Timedelta("28 d")) - da.sel(lead_time=pd.Timedelta("14 d")) # tp from day 14 to day 27
        #d56 = da.sel(lead_time=pd.Timedelta("42 d")) - da.sel(lead_time=pd.Timedelta("28 d")) # tp from day 28 to day 42
        
        #da_weekly = xr.concat([d1,d2,d3,d4,d5,d6],'lead_time').assign_coords(lead_time=biweekly_lead)
        
    else: # t2m, see climetlab_s2s_ai_challenge.CF_CELL_METHODS # biweekly: mean [day 14, day 27]
        #d1 = da.sel(lead_time=w1).mean('lead_time')
        d2 = da.sel(lead_time=w2).mean('lead_time')
        d3 = da.sel(lead_time=w3).mean('lead_time')
        d4 = da.sel(lead_time=w4).mean('lead_time')
        d5 = da.sel(lead_time=w5).mean('lead_time')
        d6 = da.sel(lead_time=w6).mean('lead_time')
        
        #d34 = da.sel(lead_time=w34).mean('lead_time')
        #d56 = da.sel(lead_time=w56).mean('lead_time')
        
    da_weekly = xr.concat([d2,d3,d4,d5,d6],'lead_time').assign_coords(lead_time=weekly_lead)
    
    da_weekly = add_valid_time_from_forecast_reference_time_and_lead_time(da_weekly)
    da_weekly['lead_time'].attrs = {'long_name':'forecast_period', 'description': 'Forecast period is the time interval between the forecast reference time and the validity time.',
                         'aggregate': 'The pd.Timedelta corresponds to the first day of a weekly aggregate.',
                         'week34_t2m': 'mean[day 14, 27]',
                         'week56_t2m': 'mean[day 28, 41]',
                         'week34_tp': 'day 28 minus day 14',
                         'week56_tp': 'day 42 minus day 28'}
    
    return da_weekly

In [None]:
#compute and save weekly aggregates
#cache data for all parameters must be available.
#all variables are loaded again if you do this for several variables at once

for c, center in enumerate(center_list):  # forecast centers (could also take models)
    for dsl in forecast_dataset_labels:  # climetlab dataset labels #obs_dataset_labels:# +  change this line depending on whether you want to download obs or fct
        #read in data to ds
        for p, parameter in enumerate(varlist_forecast):  # variables
            if c != 0 and 'observation' in dsl:  # only do once for observations 
                continue
            print(f"datasetlabel: {dsl}, center: {center}, parameter: {parameter}")
            
            #forecasts/hindcasts
            if 'input' in dsl:
                ds = cml.load_dataset(f"s2s-ai-challenge-{dsl}", origin=center, parameter=parameter, format='netcdf').to_xarray()
            #obs
            elif 'observation' in dsl: # obs only netcdf, no choice
                if parameter not in ['t2m', 'tp']:
                    continue
                ds = cml.load_dataset(f"s2s-ai-challenge-{dsl}", parameter=parameter, date=dates).to_xarray()

            if p == 0: #first variable
                ds_weekly = ds.map(aggregate_weekly)
            else:
                #here you do the biweekly aggregation #add more vars
                ds_weekly[parameter] = ds.map(aggregate_weekly)[parameter]#map: apply function to each variable in dataset

            ds_weekly = ds_weekly.map(ensure_attributes, biweekly=False)
            
            ds_weekly = ds_weekly.sortby('forecast_time')
    #ds is now ready
        if 'test' in dsl:
            ds_weekly = ds_weekly.chunk('auto')
        else:
            ds_weekly = ds_weekly.chunk({'forecast_time':'auto','lead_time':-1,'longitude':-1,'latitude':-1})

        if 'hindcast' in dsl:
            time = f'{int(ds_weekly.forecast_time.dt.year.min())}-{int(ds_weekly.forecast_time.dt.year.max())}'
            if 'input' in dsl:
                name = f'{center}_{dsl}'
            elif 'observations':
                name = dsl

        elif 'forecast' in dsl:
            time = '2020'
            if 'input' in dsl:
                name = f'{center}_{dsl}'
            elif 'observations':
                name = dsl
        else:
            assert False

        # pattern: {model_if_not_observations}{observations/forecast/hindcast}_{time}_biweekly_deterministic.zarr
        params = '_'.join(varlist_forecast)
        zp = f'{cache_path}/{name}_{time}_weekly_deterministic_{params}.zarr'
        ds_weekly.attrs.update({'postprocessed_by':'https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/blob/master/notebooks/renku_datasets_biweekly.ipynb'})
        print(f'save to: {zp}')
        ds_weekly.astype('float32').to_zarr(zp, consolidated=True, mode='w')

In [None]:
#ds_w_hindcast = ds_hindcast.map(aggregate_weekly)

## try out cml.load_dataset

In [None]:
parameter = ['t2m']#varlist_forecast

#this part is quick once the data has been downloaded to the cache
#dsl: forecast_dataset_labels = ['hindcast-input','forecast-input']
ds_hindcast = cml.load_dataset(f"s2s-ai-challenge-{'hindcast-input'}", origin='ecmwf', parameter=parameter, format='netcdf').to_xarray()


In [None]:
ds_hindcast

In [None]:
#tp
ds_hindcast_tp = cml.load_dataset(f"s2s-ai-challenge-{'hindcast-input'}", origin='ecmwf', parameter='tp', format='netcdf').to_xarray()


In [None]:
ds_hindcast_tp

In [None]:
ds_hindcast.isel(forecast_time = 0).mean(('realization','latitude','longitude')).t2m.plot()

# validate weekly aggregation

In [None]:
hind_2000_2019 = xr.open_zarr(f'{cache_path}/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr', consolidated=True)

In [None]:
hind_2000_2019

In [None]:
hind_2000_2019_weekly = xr.open_zarr(f'{cache_path}/ecmwf_hindcast-input_2000-2019_weekly_deterministictp_t2m.zarr', consolidated=True)

In [None]:
hind_2000_2019_weekly

In [None]:
hind_2000_2019_weekly.isel(lead_time = 1)

### tp 34

In [None]:
biweekly_from_weekly_tp_34 = hind_2000_2019_weekly.tp.isel(lead_time = 1) +hind_2000_2019_weekly.tp.isel(lead_time = 2)

In [None]:
biweekly_from_weekly_tp_34

In [None]:
(hind_2000_2019.isel(lead_time = 0).tp - biweekly_from_weekly_tp_34).sum(('realization', 'forecast_time')).plot()

### tp 56

In [None]:
biweekly_from_weekly_tp_56 = hind_2000_2019_weekly.tp.isel(lead_time = 3) +hind_2000_2019_weekly.tp.isel(lead_time = 4)

In [None]:
(hind_2000_2019.isel(lead_time = 1).tp - biweekly_from_weekly_tp_56).sum(('realization', 'forecast_time')).plot()

### t2m 34

In [None]:
biweekly_from_weekly_t2m_34 = (hind_2000_2019_weekly.t2m.isel(lead_time = 1) + hind_2000_2019_weekly.t2m.isel(lead_time = 2))/2

In [None]:
(hind_2000_2019.isel(lead_time = 0).t2m - biweekly_from_weekly_t2m_34).sum(('realization', 'forecast_time')).plot()

In [None]:
biweekly_from_weekly_t2m_34.mean(('realization', 'forecast_time')).plot()

### t2m 56

In [None]:
biweekly_from_weekly_t2m_56 = (hind_2000_2019_weekly.t2m.isel(lead_time = 3) +hind_2000_2019_weekly.t2m.isel(lead_time = 4))/2

In [None]:
(hind_2000_2019.isel(lead_time = 1).t2m - biweekly_from_weekly_t2m_56).sum(('realization', 'forecast_time')).plot()