# Train ML model to correct predictions of week 3-4 & 5-6

This notebook create a Machine Learning `ML_model` to predict weeks 3-4 & 5-6 based on `S2S` weeks 3-4 & 5-6 forecasts and is compared to `CPC` observations for the [`s2s-ai-challenge`](https://s2s-ai-challenge.github.io/).

# Synopsis

## Method: `ML-based mean bias reduction`

- calculate the ML-based bias from 2000-2019 deterministic ensemble mean forecast
- remove that the ML-based bias from 2020 forecast deterministic ensemble mean forecast

## Data used

type: renku datasets

Training-input for Machine Learning model:
- hindcasts of models:
    - ECMWF: `ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr`

Forecast-input for Machine Learning model:
- real-time 2020 forecasts of models:
    - ECMWF: `ecmwf_forecast-input_2020_biweekly_deterministic.zarr`

Compare Machine Learning model forecast against against ground truth:
- `CPC` observations:
    - `hindcast-like-observations_biweekly_deterministic.zarr`
    - `forecast-like-observations_2020_biweekly_deterministic.zarr`

## Resources used
for training, details in reproducibility

- platform: renku
- memory: 8 GB
- processors: 2 CPU
- storage required: 10 GB

## Safeguards

All points have to be [x] checked. If not, your submission is invalid.

Changes to the code after submissions are not possible, as the `commit` before the `tag` will be reviewed.
(Only in exceptions and if previous effort in reproducibility can be found, it may be allowed to improve readability and reproducibility after November 1st 2021.)

### Safeguards to prevent [overfitting](https://en.wikipedia.org/wiki/Overfitting?wprov=sfti1) 

If the organizers suspect overfitting, your contribution can be disqualified.

  - [x] We did not use 2020 observations in training (explicit overfitting and cheating)
  - [x] We did not repeatedly verify my model on 2020 observations and incrementally improved my RPSS (implicit overfitting)
  - [x] We provide RPSS scores for the training period with script `print_RPS_per_year`, see in section 6.3 `predict`.
  - [x] We tried our best to prevent [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)?wprov=sfti1).
  - [x] We honor the `train-validate-test` [split principle](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets). This means that the hindcast data is split into `train` and `validate`, whereas `test` is withheld.
  - [x] We did not use `test` explicitly in training or implicitly in incrementally adjusting parameters.
  - [x] We considered [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)).

### Safeguards for Reproducibility
Notebook/code must be independently reproducible from scratch by the organizers (after the competition), if not possible: no prize
  - [x] All training data is publicly available (no pre-trained private neural networks, as they are not reproducible for us)
  - [x] Code is well documented, readable and reproducible.
  - [x] Code to reproduce training and predictions is preferred to run within a day on the described architecture. If the training takes longer than a day, please justify why this is needed. Please do not submit training piplelines, which take weeks to train.

# Todos to improve template

This is just a demo.

- [ ] use multiple predictor variables and two predicted variables
- [ ] for both `lead_time`s in one go
- [ ] consider seasonality, for now all `forecast_time` months are mixed
- [ ] make probabilistic predictions with `category` dim, for now works deterministic

# Description of this notebook

* makes probabilistic predictions for categories
* only for one lead time and temperature variable
* based on ANN with ensemble spread (max - min) and ensemble mean as input (with removed annual cycle), uses softmax to return class probabilities
* low skill (accuracy ~ 0.33)

# Imports

In [None]:
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Sequential
import tensorflow.keras as keras

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


import xarray as xr
xr.set_options(display_style='text')



from dask.utils import format_bytes
import xskillscore as xs

%matplotlib inline 
#so that figures appear again

#for prediction
from scripts import make_probabilistic
from scripts import add_valid_time_from_forecast_reference_time_and_lead_time
from scripts import skill_by_year

In [None]:
cache_path = '../template/data' #if you change this you also have to adjust the git lfs pull paths

# Get training data

preprocessing of input data may be done in separate notebook/script

## Hindcast

get weekly initialized hindcasts

In [None]:
# preprocessed as renku dataset
!git lfs pull ../template/data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr

In [None]:
hind_2000_2019 = xr.open_zarr(f'{cache_path}/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr', consolidated=True)

## Observations
corresponding to hindcasts

In [None]:
# preprocessed as renku dataset
!git lfs pull ../template/data/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr

In [None]:
obs_2000_2019 = xr.open_zarr(f'{cache_path}/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr', consolidated=True)#[v]

terciled

In [None]:
!git lfs pull ../template/data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr

In [None]:
obs_2000_2019_terciled = xr.open_zarr(f'{cache_path}/hindcast-like-observations_2000-2019_biweekly_terciled.zarr', consolidated=True)

### Select region

to make life easier for the beginning --> no periodic padding needed.

In [None]:
hind_2000_2019 = hind_2000_2019.sel(longitude = slice(0,30), latitude = slice(70,40))
obs_2000_2019 = obs_2000_2019.sel(longitude = slice(0,30), latitude = slice(70,40))
obs_2000_2019_terciled = obs_2000_2019_terciled.sel(longitude = slice(0,30), latitude = slice(70,40))

In [None]:
#hind_2000_2019
#obs_2000_2019_terciled

## Train Validation split

In [None]:
# time is the forecast_time
time_train_start,time_train_end='2000','2017' # train
time_valid_start,time_valid_end='2018','2019' # valid

## Weatherbench

based on [Weatherbench](https://github.com/pangeo-data/WeatherBench/blob/master/quickstart.ipynb)

In [None]:
# run once only and dont commit
#!git clone https://github.com/pangeo-data/WeatherBench/

In [None]:
import sys
sys.path.insert(1, 'WeatherBench')
from WeatherBench.src.train_nn import PeriodicConv2D, create_predictions#DataGenerator, 

### define some vars

In [None]:
v='t2m'
bs=32

https://s2s-ai-challenge.github.io/

We deal with two fundamentally different variables here: 
- Total precipitation is precipitation flux pr accumulated over lead_time until valid_time and therefore describes a point observation. 
- 2m temperature is averaged over lead_time(valid_time) and therefore describes an average observation. 

The submission file data model unifies both approaches and assigns 14 days for week 3-4 and 28 days for week 5-6 marking the first day of the biweekly aggregate.

In [None]:
# 2 bi-weekly `lead_time`: week 3-4
lead = hind_2000_2019.lead_time[0]

### create datasets

In [None]:
#train:
#uses only ensemble mean so far
fct_train = hind_2000_2019.sel(forecast_time=slice(time_train_start,time_train_end))[v]
verif_train = obs_2000_2019_terciled.sel(forecast_time=slice(time_train_start,time_train_end))[v]

#make sure that all nans from obs are also nan in hind, if hind contains nans, this does not work!
fct_train = fct_train.where(verif_train.mean('category', skipna = False).notnull())

In [None]:
#validation
fct_valid = hind_2000_2019.sel(forecast_time=slice(time_valid_start,time_valid_end))[v]
verif_valid = obs_2000_2019_terciled.sel(forecast_time=slice(time_valid_start,time_valid_end))[v]
#make sure that all nans from obs are also nan in hind, if hind contains nans, this does not work!
fct_valid = fct_valid.where(verif_valid.mean('category', skipna = False).notnull())

### Remove bias from fct
did not improve the skill

In [None]:
obs_train = obs_2000_2019.sel(forecast_time=slice(time_train_start,time_train_end))[v]
obs_valid = obs_2000_2019.sel(forecast_time=slice(time_valid_start,time_valid_end))[v]

In [None]:
#from mean_bias_reduction
from scripts import add_year_week_coords
fct_train_bias = add_year_week_coords(fct_train.mean('realization') - obs_train).groupby('week').mean().compute()
fct_valid_bias = add_year_week_coords(fct_valid.mean('realization') - obs_valid).groupby('week').mean().compute()

In [None]:
fct_train = add_year_week_coords(fct_train) - fct_train_bias.sel(week=fct_train.week)
fct_valid = add_year_week_coords(fct_valid) - fct_valid_bias.sel(week=fct_valid.week)

### More preprocessing

In [None]:
from scripts import add_year_week_coords
def ann_preprocess(ds, v,lead):
    ds = ds.sel(lead_time = lead)
    
    #remove annual cycle for each location 
    ds = add_year_week_coords(ds)
    ens_mean = ds.mean('realization')
    ens_mean = ens_mean - ens_mean.groupby('week').mean(['forecast_time'])

    #compute ensemble spread, remove local seasonal cycle
    spread = ds.max('realization')-ds.min('realization')
    spread = spread/spread.groupby('week').mean(['forecast_time'])
    
    #combine data arrays
    spread = spread.to_dataset(name = 'spread_{}'.format(v))
    ens_mean = ens_mean.to_dataset(name = 'mean_{}'.format(v))
    combined = xr.combine_by_coords([ens_mean, spread])
    combined= combined.sel({'week' : ds.coords['week']})
    
    df = combined.to_dataframe()
    df = df.drop(['lead_time','valid_time','week','year'], axis =1).reset_index()
    
    df = df.dropna(axis = 0)
    
    #to get input shape back later
    df_ref = df
    
    df = df.drop(['forecast_time','latitude','longitude'], axis = 1)
    df = (df - df.mean(axis = 0))/df.std(axis = 0)#standardize everything
    
    return df, df_ref

In [None]:
def ann_preprocess_label(ds,v,lead):
    df = ds.sel(lead_time = lead).to_dataframe()
    df = df.drop(['lead_time','valid_time'], axis =1).reset_index()
    df = df.pivot(index = ['forecast_time', 'latitude','longitude'], columns = 'category', values = v).reset_index()
    df.rename_axis(None, inplace = True, axis = 1)
    df = df.dropna(axis = 0)
    
    
    df=df.drop(['forecast_time','latitude','longitude'], axis = 1)
    df=df[['below normal', 'near normal','above normal']]
    return df

In [None]:
#define dataframes

df_verif_train = ann_preprocess_label(verif_train, v, lead)
df_fct_train, df_fct_train_ref = ann_preprocess(fct_train, v, lead)
df_verif_valid = ann_preprocess_label(verif_valid, v, lead)
df_fct_valid, df_fct_valid_ref = ann_preprocess(fct_valid, v, lead)

In [None]:
df_verif_train

In [None]:
df_fct_train

### ANN

In [None]:
from tensorflow.keras.layers import *

ann = keras.models.Sequential([
    Dense(10, input_shape=(2,), activation='relu'),
    #Dropout(0.2),
    Dense(3),
    Activation('softmax')
])

In [None]:
ann.summary()

In [None]:
ann.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')#keras.optimizers.Adam(1e-4), 'mse')

In [None]:
import warnings
warnings.simplefilter("ignore")

In [None]:
ann.fit(df_fct_train, df_verif_train, batch_size = 100, epochs=2, validation_data=(df_fct_valid, df_verif_valid))

### Predict

In [None]:
###predict plus first postprocessing

predicted_bins = pd.DataFrame(ann.predict(df_fct_valid), columns = df_verif_train.columns)

In [None]:
def postprocess_output(output, df_ref,v):
    #add columns
    output = output.assign(latitude = df_ref.latitude, longitude = df_ref.longitude,
                           forecast_time = df_ref.forecast_time)
    #merge category columns into one
    output = output.melt(id_vars = ['forecast_time','latitude','longitude'], var_name = 'category', 
                                       value_name = v)#'t2m'
    #create MultiIndex
    output = output.pivot_table(values = v, index = ['latitude','longitude','forecast_time','category'])
    
    #convert to dataset
    xr_output = xr.Dataset.from_dataframe(output)
    
    return xr_output

In [None]:
xr_predicted_bins = postprocess_output(predicted_bins, df_fct_valid_ref, v=v)

In [None]:
#change order of categories
xr_predicted_bins = xr_predicted_bins.reindex(category=[xr_predicted_bins.category[1].values, 
                                                                xr_predicted_bins.category[2].values, 
                                                                xr_predicted_bins.category[0].values ])

### print predictions

In [None]:
xr_predicted_bins.isel(forecast_time = 0)['t2m'].plot(col = 'category')

In [None]:
verif_valid.isel(forecast_time = 0, lead_time = 0).to_dataset()['t2m'].plot(col = 'category')

using obs as input without the spread feature improved the model performance (to 0.4)
this is still really low, so 

### Ground truth

In [None]:
!git lfs pull ../template/data/hindcast-like-observations_2000-2019_biweekly_tercile-edges.nc
tercile_edges = xr.open_dataset(f'{cache_path}/hindcast-like-observations_2000-2019_biweekly_tercile-edges.nc')

In [None]:
tercile_edges = tercile_edges.sel(longitude = slice(0,30), latitude = slice(70,40))
tercile_edges

In [None]:
from scripts import make_probabilistic

obs_preds = make_probabilistic(obs_valid, tercile_edges)
obs_preds

In [None]:
obs_preds.isel(forecast_time = 0, lead_time = 0)['t2m'].plot(col = 'category')

### to create tercile_edges
add week for groupby, see https://renkulab.io/gitlab/aaron.spring/s2s-ai-challenge/-/issues/29

obs_2000_2019 = add_year_week_coords(obs_2000_2019)
obs_2000_2019.chunk({'forecast_time':-1,'longitude':'auto'}).groupby('week').quantile(q=[1./3.,2./3.], 
               dim='forecast_time').rename({'quantile':'category_edge'}).astype('float32').to_netcdf(tercile_file)
               
tercile edges und weeks beziehen sich auf forecast date. i.e. die terciles für week 1 mit lead_time = 3-4wochen beziehen sich auf die Klimatologie für valid_time ende januar.

# Reproducibility

## memory

In [None]:
# https://phoenixnap.com/kb/linux-commands-check-memory-usage
!free -g

## CPU

In [None]:
!lscpu

## software

In [None]:
!conda list