# Hindcast Prediction Demo

---

**Authors:** Riley Brady and Aaron Spring

---

This demo demonstrates the capabilities of the prediction module for a decadal prediction ensemble that is initialized off of a reconstruction. Note that we use the word "reconstruction," but this could be replaced with "hindcast" or "assimilation" among other things. This differs from the "perfect-model" approach of, e.g., the MPI decadal prediction system. For a perfect-model approach, see `perfect-model_predictability.ipynb`.


**CESM Decadal Prediction Large Ensemble**

The prediction ensemble used here is the Community Earth System Model Decadal Prediction Large Ensemble (CESM-DPLE). It is initialized annually from 1954-2017 on November 1st from a Forced Ocean-Sea Ice (FOSI) reconstruction simulation.


**References**:

1. **Official manuscript of the CESM-DPLE release:** Yeager, S. G., et al. "Predicting near-term changes in the Earth System: A large ensemble of initialized decadal prediction simulations using the Community Earth System Model." Bulletin of the American Meteorological Society 2018 (2018). 


2. **Applied case of the DPLE on air-sea CO$_{2}$ fluxes:** Lovenduski, N. S., Yeager, S. G., Lindsay, K., and Long, M. C.: Predicting near-term changes in ocean carbon uptake, Earth Syst. Dynam. Discuss., https://doi.org/10.5194/esd-2018-73, in review, 2018. 


3. **Broad overview of decadal climate prediction and terminology:** Meehl, Gerald A., et al. "Decadal climate prediction: an update from the trenches." Bulletin of the American Meteorological Society 95.2 (2014): 243-267.

In [None]:
import numpy as np
import xarray as xr
from climpred.stats import rm_trend, corr
from climpred.prediction import (compute_hindcast, compute_persistence,
                                 compute_uninitialized)
import matplotlib.pyplot as plt 
%matplotlib inline
from climpred.tutorial import load_dataset

In [None]:
load_dataset()

For this demo, I'm using `proplot` from https://github.com/lukelbd/proplot. If you're compiling this notebook on your own, you need to run `plot.install_fonts()` and then restart your notebook.

## Load and process data

`climpred` contains a folder with post-processed sample data for MPI and CESM ensembles. To avoid supplying massive files, these are all computed as global (area-weighted) averages of SST at annual resolution.

### Reconstruction

The CESM-DPLE (Community Earth System Model-Decadal Prediction Large Ensemble) is initialized from a forced ocean sea-ice (FOSI) reconstruction. This reconstruction mainly uses CORE forcing (i.e., a data atmosphere), with active (or modeled) ocean and sea ice components. It has been shown to reasonably reproduce historical ocean conditions, including El Niño events.

The reconstruction output is provided at monthly resolution, but for the purpose of this demo, we will just look at annual means.

In [None]:
def _load_reconstruction():
    recon = load_dataset('FOSI-SST')
    recon = recon.sel(time=slice(1955, 2015))     # Same timeframe as DPLE
    recon = recon['SST']
    recon.name = 'reconstruction'
    return recon

In [None]:
recon = _load_reconstruction()

The reconstruction comes out as raw output, but we want to compare it directly to the anomalies provided by the DPLE. For annual averages, we just subtract the mean of the simulation. For monthly output, we have to remove monthly climatologies.

In [None]:
# The climatology for the DPLE was computed over 1964-2014 so we should
# generate our anomalies with the same window.
recon = recon - recon.sel(time=slice(1964, 2014)).mean('time')

### CESM Decadal Prediction Large Ensemble

Here, we load in the actual initialized CESM-DPLE to compute prediction metrics with (see intro to notebook).

In [None]:
def _load_dple():
    dple = load_dataset('CESM-DP-SST')
    dple = dple.sel(init=slice(1955, 2015))
    return dple

In [None]:
dple = _load_dple()

In [None]:
dple.info()
dple = dple['SST'] # easier to work with DataArray since we don't have other variables
dple.name = 'initialized'

### CESM Large Ensemble

This is the uninitialized companion ensemble to the Decadal Prediction Ensemble. Here we have an ensemble (40 members) of simulations that are generated from minor perturbations to the initial conditions of one member in 1920.

This differs from the CESM-DPLE, as it is only initialized in 1920 and never again. It differs from FOSI, as it is a freely running ESM, *i.e.,* it is not being nudged/assimilated/forced by any data products.

**Reference:**

Kay, J. E., et al. "The Community Earth System Model (CESM) large ensemble project: A community resource for studying climate change in the presence of internal climate variability." Bulletin of the American Meteorological Society 96.8 (2015): 1333-1349.

In [None]:
cesmLE = load_dataset('CESM-LE')['SST']
# remove mean to generate anomaly like other
cesmLE = cesmLE - cesmLE.sel(time=slice(1964, 2014)).mean('time')
# take ensemble mean for this demo
cesmLE = cesmLE.mean('member')
cesmLE.name = 'uninitialized'
cesmLE.to_dataset().info()

### ERSSTv4 Observations

It is useful to compare the DPLE to observations to get a sense of actual skill. When we correlate the DPLE with the FOSI, we get a sense of its potential to predict the Earth system ("potential predictability"), if the initial conditions and model equations were perfect, and the resolution sufficiently fine. If we correlate to ERSST, we get actual prediction skill.

**Reference**:

https://iridl.ldeo.columbia.edu/SOURCES/.NOAA/.NCDC/.ERSST/.version4/

In [None]:
data = load_dataset('ERSST')['SST']
data = data - data.sel(time=slice(1964, 2014)).mean('time')
data.name = 'data'

## High-level view of CESM-DPLE for Global SSTs

To get the user comfortable with what a DPLE looks like, we provide some simple plots to look at the structure of global SSTs in CESM-DPLE. 

In [None]:
dple_mean = dple.mean('member')

### Ensemble Mean View

By taking the ensemble mean across all 40 members, we get a sense of what the best prediction is from the ensemble. This is like the thick black line on a spaghetti plot of individual hurricane forecasts. Note that we haven't detrended yet, so you will see the SST warming trend from 1955-2015.

In [None]:
def set_mpl_aeshetics(ax, cb=None):
    """Sets fontsizes for matplotlib plots"""
    ax.tick_params(labelsize=14)
    if cb is not None:
        cb.ax.tick_params(labelsize=12)

In [None]:
f, ax = plt.subplots()
m = ax.pcolormesh(dple_mean.lead, dple_mean.init, dple_mean,
                  cmap='RdBu_r', vmin=-0.6, vmax=0.6)
cbar = plt.colorbar(m, boundaries=np.arange(-0.6, 0.61, 0.1))
ax.set_title('DPLE Ensemble Mean SST', fontsize=18)
ax.set_xlabel('Lead Year', fontsize=14)
ax.set_ylabel('Initialization Year', fontsize=14)
set_mpl_aeshetics(ax, cbar)

### Initialized Predictions

The whole point of CESM-DPLE is that the ensemble is initialized with "data" from the FOSI every single year and then run forward in a coupled ESM. The question this begs is, how well can a freely coupled ESM reproduce the reconstruction?

Here is a look at what the spread of ensemble members looks like at a few initialization points. The thin pink lines are the individual 40 members, the dark pink line is the ensemble mean, and the black line is the FOSI that is trying to be replicated by CESM-DPLE.

**NOTE**: You might notice that the CESM-DPLE starts at a point slightly different than the FOSI. This is because, per protocol, the CESM-DPLE was initialized in November of the preceding year, so it is slightly different than the annual means from the FOSI.

In [None]:
f, ax = plt.subplots(figsize=(8,2))
r = ax.plot(recon.time, recon, linewidth=1.5, color='k',
            label='reconstruction')
init_years = [1960, 1980, 1995]
for iy in init_years:
    case = dple.sel(init=iy)
    case['lead'] = np.arange(iy, iy+10)
    f = ax.plot(case.lead, case, color='#9932CC', linewidth=0.5, alpha=0.75,
                label='individual forecasts')
    fm = ax.plot(case.lead, case.mean('member'), linewidth=2, color='#4B0082',
                 label='forecast mean', zorder=4)
    ax.plot(iy, case.isel(lead=0).mean('member'), 'o', markersize=6,
            color='#4B0082')

ax.set_xlabel('Year', fontsize=12)
ax.set_ylabel('SST Anomaly [K]', fontsize=12)
ax.set_title('Initialized Forecasts of Global SST', fontsize=16)
set_mpl_aeshetics(ax)

## Basic Prediction Metrics

Now, let's get into some predictability metrics (see intro for definitions). We aren't doing anything too advanced here; we are just correlating anomalies (trended and detrended) with the FOSI reference to get a sense of the skill of our predictions.


**NOTE**: For each of the metrics, we will be plotting a version that retains the long-term warming trend and a version that is detrended. The latter is much more common and important in the prediction community: can we predict anomalies relative to the secular trend that are mostly produced by random fluctuations in the climate system?

In [None]:
recon_dt = rm_trend(recon)
recon_dt.name = 'reconstruction' # naming for easier plotting later
dple_dt = rm_trend(dple_mean, dim='init')
dple_dt.name = 'initialized'
cesmLE_dt = rm_trend(cesmLE)
cesmLE_dt.name = 'uninitialized'
data_dt = rm_trend(data)
data_dt.name = 'data'

### Ensemble Mean Plots

First, let's look at the ensemble mean (lead year 1 for CESM-DPLE) of each of our products: the initialized prediction ensemble, the uninitialized ensemble, the reconstruction, and our actual SST data.

Here you see the power of the initialized ensemble. Compare the bottom panel's red line to the aqua line.

In [None]:
COLORS = {'recon': '#000000', 'hind': '#790604',
          'uninit': '#006994', 'data': '#a9957b'}

f, ax = plt.subplots(nrows=2, sharex=True, sharey=True)

"""
TOP PANEL:
Ensemble means for all four products without removing the trend.
"""
r = ax[0].plot(recon.time, recon, color=COLORS['recon'], linewidth=2, 
               label='reconstruction')
i = ax[0].plot(dple.init, dple_mean.isel(lead=0), linewidth=2, 
               color=COLORS['hind'], label='initialized forecast')
u = ax[0].plot(cesmLE.time, cesmLE, color=COLORS['uninit'], linewidth=2,
               label='uninitialized forecast')
d = ax[0].plot(data.time, data, color=COLORS['data'], linewidth=2,
               label='ERRSTv4')
ax[0].set_title('Raw SST Anomalies')
ax[0].set_ylabel('SST Anomalies [K]')


"""
BOTTOM PANEL:
Ensemble means for all four products after removing a linear trend.
"""
ax[1].plot(recon_dt.time, recon_dt, color=COLORS['recon'], linewidth=2)
ax[1].plot(dple_dt.init, dple_dt.isel(lead=0), color=COLORS['hind'], linewidth=2)
ax[1].plot(cesmLE_dt.time, cesmLE_dt, color=COLORS['uninit'], linewidth=2)
ax[1].plot(data_dt.time, data_dt, color=COLORS['data'], linewidth=2)
ax[1].set_title('Detrended SST Anomalies')
ax[1].set_ylabel('SST Anomalies [K]')
ax[1].set_xlabel('Year')

plt.legend()
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.setp(plt.gca().get_legend().get_texts(), fontsize=12)
plt.show()

### Potential Predictability and Skill

Now we can leverage the simulations and data we have to compute potential predictability (in reference to the FOSI) and skill (in reference to ERSST)

In [None]:
def _compute_skills(recon, dple, cesmLE, data, predType='potential'):
    """
    Quick function to compute the predictability/skill given all
    four datasets.
    
    predType should either be 'potential' for potential predictability
    or 'skill' for true prediction skill.
    """
    if 'member' in dple.dims:
        dple = dple.mean('member')
    if predType == 'potential':
        # Initialized ensemble predictability
        ip = compute_hindcast(dple, recon)
        # Uninitialized ensemble predictability
        up = compute_uninitialized(cesmLE, recon)
        up = xr.concat([up]*10, 'lead')
        # Persistence forecast
        persist = compute_persistence(dple, recon)
    elif predType == 'skill':
        ip = compute_hindcast(dple, data)
        up = corr(cesmLE, data)
        up = xr.DataArray([np.asarray(up)]*10, dims='lead') # DP protocol
        persist = compute_persistence(dple, data)
    return ip, up, persist


def _plot_skill(ax, result, color='k', linestyle='-', marker='o', 
                markersize=6, linewidth=2, **kwargs):
    """
    Quick function to plot results of predictability analysis.
    """
    if 'lead time' not in result.coords:
        N = len(result)
        result['lead'] = np.arange(1, N+1)
    p = ax.plot(result['lead'], result, color=color, linestyle=linestyle, 
                marker=marker, markersize=markersize, linewidth=linewidth, 
                **kwargs)
    return p

In [None]:
COLORS = {'recon': '#000000', 'hind': '#790604',
          'uninit': '#006994', 'data': '#a9957b',
          'persist': '#D3D3D3'}
f, axs = plt.subplots(nrows=2, ncols=2, figsize=(12,4),
                      sharey=True)

trended_set = xr.merge([recon, dple, cesmLE, data])
detrended_set = xr.merge([recon_dt, dple_dt, cesmLE_dt, data_dt])
titles = ['trended predictability', 'trended skill',
          'detrended predictability', 'detrended skill']

for ax, predType, ds, title in zip(axs.ravel(), 
                            ['potential', 'skill', 'potential', 'skill'],
                            [trended_set, trended_set, detrended_set, 
                            detrended_set], titles):
    ip, up, persist = _compute_skills(ds.reconstruction, ds.initialized,
                                      ds.uninitialized, ds.data, 
                                      predType=predType)
    i = _plot_skill(ax, ip, color='r', label='initialized forecast')
    u = _plot_skill(ax, up, color=COLORS['uninit'], linewidth=1.5,
                    label='uninitialized forecast')
    p = _plot_skill(ax, persist, color=COLORS['persist'], linestyle='--',
                   label='persistence forecast')
    ax.set_ylim([-1, 1.1])
    
f.suptitle('Global SST Predictability', fontsize=16)
axs[0,0].set_title('potential predictability', fontsize=14)
axs[0,1].set_title('skill', fontsize=14)
axs[0,0].set_ylabel('trended ACC', fontsize=13)
axs[1,0].set_ylabel('detrended ACC', fontsize=13)