# Train ML model to correct predictions of week 3-4 & 5-6

This notebook create a Machine Learning `ML_model` to predict weeks 3-4 & 5-6 based on `S2S` weeks 3-4 & 5-6 forecasts and is compared to `CPC` observations for the [`s2s-ai-challenge`](https://s2s-ai-challenge.github.io/).

# Synopsis

## Method: `ML-based mean bias reduction`

- calculate the ML-based bias from 2000-2019 deterministic ensemble mean forecast
- remove that the ML-based bias from 2020 forecast deterministic ensemble mean forecast

## Data used

type: renku datasets

Training-input for Machine Learning model:
- hindcasts of models:
    - ECMWF: `ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr`

Forecast-input for Machine Learning model:
- real-time 2020 forecasts of models:
    - ECMWF: `ecmwf_forecast-input_2020_biweekly_deterministic.zarr`

Compare Machine Learning model forecast against against ground truth:
- `CPC` observations:
    - `hindcast-like-observations_biweekly_deterministic.zarr`
    - `forecast-like-observations_2020_biweekly_deterministic.zarr`

## Resources used
for training, details in reproducibility

- platform: renku
- memory: 8 GB
- processors: 2 CPU
- storage required: 10 GB

## Safeguards

All points have to be [x] checked. If not, your submission is invalid.

Changes to the code after submissions are not possible, as the `commit` before the `tag` will be reviewed.
(Only in exceptions and if previous effort in reproducibility can be found, it may be allowed to improve readability and reproducibility after November 1st 2021.)

### Safeguards to prevent [overfitting](https://en.wikipedia.org/wiki/Overfitting?wprov=sfti1) 

If the organizers suspect overfitting, your contribution can be disqualified.

  - [x] We did not use 2020 observations in training (explicit overfitting and cheating)
  - [x] We did not repeatedly verify my model on 2020 observations and incrementally improved my RPSS (implicit overfitting)
  - [x] We provide RPSS scores for the training period with script `print_RPS_per_year`, see in section 6.3 `predict`.
  - [x] We tried our best to prevent [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)?wprov=sfti1).
  - [x] We honor the `train-validate-test` [split principle](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets). This means that the hindcast data is split into `train` and `validate`, whereas `test` is withheld.
  - [x] We did not use `test` explicitly in training or implicitly in incrementally adjusting parameters.
  - [x] We considered [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)).

### Safeguards for Reproducibility
Notebook/code must be independently reproducible from scratch by the organizers (after the competition), if not possible: no prize
  - [x] All training data is publicly available (no pre-trained private neural networks, as they are not reproducible for us)
  - [x] Code is well documented, readable and reproducible.
  - [x] Code to reproduce training and predictions is preferred to run within a day on the described architecture. If the training takes longer than a day, please justify why this is needed. Please do not submit training piplelines, which take weeks to train.

# Todos to improve template

This is just a demo.

- [ ] use multiple predictor variables and two predicted variables
- [ ] for both `lead_time`s in one go
- [ ] consider seasonality, for now all `forecast_time` months are mixed
- [ ] make probabilistic predictions with `category` dim, for now works deterministic

# Imports

In [None]:
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Sequential
import tensorflow.keras as keras

import matplotlib.pyplot as plt
import numpy as np


import xarray as xr
xr.set_options(display_style='text')



from dask.utils import format_bytes
import xskillscore as xs

%matplotlib inline 
#for figures


#for prediction
from scripts import make_probabilistic
from scripts import add_valid_time_from_forecast_reference_time_and_lead_time
from scripts import skill_by_year
from scripts import add_year_week_coords


from helper_ml_data import load_data

# Get training data

preprocessing of input data may be done in separate notebook/script

In [None]:
path_data = 'local_n'

## Hindcast

get weekly initialized hindcasts

In [None]:
hind_2000_2019 = load_data(data = 'hind_2000-2019', aggregation = 'biweekly', path = path_data)

In [None]:
hind_2000_2019 = hind_2000_2019.isel(depth_below_and_layer = 0 ).reset_coords('depth_below_and_layer', drop=True)

In [None]:
hind_2000_2019

## Observations
corresponding to hindcasts

In [None]:
obs_2000_2019 = load_data(data = 'obs_2000-2019', aggregation = 'biweekly', path = path_data)

In [None]:
#obs_2000_2019

terciled

In [None]:
obs_2000_2019_terciled = load_data(data = 'obs_terciled_2000-2019', aggregation = 'biweekly', path = path_data)

In [None]:
#obs_2000_2019_terciled

### Select region

* use the whole globe as input since global teleconnections are a major source of predictability, at least over Europe.
* create predictions for a much smaller domain, since otherwise the basis functions (and the associated multiplication) uses too much memory.

to make life easier for the beginning

In [None]:
input_lat = slice(90,-90)
input_lon = slice(0, 360)

output_lat = slice(90,0)
output_lon = slice(0,90)

In [None]:
hind_2000_2019 = hind_2000_2019.sel(longitude = input_lon, latitude = input_lat)
obs_2000_2019 = obs_2000_2019.sel(longitude = output_lon, latitude = output_lat)
obs_2000_2019_terciled = obs_2000_2019_terciled.sel(longitude = output_lon, latitude = output_lat)

In [None]:
obs_2000_2019.t2m.isel(lead_time = 0, forecast_time = 0).plot()

## Train Validation split

In [None]:
# time is the forecast_time
time_train_start,time_train_end='2000','2017' # train
time_valid_start,time_valid_end='2018','2019' # valid

## Weatherbench

based on [Weatherbench](https://github.com/pangeo-data/WeatherBench/blob/master/quickstart.ipynb)

In [None]:
# run once only and dont commit
#!git clone https://github.com/pangeo-data/WeatherBench/

In [None]:
import sys
sys.path.insert(1, 'WeatherBench')
from WeatherBench.src.train_nn import PeriodicConv2D, create_predictions#DataGenerator, 

### define some vars

In [None]:
v='t2m'
bs=32

https://s2s-ai-challenge.github.io/

We deal with two fundamentally different variables here: 
- Total precipitation is precipitation flux pr accumulated over lead_time until valid_time and therefore describes a point observation. 
- 2m temperature is averaged over lead_time(valid_time) and therefore describes an average observation. 

The submission file data model unifies both approaches and assigns 14 days for week 3-4 and 28 days for week 5-6 marking the first day of the biweekly aggregate.

In [None]:
# 2 bi-weekly `lead_time`: week 3-4
lead_input = hind_2000_2019.lead_time[0]
lead_output = obs_2000_2019.lead_time[0]
#lead.values
#lead

### create datasets

In [None]:
#mask: same missing values at all forecast_times
mask = xr.where(obs_2000_2019.notnull(),1,np.nan).mean('forecast_time', skipna = False)

In [None]:
#train
fct_train = hind_2000_2019.sel(forecast_time=slice(time_train_start,time_train_end))
verif_train = obs_2000_2019_terciled.sel(forecast_time=slice(time_train_start,time_train_end))[v]

verif_train = verif_train.where(mask[v].notnull())

fct_train_mean = fct_train.mean('forecast_time')
fct_train_std = fct_train.std('forecast_time')

verif_train_mean = verif_train.mean('forecast_time')
verif_train_std = verif_train.std('forecast_time')

In [None]:
#validation
fct_valid = hind_2000_2019.sel(forecast_time=slice(time_valid_start,time_valid_end))#[v]
verif_valid = obs_2000_2019_terciled.sel(forecast_time=slice(time_valid_start,time_valid_end))[v]

verif_valid = verif_valid.where(mask[v].notnull())

In [None]:
fct_train

### Remove annual cycle

In [None]:
def rm_annualcycle(ds, ds_train):
    #remove annual cycle for each location 
    
    ds = add_year_week_coords(ds)
    ds_train = add_year_week_coords(ds_train)
    
    if 'realization' in ds_train.coords:#always use train data to compute the annual cycle
        ens_mean = ds_train.mean('realization')
    else:
        ens_mean = ds_train

    ds_stand = ds - ens_mean.groupby('week').mean(['forecast_time'])

    ds_stand = ds_stand.sel({'week' : ds.coords['week']})
    ds_stand = ds_stand.drop(['week','year'])
    ds_stand
    return ds_stand

In [None]:
###remove annual cycle here
fct_train = rm_annualcycle(fct_train, fct_train)
fct_valid = rm_annualcycle(fct_valid, fct_valid)

In [None]:
#use ensemble mean as input
fct_train = fct_train.mean('realization')
fct_valid = fct_valid.mean('realization')

In [None]:
fct_train.sel(lead_time = lead_input).isel(forecast_time = 0)[v].plot()

In [None]:
verif_train.sel(lead_time = lead_output).isel(forecast_time = 0).plot(col = 'category')

### create basis functions

In [None]:
def get_basis(out_field, r_basis):
    """returns a set of basis functions for the input field, adapted from Scheuerer et al. 2020.

    PARAMETERS:
    out_field : (xarray DataArray) basis functions for these lat lon coords will be created
    r_basis : (int) radius of support of basis functions, 
                    the distance between centers of basis functions is half this radius,
                    should be choosen depending on input field size.
    
    RETURNS:
    basis : 
    lats : lats of input field
    lons : lons of input field
    n_xy : number of grid points in input field
    n_basis : number of basis functions
    """  
    
    #r_basis = 14 #radius of support of basis functions
    dist_basis = r_basis/2 #distance between centers of basis functions

    lats = out_field.latitude
    lons = out_field.longitude
    
    #number of basis functions
    n_basis = int(np.ceil((lats[0] - lats[-1])/dist_basis + 1)*np.ceil((lons[-1] - lons[0])/dist_basis + 1))

    #grid coords
    lon_np = lons
    lat_np = lats
    
    length_lon = len(lon_np)
    length_lat = len(lat_np)
    
    lon_np = np.outer(lon_np, np.ones(length_lat)).reshape(int(length_lon * length_lat))
    lat_np = np.outer(lat_np, np.ones(length_lon)).reshape(int(length_lon * length_lat))

    #number of grid points
    n_xy = int(length_lon*length_lat)

    #centers of basis functions
    lon_ctr = np.arange(lons[0],lons[-1] + dist_basis,dist_basis)
    length_lon_ctr = len(lon_ctr) #number of center points in lon direction

    lat_ctr = np.arange(lats[0],lats[-1] - dist_basis,- dist_basis)
    length_lat_ctr = len(lat_ctr) #number of center points in lat direction

    lon_ctr = np.outer(lon_ctr, np.ones(length_lat_ctr)).reshape(int(n_basis))
    lat_ctr = np.outer(np.ones(length_lon_ctr), lat_ctr).reshape(int(n_basis))

    #compute distances between fct grid and basis function centers
    dst_lon = np.abs(np.subtract.outer(lon_np,lon_ctr).reshape(len(lons),len(lats),n_basis))#10,14
    dst_lon = np.swapaxes(dst_lon, 0, 1)
    dst_lat = np.abs(np.subtract.outer(lat_np,lat_ctr).reshape(len(lats),len(lons),n_basis))#'14,10'

    dst = np.sqrt(dst_lon**2+dst_lat**2)
    dst = np.swapaxes(dst, 0, 1).reshape(n_xy,n_basis)

    #define basis functions
    basis = np.where(dst>r_basis,0.,(1.-(dst/r_basis)**3)**3)#main step, zero outside, 
    basis = basis/np.sum(basis,axis=1)[:,None]#normalization at each grid point
    nbs = basis.shape[1]
    
    return basis, lats, lons, n_xy, n_basis

#### compute basis

In [None]:
basis, lats, lons, n_xy, n_basis = get_basis(obs_2000_2019_terciled[v], 14) #30
##the smaller you choose the radius of the basis functions, the more memory needs to be allocated

In [None]:
basis.shape

## `fit`

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, Dropout, Reshape, Dot, Add, Activation

In [None]:
#input parameters for CNN

n_bins = 3 #we have to predict probs for 3 bins
dropout_rate = 0.4
hidden_nodes = [10] #they tried different architectures
clim_probs = np.log(1/3)*np.ones((n_xy,n_bins))

In [None]:
# CNN: slightly adapted from Scheuerer et al 2020.

inp_imgs = Input(shape=(121,240,4,)) #fcts
inp_basis = Input(shape=(n_xy,n_basis)) #basis
inp_cl = Input(shape=(n_xy,n_bins,)) #climatology

c = Conv2D(4, (3,3), activation='elu')(inp_imgs)
c = MaxPooling2D((2,2))(c)
c = Conv2D(8, (3,3), activation='elu')(c)
c = MaxPooling2D((2,2))(c)
x = Flatten()(c)
for h in hidden_nodes: 
    x = Dropout(dropout_rate)(x)
    x = Dense(h, activation='elu')(x)
x = Dense(n_bins*n_basis, activation='elu')(x)
x = Reshape((n_bins,n_basis))(x)
z = Dot(axes=2)([inp_basis, x])     # Tensor product with basis functions
z = Add()([z, inp_cl])              # Add (log) probability anomalies to log climatological probabilities 
out = Activation('softmax')(z)


In [None]:
cnn = Model(inputs=[inp_imgs, inp_basis, inp_cl], outputs=out)

In [None]:
cnn.summary()

In [None]:
cnn.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer=keras.optimizers.Adam(1e-4))#'adam')

In [None]:
import warnings
warnings.simplefilter("ignore")

In [None]:
#from tensorflow import set_random_seed
from tensorflow import random
random.set_seed(1)#this seems not to have any effect so far...

build a data generator, ow it takes too long to convert to numpy. its also probably too large for an efficient training.

In [None]:

class DataGenerator(keras.utils.Sequence):
    def __init__(self, fct, verif, lead_input, lead_output, basis, clim_probs, batch_size=32, shuffle=True, load=True):
        
        """
        Data generator for WeatherBench data.
        Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

        Args:
            fct: forecasts from S2S models: xr.DataArray (xr.Dataset doesnt work properly)
            verif: not true#observations with same dimensionality (xr.Dataset doesnt work properly)
            lead_time: Lead_time as in model
            batch_size: Batch size
            shuffle: bool. If True, data is shuffled.
            load: bool. If True, datadet is loaded into RAM.
            mean: If None, compute mean from data.
            std: If None, compute standard deviation from data.
            
        Todo:
        - use number in a better way, now uses only ensemble mean forecast
        - dont use .sel(lead_time=lead_time) to train over all lead_time at once
        - be sensitive with forecast_time, pool a few around the weekofyear given
        - use more variables as predictors
        - predict more variables
        """
        
                
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.lead_input = lead_input
        self.lead_output = lead_output
        
        ###remove annual cycle here
        ### add more variables
        
        self.fct_data = fct.transpose('forecast_time', ...).sel(lead_time=lead_input)
        #self.fct_mean = self.fct_data.mean('forecast_time').compute() if mean_fct is None else mean_fct.sel(lead_time = lead)
        #self.fct_std = self.fct_data.std('forecast_time').compute() if std_fct is None else std_fct.sel(lead_time = lead)
        
        self.verif_data = verif.transpose('forecast_time', ...).sel(lead_time=lead_output)
        #self.verif_mean = self.verif_data.mean('forecast_time').compute() if mean_verif is None else mean_verif.sel(lead_time = lead)
        #self.verif_std = self.verif_data.std('forecast_time').compute() if std_verif is None else std_verif.sel(lead_time = lead)

        # Normalize
        #self.fct_data = (self.fct_data - self.fct_mean) / self.fct_std
        #self.verif_data = (self.verif_data - self.verif_mean) / self.verif_std
    
        
        self.n_samples = self.fct_data.forecast_time.size
        #self.forecast_time = self.fct_data.forecast_time
        #self.n_lats = self.fct_data.latitude.size - self.window_size
        #self.n_lons = self.fct_data.longitude.size - self.window_size

        self.on_epoch_end()

        # For some weird reason calling .load() earlier messes up the mean and std computations
        if load:
            # print('Loading data into RAM')
            self.fct_data.load()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(self.n_samples / self.batch_size))

    def __getitem__(self, i):
        'Generate one batch of data'
        idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]
        #lats = self.lats [i * self.batch_size:(i + 1) * self.batch_size]
        #lons = self.lons[i * self.batch_size:(i + 1) * self.batch_size]
        ##data comes in a row, not randomly chosen from within train data, if shuffled beforehand,--> stays shuffled because of isel
        # got all nan if nans not masked
        X_x = self.fct_data.isel(forecast_time=idxs).fillna(0.).to_array().transpose('forecast_time', ...,'variable').values#.values
        
        
        X_basis = np.repeat(basis[np.newaxis,:,:],len(idxs),axis=0)
        X_clim =  np.repeat(clim_probs[np.newaxis,:,:],len(idxs),axis=0)#self.batch_size
        
        X = [X_x, X_basis, X_clim]
        
        #X = self.fct_data.isel(forecast_time=idxs).isel(latitude = slice(lats,lats + self.window_size), 
         #                                               longitude = slice(lons,lons + self.window_size)).fillna(0.).values
        #x_coords = (math.ceil((lats + self.window_size)/2),math.ceil((lons + self.window_size)/2))
        y = self.verif_data.stack(Z = ['latitude','longitude']).transpose('forecast_time','Z',...).isel(forecast_time=idxs).fillna(0.).values
        #y = self.verif_data.isel(forecast_time=idxs).fillna(0.).values
        #y = self.verif_data.isel(forecast_time=idxs).isel(latitude = x_coords[0], 
         #                                                 longitude = x_coords[1]).fillna(0.).values
        
        return X, y # x_coords,

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.idxs = np.arange(self.n_samples)
        #self.lats = np.arange(self.n_lats)
        #self.lons = np.arange(self.n_lons)
        if self.shuffle == True: ###does this make sense here?
            np.random.shuffle(self.idxs)
            #np.random.shuffle(self.lats)
            #np.random.shuffle(self.lons)

creating the data generators takes too much time, probably because everything has to be opened because of rm_annualcycle

In [None]:
dg_train = DataGenerator(fct_train, verif_train,#.sel(forecast_time = slice('2000','2001'))
                         lead_input=lead_input,lead_output = lead_output,  basis = basis, clim_probs = clim_probs,
                         batch_size=bs, load=True)#,
                         #mean_fct=fct_train_mean, std_fct=fct_train_std, 
                         #mean_verif = verif_train_mean, std_verif = verif_train_std)

In [None]:
dg_train[0][1].shape

In [None]:
dg_train[-1][0][2].shape

In [None]:
dg_valid = DataGenerator(fct_valid, verif_valid,
                         lead_input=lead_input,lead_output = lead_output, basis = basis, clim_probs = clim_probs, batch_size=bs, load=True)#,

In [None]:
dg_valid[0][1].shape

In [None]:
dg_valid[-1][0][0].shape

In [None]:
dg_valid[-1][0][1].shape

In [None]:
dg_valid[-1][0][2].shape

In [None]:
cnn.fit(dg_train, 
         epochs=5, shuffle = True,
        validation_data = dg_valid)

In [None]:
#some ideas for improvements:
#circular input
#https://www.tu-chemnitz.de/etit/proaut/publications/schubert19_IV.pdf
#https://www.tu-chemnitz.de/etit/proaut/en/research/ccnn.html
#add indices
#https://stackoverflow.com/questions/47818968/adding-an-additional-value-to-a-convolutional-neural-network-input
#https://datascience.stackexchange.com/questions/68450/how-can-you-include-information-not-present-in-an-image-for-neural-networks
#https://www.nature.com/articles/s41598-019-42294-8


## `predict`

### define create_prediction

In [None]:
def _create_predictions(model, dg, lead):
    """Create non-iterative predictions
    returns: prediction in the shape of the input arguments to DataGenerator classe
    """
    import tensorflow as tf 
    
    preds = model.predict(dg).squeeze()
    
    preds = Reshape((len(lons),len(lats),3))(preds)
    preds = tf.transpose(preds, [0,3,1,2])

    da = xr.DataArray(
                preds,
                dims=['forecast_time', 'category','longitude', 'latitude'],
                coords={'forecast_time': fct_valid.forecast_time, 'category' : verif_train.category, 'latitude': verif_train.latitude,
                        'longitude': verif_train.longitude}
            )
    da = da.transpose('forecast_time','category','latitude',...)
    da = da.assign_coords(lead_time=lead_output)
    return da

In [None]:
def add_valid_time_single(forecast, init_dim='forecast_time'):
    """Creates valid_time(forecast_time, lead_time) for a single lead time and variable
    
    lead_time: pd.Timedelta
    forecast_time: datetime
    """
    times = xr.DataArray(
                forecast[init_dim] + lead_output,
                dims=init_dim,
                coords={init_dim: forecast[init_dim]},
            )
            
    forecast = forecast.assign_coords(valid_time=times)
    return forecast

In [None]:
def single_prediction(cnn, dg, lead):
    """prediction for one var and one lead-time
    
    args:
    time: time slice
    
    """

    preds_test = _create_predictions(cnn, dg, lead)
    
    # add valid_time coord
    ###preds_test = add_valid_time_from_forecast_reference_time_and_lead_time(preds_test)
    # only works for complete output
    preds_test = add_valid_time_single(preds_test)
    #preds_test = preds_test.to_dataset(name=v)
                                    
    return preds_test

#### prediction from CNN

In [None]:
preds = cnn.predict(dg_valid).squeeze()
preds.shape

In [None]:
    import tensorflow as tf
    preds = Reshape((len(lons),len(lats),3))(preds)
    preds = tf.transpose(preds, [0,3,1,2])
    preds

In [None]:
preds_single = single_prediction(cnn, 
                                 [fct_valid.sel(lead_time = lead_input).fillna(0.).to_array().transpose('forecast_time', ...,'variable').values,
                                   np.repeat(basis[np.newaxis,:,:],len(fct_valid.forecast_time),axis=0),
                                 np.repeat(clim_probs[np.newaxis,:,:],len(fct_valid.forecast_time),axis=0)],
                                  lead_input) 

In [None]:
preds_masked = preds_single.where(mask[v].sel(lead_time = lead_output).notnull())
preds_masked.isel(forecast_time = 0).plot(col = 'category')#, vmin = 0, vmax = 1)

#### ground truth

In [None]:
verif_valid.sel(lead_time = lead_output).isel(forecast_time = 0).plot(col = 'category')

In [None]:
verif_valid.sel(lead_time = lead_output).isel(forecast_time = 0)

#### tercile probs of the raw ensemble

In [None]:
### load data
#tercile_edges : used in create_predictions --> make_probabilistic
#mask: used in make_probabilistic, but make_probabilistic would also work without mask

#!git lfs pull ../template/data/hindcast-like-observations_2000-2019_biweekly_tercile-edges.nc
#tercile_file = f'{cache_path}/hindcast-like-observations_2000-2019_biweekly_tercile-edges.nc'
#tercile_edges = xr.open_dataset(tercile_file)
tercile_edges = load_data(data = 'obs_tercile_edges_2000-2019', aggregation = 'biweekly', path = path_data)
#obs_2000_2019_terciled = load_data(data = 'obs_terciled_2000-2019', aggregation = 'biweekly', path = path_data)

In [None]:
test = hind_2000_2019.sel(forecast_time=slice(time_valid_start,time_valid_end))[v]
test_raw = make_probabilistic(test, tercile_edges)
test_raw.isel(lead_time = 0).isel(forecast_time = 0)['t2m'].where(mask[v].sel(lead_time = lead_output).notnull()).plot(col = 'category')

In [None]:
#sanity check: probs of terciles add up to one
#preds_single.isel(forecast_time = 0).sum('category').plot()

## Compute RPSS

In [None]:
#computes RPSS wrt climatology (1/3 for each category. So, negative RPSS are worse than climatology...

def skill_by_year_single(prediction, terciled_obs):
    """version of skill_by_year adjusted to one var and one lead time and flexibel validation period"""
    fct_p = prediction
    obs_p = terciled_obs


    # climatology
    clim_p = xr.DataArray([1/3, 1/3, 1/3], dims='category', coords={'category':['below normal', 'near normal', 'above normal']}).to_dataset(name='tp')
    clim_p['t2m'] = clim_p['tp']

    clim_p = clim_p[v]

    ## RPSS
    # rps_ML
    rps_ML = xs.rps(obs_p, fct_p, category_edges=None, dim=[], input_distributions='p').compute()
    # rps_clim
    rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim=[], input_distributions='p').compute()

    # rpss
    rpss = 1 - (rps_ML / rps_clim)

    # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7

    # penalize
    penalize = obs_p.where(fct_p!=1, other=-10).mean('category')
    rpss = rpss.where(penalize!=0, other=-10)

    # clip
    rpss = rpss.clip(-10, 1)

    # average over all forecasts
    rpss_year = rpss.groupby('forecast_time.year').mean()

    # weighted area mean
    weights = np.cos(np.deg2rad(np.abs(rpss_year.latitude)))
    # spatially weighted score averaged over lead_times and variables to one single value
    scores = rpss_year.sel(latitude=slice(None, -60)).weighted(weights).mean('latitude').mean('longitude')
    #scores = scores.to_array().mean(['lead_time', 'variable'])

    return scores.to_dataframe('RPSS') 

In [None]:
#obs_2000_2019_terciled.sel(forecast_time=slice(time_valid_start,time_valid_end))[v].sel(lead_time = lead)

In [None]:
skill_by_year_single(preds_single, 
                     obs_2000_2019_terciled.sel(forecast_time=slice(time_valid_start,time_valid_end))[v].sel(lead_time = lead_output))

In [None]:
# RPSS in the order of -0.015 (for r_basis = 14)
# if you choose a larger radius for the basis functions, we might come closer to climatology and hence achieve a better RPSS.
# fluctuations in predictions are large.

In [None]:
skill_by_year_single(test_raw[v].sel(latitude = output_lat, longitude = output_lon).sel(lead_time = lead_output),
                     obs_2000_2019_terciled.sel(forecast_time=slice(time_valid_start,time_valid_end))[v].sel(lead_time = lead_output))

In [None]:
# RPSS in the order of -0.635

#### The RPSS of this CNN approach is higher than for all ANNs. 
CNN fields are less smooth than the fields of ANN terciled. 

# Reproducibility

## memory

In [None]:
# https://phoenixnap.com/kb/linux-commands-check-memory-usage
!free -g

## CPU

In [None]:
!lscpu

## software

In [None]:
!conda list