# Train ML model to correct predictions of week 3-4 & 5-6

This notebook create a Machine Learning `ML_model` to predict weeks 3-4 & 5-6 based on `S2S` weeks 3-4 & 5-6 forecasts and is compared to `CPC` observations for the [`s2s-ai-challenge`](https://s2s-ai-challenge.github.io/).

# Synopsis

## Method: `ML-based mean bias reduction`

- calculate the ML-based bias from 2000-2019 deterministic ensemble mean forecast
- remove that the ML-based bias from 2020 forecast deterministic ensemble mean forecast

## Data used

type: renku datasets

Training-input for Machine Learning model:
- hindcasts of models:
    - ECMWF: `ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr`

Forecast-input for Machine Learning model:
- real-time 2020 forecasts of models:
    - ECMWF: `ecmwf_forecast-input_2020_biweekly_deterministic.zarr`

Compare Machine Learning model forecast against against ground truth:
- `CPC` observations:
    - `hindcast-like-observations_biweekly_deterministic.zarr`
    - `forecast-like-observations_2020_biweekly_deterministic.zarr`

## Resources used
for training, details in reproducibility

- platform: renku
- memory: 8 GB
- processors: 2 CPU
- storage required: 10 GB

## Safeguards

All points have to be [x] checked. If not, your submission is invalid.

Changes to the code after submissions are not possible, as the `commit` before the `tag` will be reviewed.
(Only in exceptions and if previous effort in reproducibility can be found, it may be allowed to improve readability and reproducibility after November 1st 2021.)

### Safeguards to prevent [overfitting](https://en.wikipedia.org/wiki/Overfitting?wprov=sfti1) 

If the organizers suspect overfitting, your contribution can be disqualified.

  - [x] We did not use 2020 observations in training (explicit overfitting and cheating)
  - [x] We did not repeatedly verify my model on 2020 observations and incrementally improved my RPSS (implicit overfitting)
  - [x] We provide RPSS scores for the training period with script `print_RPS_per_year`, see in section 6.3 `predict`.
  - [x] We tried our best to prevent [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)?wprov=sfti1).
  - [x] We honor the `train-validate-test` [split principle](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets). This means that the hindcast data is split into `train` and `validate`, whereas `test` is withheld.
  - [x] We did not use `test` explicitly in training or implicitly in incrementally adjusting parameters.
  - [x] We considered [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)).

### Safeguards for Reproducibility
Notebook/code must be independently reproducible from scratch by the organizers (after the competition), if not possible: no prize
  - [x] All training data is publicly available (no pre-trained private neural networks, as they are not reproducible for us)
  - [x] Code is well documented, readable and reproducible.
  - [x] Code to reproduce training and predictions is preferred to run within a day on the described architecture. If the training takes longer than a day, please justify why this is needed. Please do not submit training piplelines, which take weeks to train.

# Todos to improve template

This is just a demo.

- [ ] use multiple predictor variables and two predicted variables
- [ ] for both `lead_time`s in one go
- [ ] consider seasonality, for now all `forecast_time` months are mixed
- [ ] make probabilistic predictions with `category` dim, for now works deterministic

# Description of this notebook

* makes probabilistic predictions for categories
* only for one lead time and temperature variable
* based on ANN with ensemble spread (std), climatological spread, ensemble mean(with removed annual cycle), lat/lon, week as input , uses softmax to return class probabilities
* always use train data to remove annual cycle and to compute other summary statistics
* masking to have the same NANs over the whole period
* L2 regularization and early stopping
* skill still low (accuracy ~ 0.37)
* categories are unequally distributed in the observations (near-normal is less frequent)
* large differences between obs and fct (each with removed annual cycle)
* improved standardization (outside of pre-processing, always using train data)

# Imports

In [None]:
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Sequential
import tensorflow.keras as keras

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


import xarray as xr
xr.set_options(display_style='text')



from dask.utils import format_bytes
import xskillscore as xs

%matplotlib inline 
#so that figures appear again

#for prediction
from scripts import make_probabilistic
from scripts import add_valid_time_from_forecast_reference_time_and_lead_time
from scripts import skill_by_year
from scripts import add_year_week_coords

In [None]:
cache_path = '../template/data' #if you change this you also have to adjust the git lfs pull paths

# Get training data

preprocessing of input data may be done in separate notebook/script

## Hindcast

get weekly initialized hindcasts

In [None]:
# preprocessed as renku dataset
!git lfs pull ../template/data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr

In [None]:
hind_2000_2019 = xr.open_zarr(f'{cache_path}/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr', consolidated=True)

## Observations
corresponding to hindcasts

In [None]:
# preprocessed as renku dataset
!git lfs pull ../template/data/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr

In [None]:
obs_2000_2019 = xr.open_zarr(f'{cache_path}/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr', consolidated=True)#[v]

terciled

In [None]:
!git lfs pull ../template/data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr

In [None]:
obs_2000_2019_terciled = xr.open_zarr(f'{cache_path}/hindcast-like-observations_2000-2019_biweekly_terciled.zarr', consolidated=True)

### Select region

to make life easier for the beginning --> no periodic padding needed.

In [None]:
lat = slice(90,0)
lon = slice(0,90) #negative,positive will not work

In [None]:
hind_2000_2019 = hind_2000_2019.sel(longitude = lon, latitude = lat)
obs_2000_2019 = obs_2000_2019.sel(longitude = lon, latitude = lat)
obs_2000_2019_terciled = obs_2000_2019_terciled.sel(longitude = lon, latitude = lat)

## Train Validation split

In [None]:
# time is the forecast_time
time_train_start,time_train_end='2000','2017' # train
time_valid_start,time_valid_end='2018','2019' # valid

## Weatherbench

based on [Weatherbench](https://github.com/pangeo-data/WeatherBench/blob/master/quickstart.ipynb)

In [None]:
# run once only and dont commit
#!git clone https://github.com/pangeo-data/WeatherBench/

In [None]:
import sys
sys.path.insert(1, 'WeatherBench')
from WeatherBench.src.train_nn import PeriodicConv2D, create_predictions#DataGenerator, 

### define some vars

In [None]:
v='t2m'
bs=32

https://s2s-ai-challenge.github.io/

We deal with two fundamentally different variables here: 
- Total precipitation is precipitation flux pr accumulated over lead_time until valid_time and therefore describes a point observation. 
- 2m temperature is averaged over lead_time(valid_time) and therefore describes an average observation. 

The submission file data model unifies both approaches and assigns 14 days for week 3-4 and 28 days for week 5-6 marking the first day of the biweekly aggregate.

In [None]:
# 2 bi-weekly `lead_time`: week 3-4
lead = hind_2000_2019.lead_time[0]

### create datasets

In [None]:
#mask: same missing values at all forecast_times
mask = xr.where(obs_2000_2019.notnull(),1,np.nan).mean('forecast_time', skipna = False)
mask

#What if hind contains nan?--> precip

In [None]:
#train:
#uses only ensemble mean so far
fct_train = hind_2000_2019.sel(forecast_time=slice(time_train_start,time_train_end))[v]
verif_train = obs_2000_2019_terciled.sel(forecast_time=slice(time_train_start,time_train_end))[v]

fct_train = fct_train.where(mask[v].notnull())
verif_train = verif_train.where(mask[v].notnull())

In [None]:
#validation
fct_valid = hind_2000_2019.sel(forecast_time=slice(time_valid_start,time_valid_end))[v]
verif_valid = obs_2000_2019_terciled.sel(forecast_time=slice(time_valid_start,time_valid_end))[v]

fct_valid = fct_valid.where(mask[v].notnull())
verif_valid = verif_valid.where(mask[v].notnull())

### Remove bias from fct
did not improve the skill

obs_train = obs_2000_2019.sel(forecast_time=slice(time_train_start,time_train_end))[v]
obs_valid = obs_2000_2019.sel(forecast_time=slice(time_valid_start,time_valid_end))[v]

#from mean_bias_reduction
from scripts import add_year_week_coords
fct_train_bias = add_year_week_coords(fct_train.mean('realization') - obs_train).groupby('week').mean().compute()
fct_valid_bias = add_year_week_coords(fct_valid.mean('realization') - obs_valid).groupby('week').mean().compute()

fct_train = add_year_week_coords(fct_train) - fct_train_bias.sel(week=fct_train.week)
fct_valid = add_year_week_coords(fct_valid) - fct_valid_bias.sel(week=fct_valid.week)

### Preprocessing

In [None]:
def rm_annualcycle(ds, ds_train):
    #remove annual cycle for each location 
    
    ds = add_year_week_coords(ds)
    ds_train = add_year_week_coords(ds_train)
    
    if 'realization' in ds_train.coords:#always use train data to compute the annual cycle
        ens_mean = ds_train.mean('realization')
    else:
        ens_mean = ds_train

    ds_stand = ds - ens_mean.groupby('week').mean(['forecast_time'])

    ds_stand = ds_stand.sel({'week' : ds.coords['week']})
    ds_stand = ds_stand.drop(['week','year'])
    ds_stand
    return ds_stand

In [None]:
def ann_preprocess(ds, ds_train, v,lead):

    ds = ds.sel(lead_time = lead)
    ds_train = ds_train.sel(lead_time = lead)

    #remove annual cycle for each location 
    ds = rm_annualcycle(ds, ds_train)

    #compute ensemble mean
    ens_mean = ds.mean('realization')

    #compute ensemble spread
    ens_spread = ds.std('realization')

    #provide climatological ensemble spread #from obs or from ens?
    ds_train_weekly = add_year_week_coords(ds_train)
    if 'realization' in ds_train_weekly.coords:
        ds_train_weekly = ds_train_weekly#.mean('realization')

    clim_spread = ds_train_weekly.groupby('week').std(['forecast_time','realization'])
    clim_spread = clim_spread.sel({'week' : add_year_week_coords(ds).coords['week']})
    clim_spread = clim_spread.drop(['week','year'])
    
    #provide time feature
    week = add_year_week_coords(ds)
    week_ = np.cos(2*np.pi/53*(week.week +53/2))
    week_ = week_.drop(['week','year'])
    week_ = week_.expand_dims({'longitude': clim_spread.longitude, 'latitude': clim_spread.latitude})
    week_ = week_.transpose('forecast_time', 'latitude', 'longitude')
    
    #combine data arrays
    ens_mean = ens_mean.to_dataset(name = 'mean_{}'.format(v))
    spread = ens_spread.to_dataset(name = 'spread_{}'.format(v))
    clim_spread = clim_spread.to_dataset(name = 'clim_spread_{}'.format(v))
    week_ = week_.to_dataset(name = 'week')
    combined = xr.combine_by_coords([ens_mean, spread, clim_spread, week_])

    df = combined.to_dataframe()
    df = df.drop(['lead_time','valid_time'], axis =1).reset_index()
    
    df = df.dropna(axis = 0)
    
    #to get input shape back later
    df_ref = df
    
    df = df.drop(['forecast_time'], axis = 1)#,'latitude','longitude'
    
    return df, df_ref

In [None]:
def ann_preprocess_label(ds,v,lead):
    df = ds.sel(lead_time = lead).to_dataframe()
    df = df.drop(['lead_time','valid_time'], axis =1).reset_index()
    df = df.pivot(index = ['forecast_time', 'latitude','longitude'], columns = 'category', values = v).reset_index()
    df.rename_axis(None, inplace = True, axis = 1)
    df = df.dropna(axis = 0)
    
    
    df=df.drop(['forecast_time','latitude','longitude'], axis = 1)
    df=df[['below normal', 'near normal','above normal']]
    return df

In [None]:
#define dataframes

df_verif_train = ann_preprocess_label(verif_train, v, lead)
df_fct_train, df_fct_train_ref = ann_preprocess(fct_train, fct_train, v, lead)
df_verif_valid = ann_preprocess_label(verif_valid, v, lead)
df_fct_valid, df_fct_valid_ref = ann_preprocess(fct_valid, fct_train, v, lead)

In [None]:
#standardize input

mean_fct_train = df_fct_train.mean(axis = 0)
std_fct_train = df_fct_train.std(axis = 0)

#validation set using train mean and std
df_fct_valid   = (df_fct_valid - mean_fct_train)/std_fct_train

df_fct_train   = (df_fct_train - mean_fct_train)/std_fct_train

In [None]:
df_verif_train

In [None]:
df_fct_train

In [None]:
df_fct_train_ref

### ANN

In [None]:
from tensorflow.keras.layers import *
from tensorflow.keras.regularizers import l2, l1

ann = keras.models.Sequential([
    Dense(10, input_shape=(6,), kernel_regularizer=l2(0.01), bias_regularizer=l2(0.01)), #activation='relu'),
    Activation('elu'),
    #Dense(10, input_shape=(10,), kernel_regularizer=l2(0.001), bias_regularizer=l2(0.001)), #activation='relu'),
    #Activation('elu'),
    #Dropout(0.4),
    Dense(3),# activation='softmax'),
    #x = x+ log(1/3) ###add climatological probabilities
    Activation('softmax')
])

In [None]:
ann.summary()

In [None]:
ann.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')#keras.optimizers.Adam(0.05))#keras.optimizers.Adam(1e-4), 'mse')

In [None]:
import warnings
warnings.simplefilter("ignore")

In [None]:
#early stopping
import tensorflow as tf
early_stopping = tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)

In [None]:
ann.fit(df_fct_train, df_verif_train, batch_size = 100, epochs=5, validation_data=(df_fct_valid, df_verif_valid),
        callbacks=[early_stopping])

### Predict

In [None]:
###predict plus first postprocessing

predicted_bins = pd.DataFrame(ann.predict(df_fct_valid), columns = df_verif_train.columns)

In [None]:
def postprocess_output(output, df_ref, ds_input, v):
    #add columns
    
    output['latitude'] = df_ref.latitude.values
    output['longitude'] = df_ref.longitude.values
    output['forecast_time'] = df_ref.forecast_time.values
    
    #merge category columns into one
    output = output.melt(id_vars = ['forecast_time','latitude','longitude'], var_name = 'category', 
                                       value_name = v)#'t2m'
    
    #create MultiIndex
    output = output.pivot_table(values = v, index = ['latitude','longitude','forecast_time','category'])
    
    #convert to dataset
    xr_output = xr.Dataset.from_dataframe(output)
    
    #retain the complete coords    
    temp = ds_input.sel(lead_time = lead).drop(['valid_time','lead_time'])
    temp = temp.to_dataset(name = 'zeros')
    merged = xr.merge([xr_output, temp])
    merged = merged.drop('zeros')

    return merged 

In [None]:
xr_predicted_bins = postprocess_output(predicted_bins, df_fct_valid_ref, fct_valid, v=v)

In [None]:
#change order of categories
xr_predicted_bins = xr_predicted_bins.reindex(category=[xr_predicted_bins.category[1].values, 
                                                                xr_predicted_bins.category[2].values, 
                                                                xr_predicted_bins.category[0].values ])

### print predictions

# Prediction does not depend on the forecast_time!!!

if lat/lon is used as features, predicted fields become even smoother and the RPSS is marginally lower

week seems to slightly improve the forecast (in terms of accuracy)


In [None]:
xr_predicted_bins['t2m'].isel(forecast_time = 0).plot(col = 'category')
#xr_predicted_bins['t2m'].isel(forecast_time = slice(0,50,10)).plot(col = 'category', row = 'forecast_time')#.isel(forecast_time = 50)

In [None]:
#ground truth
verif_valid.isel(forecast_time = 0, lead_time = 0).to_dataset()['t2m'].plot(col = 'category')

## Climatological category probabilities for the observations in the train set (i.e. train labels)

In [None]:
annual_cycle = add_year_week_coords(verif_train)
annual_cycle = annual_cycle.groupby('week').mean(['forecast_time'])
#annual_cycle.sel({'week' : verif_train.coords['week']})
annual_cycle.sel(week = 1, lead_time = lead).plot(col = 'category')

for week 1, the category 'near normal' is less frequent than the other two categories.

This holds also for most locations in all the other weeks (see next figure).

In [None]:
#for obs, only 18 years... 5/18 = 0.27 6/18 =0.33 / 7/18 = 0.38
annual_cycle = annual_cycle.sel(lead_time = lead).stack( z = ('latitude','longitude')).reset_index("z")
annual_cycle.plot.line(hue = 'z', col = 'category', add_legend = False)

In [None]:
annual_cycle.sum('z').plot.line(hue = 'category')

Summed up over all grid points, it is obvious, that the **near-normal category is underrepresented**. There is no indication of seasonal differences in the climatological category probabilities. The underrepresentation is most likely  the result of using the ensemble forecasts for computing the category edges. The ensemble forecasts are most likely underdispersive. Therefore, it actually "makes sense" that the ANN predicts lower probabilities for the near normal category.

## Case study for 2018-01-02 to compare the anomaly fields (removed annual cycle) of the forecasts to the observations

potential skill

In [None]:
ds = fct_valid
ds_train = fct_train
ds = ds.sel(lead_time = lead)
ds_train = ds_train.sel(lead_time = lead)

#remove annual cycle for each location 
ds = rm_annualcycle(ds, ds_train)

#### ensemble

In [None]:
ds.isel(forecast_time = 0).plot(col = 'realization', col_wrap = 4)

In [None]:
ds.isel(forecast_time = 0).mean('realization').plot()

#### obs

In [None]:
obs_valid = obs_2000_2019.sel(forecast_time=slice(time_valid_start,time_valid_end))[v]

obs_train = obs_2000_2019.sel(forecast_time=slice(time_train_start,time_train_end))[v]
obs_valid_ = rm_annualcycle(obs_valid, obs_train).sel(lead_time = lead)
obs_valid_.isel(forecast_time = 0).plot()

#### conclusion
The forecast fields and the observations deviate quite a bit, which limits the accuracy of the ANN probability predictions.

In [None]:
#remove annual cycle from obs using train forecasts.
#not the best thing to do..., forecasts are biased!
annual_cycle = add_year_week_coords(fct_train)
annual_cycle = annual_cycle.groupby('week').mean(['forecast_time','realization'])

(obs_valid.isel(forecast_time = 0).sel(lead_time = lead) - annual_cycle.sel(week = 1, lead_time = lead)).plot()

In [None]:
#verif_valid.sel(lead_time = lead).isel(forecast_time = 0).plot(col = 'category')

## Compute RPSS

In [None]:
#computes RPSS wrt climatology (1/3 for each category. So, negative RPSS are worse than climatology...

def skill_by_year_single(prediction, terciled_obs):
    """version of skill_by_year adjusted to one var and one lead time and flexibel validation period"""
    fct_p = prediction
    obs_p = terciled_obs


    # climatology
    clim_p = xr.DataArray([1/3, 1/3, 1/3], dims='category', coords={'category':['below normal', 'near normal', 'above normal']}).to_dataset(name='tp')
    clim_p['t2m'] = clim_p['tp']

    clim_p = clim_p[v]

    ## RPSS
    # rps_ML
    rps_ML = xs.rps(obs_p, fct_p, category_edges=None, dim=[], input_distributions='p').compute()
    # rps_clim
    rps_clim = xs.rps(obs_p, clim_p, category_edges=None, dim=[], input_distributions='p').compute()

    # rpss
    rpss = 1 - (rps_ML / rps_clim)

    # https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge-template/-/issues/7

    # penalize
    penalize = obs_p.where(fct_p!=1, other=-10).mean('category')
    rpss = rpss.where(penalize!=0, other=-10)

    # clip
    rpss = rpss.clip(-10, 1)

    # average over all forecasts
    rpss_year = rpss.groupby('forecast_time.year').mean()

    # weighted area mean
    weights = np.cos(np.deg2rad(np.abs(rpss_year.latitude)))
    # spatially weighted score averaged over lead_times and variables to one single value
    scores = rpss_year.sel(latitude=slice(None, -60)).weighted(weights).mean('latitude').mean('longitude')
    #scores = scores.to_array().mean(['lead_time', 'variable'])

    return scores.to_dataframe('RPSS') 

In [None]:
prediction = xr_predicted_bins.reindex(latitude=xr_predicted_bins.latitude[::-1])
#prediction.t2m

In [None]:
skill_by_year_single(prediction.t2m, 
                      verif_valid.sel(lead_time = lead))

#### The RPSS of this approach is clearly higher than for the ensemble post-processing approach. Thus, the ANN predicting tercile probabilities outperforms the ensemble post-processing approach.
However, this is probably because this approach predicts the smoothest fields and relaxes the most towards climatology. 

# Reproducibility

## memory

In [None]:
# https://phoenixnap.com/kb/linux-commands-check-memory-usage
!free -g

## CPU

In [None]:
!lscpu

## software

In [None]:
!conda list