# Eruption forecasting with Bayesian Networks for Whakaari

In [28]:
import warnings
warnings.simplefilter(action='ignore')

from collections.abc import Sequence, Callable
import math
import os
from pathlib import Path
from datetime import datetime
from functools import partial
import shutil

import numpy as np
import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio
pio.templates.default = "plotly_white"

import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import (LongitudeFormatter,
                                   LatitudeFormatter)

from ipywidgets import widgets

from whakaaribn import (load_all_whakaari_data,
                        load_whakaari_catalogue,
                        Discretizer,
                        get_group_labels,
                        pre_eruption_window,
                        WhakaariModel,
                        get_color,
                        BayesNet,
                        get_data)

from sklearn import set_config
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import (roc_auc_score,
                             log_loss,
                             average_precision_score,
                             auc)
import tqdm as tqdm
import xarray as xr
set_config(transform_output="pandas")

## Whakaari image and network structure

In [None]:
# Set up the data directory to store figures and results
data_dir = os.path.join(Path().resolve(), 'data')
if not os.path.exists(data_dir):
    print("creating ", data_dir)
    os.makedirs(data_dir)
else:
    print(data_dir, "already exists")

In [None]:
fn = get_data('data/271912_Brad Scott_GNS Science.jpg')
arr_img = plt.imread(fn, format='jpg')
y, x, z = arr_img.shape
factor=300
fig = plt.figure(figsize=(8, 6))
ax1 = fig.add_axes([0.02,0.525,0.5,0.45], projection=ccrs.Mercator())
ax2 = fig.add_axes([0.48, 0.5, 0.51, 0.5])
ax3 = fig.add_axes([0.1,0.05,0.4,0.4])
ax4 = fig.add_axes([0.55,0.05,0.4,0.4])


# Add a Map
ax1.set_extent([160, 180, -49, -32], crs=ccrs.PlateCarree())
ax1.coastlines(resolution='50m')
ax1.add_feature(cfeature.LAND, facecolor='lightgray')
ax1.add_feature(cfeature.OCEAN, facecolor='lightblue')
url = "https://basemaps.linz.govt.nz/v1/tiles/aerial/WebMercatorQuad/WMTSCapabilities.xml?api=c01jj05fc72acjxevhtrem76m80"
layer = "aerial"
ax1.add_wmts(url, layer)

label_style = {'color': 'black', 'weight': 'bold', 'size': 8}
gl = ax1.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                                     linewidth=1, color='gray', alpha=0.5,
                                     linestyle='--', xlabel_style=label_style,
                                     ylabel_style=label_style)
gl.top_labels = False
gl.right_labels = False
gl.xlines = True
gl.ylines = True
gl.xlocator = mticker.FixedLocator([166, 172])
gl.ylocator = mticker.FixedLocator([-45, -40, -35])
gl.xformatter = LongitudeFormatter(direction_label=True)
gl.yformatter = LatitudeFormatter(direction_label=True)
gl.xpadding = -1
gl.ypadding = -5
ax1.plot(177.18270881065772, -37.51992819737241, marker='o', color='red',
                  markersize=3, transform=ccrs.PlateCarree())
ax1.text(
    .1,
    .9,
    "(A)",
    horizontalalignment="center",
    transform=ax1.transAxes,
    fontdict={"color": "k", "fontsize": 14}
)

# Add an image of Whakaari
ax2.imshow(arr_img)
ax2.set_xticks([])
ax2.set_yticks([])
ax2.text(
    .1,
    .9,
    "(B)",
    horizontalalignment="center",
    transform=ax2.transAxes,
    fontdict={"color": "k", "fontsize": 14}
)

# Add Network structure for causal model
magma_edges = [('Magmatic Intrusion', 'Eruption'),
               ('Magmatic Intrusion', 'Eqr'),
               ('Magmatic Intrusion', 'RSAM'),
               ('Magmatic Intrusion', 'SO2'),
               ('Magmatic Intrusion', 'CO2'),
               ('Magmatic Intrusion', 'H2S')]
seal_edges = [('Hydrothermal Seal', 'Eruption'),
              ('Hydrothermal Seal', 'Eqr'),
              ('Hydrothermal Seal', 'RSAM'),
              ('Hydrothermal Seal', 'SO2'),
              ('Hydrothermal Seal', 'CO2'),
              ('Hydrothermal Seal', 'H2S')]

G1 = nx.DiGraph(magma_edges + seal_edges)

G2 = nx.DiGraph([('Eruption', 'Eqr'),
                 ('Eruption', 'RSAM'),
                 ('Eruption', 'SO2'),
                 ('Eruption', 'CO2'),
                 ('Eruption', 'H2S'),
                 ('Eqr', 'RSAM'),
                 ('Eqr', 'SO2'),
                 ('Eqr', 'CO2'),
                 ('Eqr', 'H2S'),
                 ('RSAM', 'SO2'),
                 ('RSAM', 'CO2'),
                 ('RSAM', 'H2S'),
                 ('CO2', 'SO2'),
                 ('CO2', 'H2S'),
                 ('SO2', 'H2S')])

# group nodes by rows
bottom_nodes = ['Magmatic Intrusion', 'Hydrothermal Seal']
top_nodes = ['Eqr', 'RSAM', 'SO2', 'CO2', 'H2S']

# set the position according to column (x-coord)
x_pos_bottom = [0., 1.1]
x_pos_top = [0., 0.35, 0.7, 1, 1.3]
pos = {n: (_x, 0) for n, _x in zip(bottom_nodes, x_pos_bottom)}
pos.update({n: (_x, .5) for n, _x in zip(top_nodes, x_pos_top)})
pos.update({'Eruption': (-.3, 0.25)})
pos = nx.spring_layout(G1, pos=pos)
pos.update({'Magmatic Intrusion': (-.3, 0.25)})
pos.update({'Hydrothermal Seal': (.3, -0.25)})

arrowsize = 14
nx.draw_networkx_edges(G1, pos, edgelist=magma_edges,
                       edge_color='r', ax= ax3,
                       arrowsize=arrowsize)
nx.draw_networkx_edges(G1, pos, edgelist=seal_edges,
                       edge_color='b', ax= ax3,
                       arrowsize=arrowsize)

label_options = {"fc": "white", "alpha": 0.5}
nx.draw_networkx_labels(G1, pos, font_size=12, ax=ax3, bbox=label_options, verticalalignment='center')
nx.draw_networkx_nodes(G1, pos, ax= ax3)

ax3.axis("off")
ax3.text(
    0.1,
    0.9,
    "(C)",
    horizontalalignment="center",
    transform=ax3.transAxes,
    fontdict={"color": "k", "fontsize": 14}
)

# Add network structure for fully-connected model
pos = nx.kamada_kawai_layout(G2)
nx.draw_networkx_edges(G2, pos, ax= ax4, arrowsize=arrowsize)
nx.draw_networkx_nodes(G2, pos, ax= ax4)
label_options = {"fc": "white", "alpha": 0.5}
nx.draw_networkx_labels(G2, pos, font_size=12, ax=ax4, bbox=label_options, verticalalignment='center')
ax4.set_xlim(-1.2, 1.35)
ax4.axis("off")
ax4.text(
    0.05,
    0.9,
    "(D)",
    horizontalalignment="center",
    transform=ax4.transAxes,
    fontdict={"color": "k", "fontsize": 14}
)
fig.savefig('data/whakaari_island_and_graphs.png', dpi=300, bbox_inches='tight')


## Data preparation

In [3]:
def group_train_test_split(data, groups):
    assert len(data) == len(groups)
    data_ = pd.DataFrame(data.copy())
    data_['group'] = groups
    test_data = data_[data_['group']=='d']
    remainder = data_[data_['group']=='e']
    train_data = data_.drop(data_[data_.group == 'd'].index)
    train_data = train_data.drop(train_data[data_.group == 'e'].index)
    
    test = test_data.drop(columns=['group'])
    train = train_data.drop(columns=['group'])
    remainder = remainder.drop(columns=['group'])
    return train, test, remainder
    
data = load_all_whakaari_data(fill_method=None, startdate=datetime(2009, 1, 1),
                              enddate=datetime(2024, 9, 10),
                              ignore_data=('LP', 'VLP'), ignore_cache=True,
                              ignore_all_caches=True, fuse_so2=False)


def get_train_test_data(data, ndays=30, min_interval=360, min_size=2):
    groups=get_group_labels(data.index[0], data.index[-1], ndays=ndays,
                            min_interval=min_interval, min_size=min_size)
    X_train, X_test, X_remainder = group_train_test_split(data, groups)
    eruptions = load_whakaari_catalogue(min_size, '0D')
    dfe = eruptions.loc[data.index[0]:]
    dates = pd.date_range(data.index[0], data.index[-1], freq='1D')
    dfe = dfe.reindex(dates, fill_value=0)
    dfe.drop(['delta', 'tvalue'], axis=1, inplace=True)
    y_train, y_test, y_remainder = group_train_test_split(np.sign(dfe['Activity_Scale']), groups)
    return X_train, X_test, X_remainder, y_train, y_test, y_remainder, groups


X_train, X_test, X_remainder, y_train, y_test, y_remainder, groups = get_train_test_data(data)

assert y_train.shape[0] + y_test.shape[0] == X_train.shape[0] + X_test.shape[0]

## Data plot

In [None]:
fig = make_subplots(rows=11, cols=1, shared_xaxes=True, vertical_spacing=0.02,
                    specs=[[{}], 
                           [{"rowspan": 2, "secondary_y": True}],
                           [{}], 
                           [{"rowspan": 2, "secondary_y": True}],
                           [{}],
                           [{"rowspan": 2, "secondary_y": True}], 
                           [{}], 
                           [{"rowspan": 2, "secondary_y": True}],
                           [{}],
                           [{"rowspan": 2, "secondary_y": True}],
                           [{}]],
                    start_cell="bottom-left")
plot_dict = {'RSAM': dict(name='RSAM [nm/s]', mode='lines'),
             'Eqr': dict(name='Eq. rate [1/day]', mode='lines'),
             'CO2': dict(name=u'CO\u2082 [t/day]', mode='markers'),
             'SO2': dict(name=u'SO\u2082 [t/day]', mode='markers'),
             'H2S': dict(name=u'H\u2082S [t/day]', mode='markers')}
plot_data = data.assign(group=groups)
plot_data = plot_data.assign(eruptions=pre_eruption_window(pd.concat([y_train, y_test, y_remainder]), 30))
for i, g in enumerate(plot_data.groupby('group')):
    fig.add_trace(go.Scatter(x=g[1].index, y=g[1]['eruptions'], mode='lines', line_color=get_color(i+2),
                             showlegend=False, legendgroup='group1', legendgrouptitle_text='Data splits',
                             name='Group %s' % g[0]), row=1, col=1)
seismic_end_date = pd.Timestamp('2022-08-04')
for i, col in enumerate(plot_dict.keys()):
    showlegend = True
    if i > 0:
        showlegend = False
    _x = plot_data.index
    _y = plot_data[col].ffill()
    _y_raw = plot_data[col]
    if col in ['RSAM', 'Eqr']:
        _x = _x[_x <= seismic_end_date]
        _y = _y.iloc[:len(_x)]
        _y_raw = plot_data[col].iloc[:len(_x)]
    fig.add_trace(go.Scatter(x=_x, y=_y, mode='lines', line_color=get_color(1),
                             showlegend=showlegend, name='Interpolated data', legendgroup='group',
                             legendgrouptitle_text="Input data"), row=i*2+2, col=1, secondary_y=False)
    color = get_color(0)
    if plot_dict[col]['mode'] == 'markers':
        color = get_color(0, alpha=0.7)
    fig.add_trace(go.Scatter(x=_x, y=_y_raw, mode=plot_dict[col]['mode'],
                             showlegend=showlegend, name='Raw data', legendgroup='group',
                             legendgrouptitle_text="Input data", line_color=color), row=i*2+2, col=1,
                    secondary_y=False)
    
    fig.update_yaxes(title=plot_dict[col]['name'], row=i*2+2, col=1, secondary_y=False)
if True:
    eruptions = load_whakaari_catalogue(2, '0D')
    dfe_ = eruptions.loc["2004-01-01":]
    showlegend = True 
    for row in [2, 4, 6, 8, 10]:
        for i in range(len(dfe_.index)):
            fig.add_trace(go.Scatter(x=[dfe_.index[i], dfe_.index[i]], y=[0., 1.], mode='lines',
                                     line_width=.8, line_color='black', name='Observed eruption', 
                                     showlegend=showlegend), secondary_y=True, row=row, col=1)
            showlegend = False

fig.update_layout(height=1000, width=1200)
fig.update_layout(legend=dict(y=1.15, orientation='h'))
fig.update_yaxes(type='log', nticks=3, secondary_y=False)
fig.update_yaxes(type='linear', row=1, col=1)
fig.update_yaxes(dtick=1, row=1, col=1)
x_annot= "2009-01-01"
y_annot = 0.9
fig.add_annotation(text="<b>(F)</b>", xref="x", yref="y domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(E)</b>", xref="x", yref="y2 domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(D)</b>", xref="x", yref="y5 domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(C)</b>", xref="x", yref="y8 domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(B)</b>", xref="x", yref="y11 domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(A)</b>", xref="x", yref="y14 domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>Pre-eruption window</b>", xref="x", yref="y domain", x="2010-09-01", y=y_annot, showarrow=False)
fig.add_annotation(text="<b>Data groups:</b>", xref="x", yref="y domain", x="2010-01-01", y=.2, showarrow=False)
fig.add_annotation(text="<b>Group A</b>", xref="x", yref="y domain", x="2011-10-01", y=.2, showarrow=False)
fig.add_annotation(text="<b>Group B</b>", xref="x", yref="y domain", x="2013-03-01", y=.2, showarrow=False)
fig.add_annotation(text="<b>Group C</b>", xref="x", yref="y domain", x="2015-06-01", y=.2, showarrow=False)
fig.add_annotation(text="<b>Group D</b>", xref="x", yref="y domain", x="2018-03-01", y=.2, showarrow=False)
fig.add_annotation(text="<b>Group E</b>", xref="x", yref="y domain", x="2022-02-01", y=.2, showarrow=False)
fig.update_yaxes(showticklabels=False, showgrid=False, secondary_y=True)
fig.write_image('./data/dataset_plot_whakaari.png', height=1000, width=1200, scale=3)
fig

## Forecasts 

In [5]:
class SequentialGroupSplit:
    def __init__(self, groups):
        self.groups = groups
    
    def split(self, X, y=None, groups=None):
        _data = X.copy()
        conds = []
        group_ids = np.unique(self.groups)
        for i in range(group_ids[0:-1].size):
            conds.append(f"(self.groups == '{group_ids[i]}')")
            try:
                _data_train = _data[eval('|'.join(conds))]
            except KeyError as e:
                print(conds)
                raise e
            train_idx0 = _data.index.get_indexer([_data_train.index[0]])[0]
            train_idx1 = _data.index.get_indexer([_data_train.index[-1]])[0]
            train = np.arange(train_idx0, train_idx1+1)
            _data_test = _data[self.groups == group_ids[i+1]]
            test_idx0 = _data.index.get_indexer([_data_test.index[0]])[0]
            test_idx1 = _data.index.get_indexer([_data_test.index[-1]])[0]
            test = np.arange(test_idx0, test_idx1+1)
            yield train, test

    def get_n_splits(self, X=None, y=None, groups=None):
        return np.unique(self.groups).size - 1


def test_sequential_group_split():
    groups = np.array(['a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd'])
    data = pd.DataFrame({'x': np.arange(groups.size)}, index=pd.date_range('2000-01-01', periods=groups.size))
    sgs = SequentialGroupSplit(groups)
    i = 1
    for train, test in sgs.split(data):
        assert train.size == i*3
        assert test.size == 3
        i += 1 
    assert train.size == groups.size - 3
    assert sgs.get_n_splits() == 3

test_sequential_group_split()

In [6]:
def forecasts(data: pd.DataFrame, exclude_from_test: Sequence=(), pew: int=30, expert_only: bool=False, eq_sample_size: int=1,
              modelfile: str='data/Whakaari_4s_initial1.xdsl', bins: tuple=(0, 5, 95, 100),
              hidden_nodes: bool=True, uniformize: bool=False, randomize: bool=False, zarr_store='data/whakaari_forecasts.zarr',
              recompute=False, save_trained_model=False, smoothing=None, ex_nodes=[]):
    """
    Compute BN forecasts
    """
    if zarr_store is not None:
        group = "/model={}/expert_only={}/bins={}/eq_sample_size={}/pew={}/exclude_from_test={}".format(modelfile.replace('/', '_'), expert_only, str(bins),
                                                        eq_sample_size, pew, exclude_from_test)
        path = os.path.join(zarr_store, group[1:])
        if os.path.isdir(path) and not recompute:
            xds = xr.open_zarr(path, consolidated=False)
            print(f"Loading forecasts from {path}")
            return xds
    data_fill = data.ffill(axis=0)
    data_fill.loc['2022-07-01':, 'RSAM'] = np.nan
    data_fill.loc['2022-07-01':, 'Eqr'] = np.nan
    pipe = Pipeline([('discretize', Discretizer(bins=bins, strategy='quantile', names=None)),
                     ('clf', WhakaariModel(expert_only=expert_only, uniformize=uniformize, eq_sample_size=eq_sample_size,
                                          randomize=randomize, hidden_nodes=hidden_nodes,
                                          modelfile=modelfile, smoothing=smoothing))])

    x_train, x_test, x_remainder, y_train, y_test, y_remainder, groups = get_train_test_data(data_fill, ndays=30)
    y_train = pre_eruption_window(y_train, pew)
    y_test = pre_eruption_window(y_test, pew)
    y_all = pd.concat([y_train, y_test, y_remainder])
    cv = SequentialGroupSplit(groups) 
    probs = np.zeros(data_fill.shape[0])
    magma = np.zeros(data_fill.shape[0])
    seal = np.zeros(data_fill.shape[0])
    sens = np.zeros(data_fill.shape)
    disc_data = np.full(data_fill.shape, '*', dtype='<U7')
    init = True
    for train, test in cv.split(data_fill):
        pipe.fit(data_fill.iloc[train], y_all.iloc[train])
        if init:
            probs[train] = pipe.predict_proba(data_fill.iloc[train])[:, 1]
            magma[train] = pipe['clf'].hidden_proba_[:, 1]
            seal[train] = pipe['clf'].hidden_proba_[:, 3]
            init = False
        _data_test = data_fill.copy()
        for col in exclude_from_test:
            _data_test[col] = np.nan
        probs[test] = pipe.predict_proba(_data_test.iloc[test])[:, 1]
        magma[test] = pipe['clf'].hidden_proba_[:, 1]
        seal[test] = pipe['clf'].hidden_proba_[:, 3]
        _sens = pd.DataFrame(pipe.score(_data_test.iloc[test]))
        sens[test, :] = _sens.values[:]
        disc_data[test, :] = pipe[:-1].transform(data_fill.iloc[test])
    if save_trained_model:
        trained_modelfile = modelfile.replace('.xdsl', '_trained.xdsl')
        pipe['clf'].net_.write(trained_modelfile)
        datafile = modelfile.replace('.xdsl', '_data.csv')
        pd.DataFrame(data=disc_data, index=data_fill.index, columns=data_fill.columns).to_csv(datafile)

    xds = xr.Dataset(
        {
            "probs": (["time"], probs),
            "probs_min": (["time"], probs),
            "probs_max": (["time"], probs),
            "magma": (["time"], magma),
            "magma_min": (["time"], magma),
            "magma_max": (["time"], magma),
            "seal": (["time"], seal),
            "seal_min": (["time"], seal),
            "seal_max": (["time"], seal),
            "sens": (["time", "type"], sens),
            "original_data": (["time", "type"], data.values),
            "discrete_data": (["time", "type"], disc_data),
            "y_all": (["time"], y_all.values.squeeze())
        },
        coords={"time": data.index.tz_localize(None),
                "type": data.columns.astype(str)},
    )
    if zarr_store is not None:
        xds.to_zarr(zarr_store, group=group, mode='a')
    return xds


In [7]:
def forecast_plot(frcst, frcst_min=None, frcst_max=None, fig=None, ploterr=True, log=False, eruptions=True,
                  label='Forecast', showlegend=True, color_id=0, alpha=1., window=30,
                  row=1, col=1):
    """
    Plot timeseries of forecast probabilities.
    Arguments:
    ----------
        frcst: pandas.Dataframe 
            The forecast probabilities.
        fig: plotly Figure, optional
            The figure to add the plot to. If None, a new figure is created.
        ploterr: bool, optional
            Whether to plot error bounds of the forecast probabilities.
        log: bool, optional
            Whether to plot the y-axis on a log scale.
        eruptions: bool, optional
            Whether to plot the eruption times.
        label: str, optional
            The label of the forecast.
        color_id: int, optional
            The color of the forecast.
        window: int, optional
            The window size for the rolling mean.
        row: int, optional
            The row to add the plot to.
        col: int, optional
            The column to add the plot to.
    Returns:
    --------
        fig: plotly Figure
            The figure with the plot.

    """
    if fig is None:
        fig = make_subplots(specs=[[{"secondary_y": True}]])

    forecast = frcst
    if window is not None:
        forecast = frcst.rolling(dict(time=window)).mean()
    fig.add_trace(go.Scatter(x=pd.to_datetime(forecast.time), y=forecast, mode='lines',
                             line_color=get_color(color_id, alpha=alpha), name=label, showlegend=showlegend,
                             legendgroup='group1'), row=row, col=col)
    if ploterr:
        fig.add_trace(go.Scatter(x=pd.to_datetime(frcst_min.time), y=frcst_min, mode='lines', marker=dict(color="#444"),
                                 line=dict(width=0), showlegend=False), row=row, col=col)
        fig.add_trace(go.Scatter(x=pd.to_datetime(frcst_max.time), y=frcst_max, mode='lines', marker=dict(color="#444"),
                                 line=dict(width=0), showlegend=False, fillcolor=get_color(color_id, alpha=0.3),
                                 fill='tonexty'), row=row, col=col)
    if eruptions:
        eruptions = load_whakaari_catalogue(1, '0D')
        dfe = eruptions.loc[frcst.time[0].values:frcst.time[-1].values]
        showlegend = showlegend 
        for i in range(len(dfe.index)):
            fig.add_trace(go.Scatter(x=[dfe.index[i], dfe.index[i]], y=[0., 1.], mode='lines',
                                     line_width=.8, line_color='black', name='Observed eruption', 
                                     showlegend=showlegend), secondary_y=True, row=row, col=col)
            showlegend = False

    if log:
        fig.update_yaxes(type='log', secondary_y=False)
    return fig

### Grid search

In [8]:
def circular_node_positions(num_nodes, radius=200, offset=(200, 200)):
    positions = []
    for i in range(num_nodes):
        theta = (2 * math.pi * i) / num_nodes
        x = radius * math.cos(theta)
        y = radius * math.sin(theta)
        positions.append((int(x+offset[0]), int(y+offset[1])))
    return positions

def stacked_node_positions(num_causal_nodes, num_child_nodes, x_child=150,
                           y_causal=200, y_child=100, node_distance=100):
    """
    Create node positions for a stacked layout where causal nodes are at the top and child nodes at the bottom.
    """
    x_causal = int((x_child + (num_child_nodes-1)*node_distance) / 2. - (num_causal_nodes-1)*node_distance / 2.)
    positions = []
    for i in range(num_causal_nodes):
        positions.append((x_causal + i*node_distance, y_causal))
    for i in range(num_child_nodes):
        positions.append((x_child + i*node_distance, y_child))
    return positions

def fully_connected(nodes):
    edges = []
    for i in range(len(nodes)):
        for j in range(i+1, len(nodes)):
            edges.append((nodes[i], nodes[j]))
    return edges

def causal(causal_nodes, child_nodes):
    edges = []
    for i in range(len(child_nodes)):
        for j in range(len(causal_nodes)):
            edges.append((causal_nodes[j], child_nodes[i]))
    return edges

def create_network(network_file, node_names, positions, edges):
    net_ = BayesNet()
    for node, pos in zip(node_names, positions):
        node_name, states = node
        nstates = len(states) 
        node = net_.add_node(node_name, states, np.ones(nstates)/nstates, 
                             description=node_name, position=pos)
    for parent, child in edges:
        net_.add_arc(parent, child)
    net_.write(network_file)

for nstates in [2, 3, 4, 5]:
    obs_states = ["state_{:d}".format(i) for i in range(nstates)]
    binary_states = ["no", "yes"]
    node_names = [('eruptions', binary_states), ('Eqr', obs_states),
                  ('CO2', obs_states), ('RSAM', obs_states),
                  ('SO2', obs_states), ('H2S', obs_states)]
    causal_node_names = [('Magmatic_Intrusion', binary_states), ('Hydrothermal_Seal', binary_states)]
    positions = circular_node_positions(len(node_names)) 
    causal_positions = stacked_node_positions(len(causal_node_names), len(node_names)) 
    edges = fully_connected([node for node, _ in node_names])
    causal_edges = causal([node for node, _ in causal_node_names], [node for node, _ in node_names])
    create_network(f"data/fully_connected_model_{nstates}_states.xdsl", node_names, positions, edges) 
    create_network(f"data/causal_model_{nstates}_states.xdsl", causal_node_names + node_names, causal_positions, causal_edges) 

#### Validation functions

In [9]:
def get_evaluation_windows(starttime, endtime, pew):
    """
    Get the evaluation windows for the forecasted rates.

    Arguments:
    ----------
        starttime: pd.Timestamp
            The start time of the forecast.
        endtime: pd.Timestamp
            The end time of the forecast.
        pew: int
            The pre-eruption window size.
    Returns:
    --------
        positive_windows: list
            The pre-eruption windows.
        negative_windows: list
            Non pre-eruption windows. 
    """
    explosive_eruptions = load_whakaari_catalogue(2, '0D')
    explosive_eruptions = explosive_eruptions.loc[starttime:endtime]
    positive_windows = []
    negative_windows = []
    tstart = starttime
    for e in explosive_eruptions.iterrows():
        positive_windows.append((e[0] - pd.Timedelta(days=pew-1), e[0]))
        negative_windows.append((tstart, e[0] - pd.Timedelta(days=pew)))
        tstart = e[0] + pd.Timedelta(days=30)
    negative_windows.append((tstart, endtime))
    return positive_windows, negative_windows
 
def compute_rates(model: pd.DataFrame, pew: int, debug: bool=False):
    """
    Compute the forecasted rates for positive and negative windows.

    Arguments:
    ----------
        model: pandas.DataFrame
            The model probabilities.
        pew: int
            The pre-eruption window size.
        debug: bool, optional
            Whether to print debug information.
    Returns:
    --------
        dict: The forecasted rates for positive and negative windows as well as the overall rate.
    """ 
    positive_windows, negative_windows = get_evaluation_windows(model.index[0], model.index[-1], pew)
    positive_rates = []
    for win_start, win_end in positive_windows:
        tmp_model = model.loc[win_start:win_end]
        rates = -np.log(1-tmp_model)
        if debug:
            print('--->', win_start, win_end, 'Forecasted rate:', rates)
        positive_rates.append(rates)

    negative_rates = []
    for win_start, win_end in negative_windows:
        tmp_model = model.loc[win_start:win_end]
        rates = -np.log(1-tmp_model)
        if debug:
            print('--->', win_start, win_end, 'Forecasted rate:', rates)
        negative_rates.append(rates)
    positive_rates = np.concatenate(positive_rates)
    negative_rates = np.concatenate(negative_rates)

    yearly_scale = 365.25/pew
    pos_mean = np.mean(positive_rates) * yearly_scale
    pos_sem = np.std(positive_rates) / np.sqrt(len(positive_rates)) * yearly_scale
    neg_mean = np.mean(negative_rates) * yearly_scale
    neg_sem = np.std(negative_rates) / np.sqrt(len(negative_rates)) * yearly_scale
    overall_rate = np.mean(-np.log(1-model)) * yearly_scale
    overall_sem = np.std(-np.log(1-model)) / np.sqrt(len(model)) * yearly_scale
   
    return dict(positive_rates=(pos_mean, pos_sem), negative_rates=(neg_mean, neg_sem),
                overall_rate=(overall_rate, overall_sem)) 


def evaluate_threshold(thresh: float, model: pd.DataFrame,
                       debug: bool=False, pew: int=90,
                       return_windows: bool=False):
    """
    Evaluate the number of true positives, false positives, true negatives and false negatives for a given threshold.

    Arguments:
    ----------
        thresh: float
            The threshold value.
        model: pandas.DataFrame
            The forecasted probabilities.
        debug: bool, optional
            Whether to print debug information.
        pew: int, optional
            The pre-eruption window size.
        return_windows: bool, optional
            Whether to return the positive and negative windows.

    Returns:
    --------
        dict: The evaluation results.
        list: A list of dictionaries for the tp, fp, tn, fn windows.
    """
    starttime = max(pd.Timestamp(model.index[0]), pd.Timestamp('2013-01-01'))
    endtime = pd.Timestamp(model.index[-1])
    explosive_eruptions = load_whakaari_catalogue(2, '0D')
    explosive_eruptions = explosive_eruptions.loc[starttime:endtime]
    if debug:
        print(explosive_eruptions)
    model = model.loc[starttime:endtime]
    # normalise to 0-1
    model = (model - model.min()) / (model.max() - model.min())
    dt = pd.to_datetime(model.index)
    bin_model = np.where(model >= thresh, 1, 0)
    if len(bin_model.shape) > 1:
        bin_model = bin_model[0]
    # assign data on eruption days to the value on the day before 
    eidx = np.where(np.isin(model.index, explosive_eruptions.index))[0]
    bin_model[eidx] = bin_model[eidx - 1] 
    negative_windows = []
    positive_windows = []
    first_day = 0
    alert = bin_model[0]
    for i, date in enumerate(dt):
        if bin_model[i] != alert or i == len(dt) - 1:
            last_day = i - 1
            if i == len(dt) - 1:
                last_day = i
            if alert > 0:
                positive_windows.append((first_day, last_day))
            else:
                negative_windows.append((first_day, last_day))
            first_day = i
            alert = bin_model[i]

    true_positives = 0
    false_positives = 0
    true_negatives = 0
    false_negatives = 0
    windows = []
    if debug:
        print('Positive windows:')
    for win_start, win_end in positive_windows:
        date_start = dt[win_start]
        date_end = dt[win_end]
        if debug:
            print('--->', date_start, date_end)
        explosion_in_window = False
        for ee in explosive_eruptions.iterrows():
            if date_start < ee[0] <= date_end:
                if pew is not None:
                    fp_ = tp_ = tn_ = fn_ = 0
                    pre_eruption_window = ee[0] - pd.Timedelta(days=pew)
                    tp_ += (date_end - ee[0]).days
                    tp_ += (ee[0] - max(pre_eruption_window, date_start)).days
                    pre_win = (date_start - pre_eruption_window).days
                    if pre_win < 0:
                        fp_ += abs(pre_win)
                    else:
                        fn_ += pre_win
                    false_positives += fp_
                    true_positives += tp_
                    false_negatives += fn_
                    # subtract from previous windows true_negatives
                    true_negatives -= fn_
                else:
                    true_positives += (win_end - win_start + 1)
                    windows.append(dict(start=date_start, end=date_end, type='true_positive'))
                if debug:
                    print('Eruption in positive window: ', date_start, date_end)
                # stop if there is at least one eruption in the window
                explosion_in_window = True
                break
        if not explosion_in_window:
            false_positives += (win_end - win_start + 1)
            windows.append(dict(start=date_start, end=date_end, type='false_positive'))

    if debug:
        print('Negative windows:')
    for win_start, win_end in negative_windows:
        date_start = dt[win_start]
        date_end = dt[win_end]
        if debug:
            print('--->', date_start, date_end)
        explosion_in_window = False
        for ee in explosive_eruptions.iterrows():
            if date_start < ee[0] <= date_end:
                if pew is not None:
                    fp_ = tp_ = tn_ = fn_ = 0
                    tn_ = (date_end - ee[0]).days
                    pre_eruption_window = ee[0] - pd.Timedelta(days=pew)
                    fn_ = (ee[0] - max(pre_eruption_window, date_start)).days
                    pre_win = (date_start - pre_eruption_window).days
                    if pre_win < 0:
                        tn_ += abs(pre_win)
                    true_negatives += tn_
                    false_negatives += fn_
                else:
                    false_negatives += (win_end - win_start + 1)
                    # if window ends later than 09/01/2020 count the 
                    # days between that date and the window end date as true negatives
                    # as the last explosive eruption was on 09/12/2019
                    if date_end > pd.Timestamp('2020-01-09'):
                        diff = (date_end - pd.Timestamp('2020-01-09')).days
                        true_negatives += diff
                        false_negatives -= diff
                        windows.append(dict(start=date_start, end=date_end - pd.Timedelta(days=diff), type='false_negative'))
                        windows.append(dict(start=date_end - pd.Timedelta(days=diff-1), end=date_end, type='true_negative'))
                    else:
                        windows.append(dict(start=date_start, end=date_end, type='false_negative'))
                if debug:
                    print('Eruption in negative window: ', date_start, date_end)
                # stop if there is at least one eruption in the window
                explosion_in_window = True 
                break
        if not explosion_in_window:
            true_negatives += (win_end - win_start + 1)
            windows.append(dict(start=date_start, end=date_end, type='true_negative'))
    if return_windows:
        return dict(tp=true_positives, fp=false_positives, tn=true_negatives, fn=false_negatives), windows
    return dict(tp=true_positives, fp=false_positives, tn=true_negatives, fn=false_negatives)

def get_roc_curve(model: pd.DataFrame, thresholds: list, func: Callable,
                  debug: bool=False):
    """
    Compute the ROC curve for a given model and thresholds.

    Arguments:
    ----------
        model: pandas.DataFrame
            The model probabilities.
        thresholds: list
            The thresholds to evaluate.
        func: function
            The evaluation function.
        debug: bool, optional
            Whether to print debug information.
    """
    tpr = np.empty(len(thresholds))*0.
    fpr = np.empty(len(thresholds))*0.
    precision = np.empty(len(thresholds))*0.

    for i, thresh in enumerate(thresholds):
        result = func(thresh, model)
        if debug:
            print(thresh, result)
        try:
            tpr_ = result['tp']/(result['tp'] + result['fn'])
            fpr_ = result['fp']/(result['fp'] + result['tn'])
            prec_ = result['tp']/(result['tp'] + result['fp'])
        except ZeroDivisionError:
            print("Threshold: ", thresh, "True positives: ", result['tp'], "False negatives: ", result['fn'])
            print("Threshold: ", thresh, "True negatives: ", result['tn'], "False positives: ", result['fp'])
            continue
        # start the evaluation from the first correct alerts
        if result['tp'] == 0:
            continue
        tpr[i] = tpr_
        fpr[i] = fpr_
        precision[i] = prec_
    # make sure that tpr and fpr end in 1
    # so that the AUC value is comparable
    tpr = np.r_[tpr, 1.]
    fpr = np.r_[fpr, 1.]
    return tpr, fpr, precision


def make_strictly_increasing(sequence: Sequence) -> Sequence:
    """
    Make a sequence strictly increasing. Some of the ROC curves computed with
    our own metric are not strictly increasing due to the way tps, fps, tns and fns
    are defined. This function makes sure that the sequence is strictly increasing so
    that we can caluculate the AUC value.

    Arguments:
    ----------
        sequence: Sequence
            The sequence to make strictly increasing.

    Returns:
    --------
        Sequence: The strictly increasing sequence.
    """
    # Make a copy to avoid modifying the original list
    result = sequence
    
    # Iterate through the sequence starting from the second element
    for i in range(1, len(result)):
        # If the current element is not greater than the previous one
        if result[i] <= result[i - 1]:
            # Increment the current element to be greater than the previous one
            result[i] = result[i - 1]
    
    return result

def test_make_strictly_increasing():
    seq = [.1, .2, .3, .4, .2, .6, .7, .5, 1]
    assert make_strictly_increasing(seq) == [.1, .2, .3, .4, .4, .6, .7, .7, 1]

def test_evaluate_threshold():
    test_seq = np.r_[np.ones(30)*.6, np.ones(30)*.2, np.ones(40)*.7]
    dates = pd.date_range(pd.Timestamp('2013-10-21') - pd.Timedelta(days=100), periods=100)
    test_df = pd.Series(data=test_seq, index=dates)
    result, windows = evaluate_threshold(0.5, test_df, debug=False,
                                         pew=None, return_windows=True)
    assert result['tp'] == 40
    assert result['fp'] == 30 
    assert result['tn'] == 30
    assert result['fn'] == 0
    assert len(windows) == 3 
    for win in windows:
        if win['type'] == 'true_positive':
            assert win['end'] == pd.Timestamp('2013-10-20')
            assert (win['end'] - win['start']) == pd.Timedelta(days=39)

    test_seq = np.r_[np.ones(30)*.7, np.ones(30)*.2, np.ones(40)*.4]
    dates = pd.date_range(pd.Timestamp('2013-10-21') - pd.Timedelta(days=100), periods=100)
    test_df = pd.Series(data=test_seq, index=dates)
    result, windows = evaluate_threshold(0.5, test_df, debug=False,
                                         pew=None, return_windows=True)
    assert result['tp'] == 0
    assert result['fp'] == 30 
    assert result['tn'] == 0
    assert result['fn'] == 70


test_make_strictly_increasing()
test_evaluate_threshold()

In [None]:
def ap(estimator, X, y, w=1):
    prob_e = estimator.predict_proba(X)
    weights = np.where(y == 1, w, 1)
    score = average_precision_score(y, prob_e[:, 1], sample_weight=weights,
                                    average='weighted')
    return score 

def aic(estimator, X, y, dof=1):
    """
    Akaike Information Criterion
    """
    prob_e = estimator.predict_proba(X)
    score = -2*log_loss(y, prob_e[:, 1])
    score += 2*dof
    return score

def my_log_loss(estimator, X, y):
    prob_e = estimator.predict_proba(X)
    score = -log_loss(y, prob_e[:, 1])
    return score

def my_auc(estimator, X, y, w=1):
    prob_e = estimator.predict_proba(X)
    weights = np.where(y == 1, w, 1)
    score = roc_auc_score(y, prob_e[:, 1], sample_weight=weights,
                                    average='weighted')
    return score 

def my_roc_auc(estimator, X, y, pew=90):
    prob_e = estimator.predict_proba(X)
    thresholds = np.linspace(0.01, 0.99, 100)[::-1]
    mdl = pd.Series(prob_e[:, 1], index=X.index)
    assert mdl.shape[0] > 0
    tpr_bn, fpr_bn, precision_bn = get_roc_curve(mdl, thresholds, partial(evaluate_threshold, pew=pew))  
    score = auc(make_strictly_increasing(fpr_bn), tpr_bn)
    return score


params_gcv = [
    {   "discretize__bins": [(0, 5, 100), (0, 50, 100), (0, 95, 100)],
        "clf__hidden_nodes": [False],
        "clf__uniformize": [True],
        "clf__modelfile": ['data/fully_connected_model_2_states.xdsl']
    },
     {   "discretize__bins": [(0, 5, 95, 100), (0, 33, 66, 100), (0, 25, 75, 100)],
        "clf__hidden_nodes": [False],
        "clf__uniformize": [True],
        "clf__modelfile": ['data/fully_connected_model_3_states.xdsl']
    },
    {  "discretize__bins": [(0, 25, 50, 75, 100), (0, 5, 50, 95, 100), (0, 10, 50, 90, 100), (0, 20, 50, 80, 100)],
        "clf__hidden_nodes": [False],
        "clf__uniformize": [True],
        "clf__modelfile": ['data/fully_connected_model_4_states.xdsl']
    },
    {  "discretize__bins": [(0, 20, 40, 60, 80, 100), (0, 5, 20, 80, 95, 100), (0, 5, 25, 75, 95, 100)],
        "clf__hidden_nodes": [False],
        "clf__uniformize": [True],
        "clf__modelfile": ['data/fully_connected_model_5_states.xdsl']
    }
]

whmdl = WhakaariModel(modelfile='data/Whakaari_bn_start.xdsl', uniformize=False, randomize=False,
                      hidden_nodes=True, smoothing=30)
pipe = Pipeline([('discretize', Discretizer(strategy='quantile', names=None)),
                 ('clf', whmdl)])

pipe.set_output(transform="pandas")
cv = SequentialGroupSplit(groups[groups != 'e'])

search_results = {}
for pew in np.arange(10, 110, 10):
    search_results[pew] = {}
    print("Pre-eruption window = ", pew)
    for nstates, params in zip([2, 3, 4, 5], params_gcv):
        print("Number of states = ", nstates)
        dof = np.sum(2*nstates**np.arange(1,6))
        _y_train = pre_eruption_window(y_train, pew)
        y_test = pre_eruption_window(y_test, pew)
        search = GridSearchCV(estimator=pipe, param_grid=[params],
                              scoring={'average_precision': partial(ap, w=1),
                                    'aic': partial(aic, dof=dof),
                                    'log_loss': my_log_loss,
                                    'roc_auc': partial(my_auc, w=1),
                                    'mod_roc_auc': partial(my_roc_auc, pew=pew),
                                    'mod_roc_auc_no_pew': partial(my_roc_auc, pew=None)},
                              cv=cv, n_jobs=10, verbose=1, refit=False)
        with warnings.catch_warnings():
            warnings.simplefilter(action="ignore", category=FutureWarning)
            search.fit(pd.concat((X_train, X_test)).ffill(), pd.concat((_y_train, y_test)))
        sdf = pd.DataFrame(search.cv_results_)
        sdf = sdf.sort_values(by=['rank_test_mod_roc_auc_no_pew'])
        search_results[pew][nstates] = sdf

### Evaluation

In [None]:
pew = np.arange(10, 110, 10)
nstates = np.arange(2, 6)
metric = 'mod_roc_auc_no_pew'
data_types = ['RSAM', 'Eqr', 'CO2', 'SO2', 'H2S']
smooth_win = 30
fig = make_subplots(rows=2, cols=1, specs=[[{"secondary_y": True}], 
                                           [{"secondary_y": True}]], shared_xaxes=True)

sdf_tmp = search_results[pew[0]][nstates[0]]
sdf_tmp = sdf_tmp.sort_values(by=[f'rank_test_{metric}'])
estimator = sdf_tmp.iloc[0].params
xds = forecasts(pd.concat((X_train, X_test, X_remainder)), pew=pew[0], expert_only=False,
                modelfile=estimator['clf__modelfile'],
                bins=estimator['discretize__bins'],
                hidden_nodes=estimator['clf__hidden_nodes'],
                uniformize=estimator['clf__uniformize'],
                randomize=False, zarr_store=None,
                recompute=True, save_trained_model=False,
                smoothing=30)

forecast_plot(xds['probs'], fig=fig, log=False, ploterr=False,
            eruptions=True, label='Forecast', showlegend=True, color_id=0, window=smooth_win, row=1, col=1)
datatype = 'RSAM'
df_data = xds['original_data'].loc[dict(type=datatype)].to_pandas().fillna(method='ffill')
df_data_disc = xds['discrete_data'].loc[dict(type=datatype)].to_pandas()
fig.add_trace(go.Scatter(name=datatype, x=df_data.index, y=df_data, mode='lines', line_color=get_color(7)), row=2, col=1)
fig.add_trace(go.Scatter(name="%s_disc" % datatype, x=df_data_disc.index,
                        y=np.unique(df_data_disc, return_inverse=True)[1],
                        mode='lines', line_color=get_color(7, alpha=.5)), row=2, col=1, secondary_y=True)
fig.update_layout(height=600)
fig.update_xaxes(showticklabels=True, row=2, col=1)

figw = go.FigureWidget(fig)
dataset = widgets.Dropdown(
    options=list(data_types),
    value=data_types[0],
    description='Dataset:',
)

states = widgets.Dropdown(
    options=list(nstates),
    value=nstates[0],
    description='States:',
)

prewin = widgets.Dropdown(
    options=list(pew),
    value=pew[0],
    description='Pre-eruption window:',
)


def response(change):
    with figw.batch_update():
        sdf_tmp = search_results[prewin.value][states.value]
        sdf_tmp = sdf_tmp.sort_values(by=[f'rank_test_{metric}'])
        estimator = sdf_tmp.iloc[0].params
        xds = forecasts(pd.concat((X_train, X_test, X_remainder)), pew=prewin.value, expert_only=False,
                        modelfile=estimator['clf__modelfile'],
                        bins=estimator['discretize__bins'],
                        hidden_nodes=estimator['clf__hidden_nodes'],
                        uniformize=estimator['clf__uniformize'],
                        randomize=False, zarr_store=None,
                        recompute=True, save_trained_model=False,
                        smoothing=30)

        _df_data = xds['original_data'].to_pandas().fillna(method='ffill')
        _df_disc_data = xds['discrete_data'].to_pandas()
        for i in range(len(figw.data)):
            if figw.data[i].name == 'Forecast':
                figw.data[i].y = xds['probs']
                figw.data[i].x = pd.to_datetime(xds['probs']['time'])
            for node_name in data_types:
                if figw.data[i].name == "%s_disc" % node_name:
                    figw.data[i].y = np.unique(_df_disc_data[dataset.value], return_inverse=True)[1]
                    figw.data[i].x = _df_disc_data.index
                    figw.data[i].name = "%s_disc" % dataset.value 
                if figw.data[i].name == node_name:
                    figw.data[i].y = _df_data[dataset.value]
                    figw.data[i].x = _df_data.index 
                    figw.data[i].name = dataset.value 


dataset.observe(response, names='value')
states.observe(response, names='value')
prewin.observe(response, names='value')

widgets.VBox([widgets.HBox([dataset, states, prewin]), figw])

In [12]:
def get_best_model(search_results, pews, nstates, exclude_from_test=()):
    best_objective = -np.inf
    best_pew = pews[0]
    for _ns in nstates:
        for _pew in pews:
            sdf_tmp = search_results[_pew][_ns]
            sdf_tmp = sdf_tmp.sort_values(by=['rank_test_mod_roc_auc_no_pew'])
            if sdf_tmp.iloc[0].mean_test_mod_roc_auc_no_pew > best_objective:
                best_objective = sdf_tmp.iloc[0].mean_test_mod_roc_auc_no_pew
                best_estimator = sdf_tmp.iloc[0].params
                best_pew = _pew
    xds = forecasts(pd.concat((X_train, X_test, X_remainder)), pew=best_pew,
                    expert_only=False, exclude_from_test=exclude_from_test,
                    modelfile=best_estimator['clf__modelfile'],
                    bins=best_estimator['discretize__bins'],
                    hidden_nodes=best_estimator['clf__hidden_nodes'],
                    uniformize=best_estimator['clf__uniformize'],
                    randomize=False, zarr_store='data/best_model.zarr',
                    recompute=True, save_trained_model=False,
                    smoothing=30)
    return xds 

### Validation

In [13]:
def validation_plot(fcst: pd.Series, threshold: float, showlegend: bool=False,
                    debug: bool=False, fig=None, row: int=1, col: int=1):
    """
    Plot the forecast probabilities and the evaluation windows.

    Arguments:
    ----------
        fcst: pandas.DataFrame
            The forecast probabilities.
        threshold: float
            The threshold value.
        debug: bool, optional
            Whether to print debug information.
        fig: plotly Figure, optional
            The figure to add the plot to. If None, a new figure is created.
        row: int, optional
            The row to add the plot to.
        col: int, optional
            The column to add the plot to.
    """
    trace = (fcst - fcst.min())/(fcst.max() - fcst.min())
    time = trace.index
    stats_, time_windows = evaluate_threshold(threshold, trace, pew=None, return_windows=True)
    if debug:
        print(stats_)
    if fig is None:
        fig = make_subplots(rows=1, cols=1, specs=[[{"secondary_y": True}]])

    eruptions = load_whakaari_catalogue(2, '0D')
    eruptions = eruptions.loc['2013':'2020']
    showlegend = showlegend 
    for i in range(len(eruptions.index)):
        fig.add_trace(go.Scatter(x=[eruptions.index[i], eruptions.index[i]], y=[0., .7], mode='lines',
                                    line_width=.8, line_color='black', name='Observed Eruption', 
                                    showlegend=showlegend), secondary_y=False, row=row, col=col)
        showlegend = False
    cl_ = dict(true_positive=get_color(0), true_negative=get_color(2), false_positive=get_color(1), false_negative=get_color(3))
    for window in time_windows:
        start = window["start"]
        end = window["end"]
        type_ = window["type"]
        
        # Mask the time series within the current window
        mask = (time >= start) & (time <= end)
        x_window = time[mask]
        y_window = trace[mask]
        
        # Choose color based on the type
        color = cl_[type_] 
        
        # Add a trace for the shaded area
        fig.add_trace(go.Scatter(
            x=list(x_window) + list(x_window[::-1]),  # x-coordinates for fill
            y=list(y_window) + [threshold] * len(y_window),  # y-coordinates for fill
            fill='toself',
            fillcolor=color,
            line=dict(color='rgba(255,255,255,0)'),  # No line around fill
            hoverinfo='skip',
            showlegend=False
        ), row=row, col=col)
    for type_, color in cl_.items():
        fig.add_trace(go.Scatter(
            x=[None],  # Dummy x value
            y=[None],  # Dummy y value
            mode='markers',
            marker=dict(size=10, color=color),
            name=f"{type_.capitalize()}",
            showlegend=showlegend
        ), row=row, col=col)
        
    fig.update_yaxes(showticklabels=False, showgrid=False, secondary_y=True, row=row, col=col)
    return fig

In [14]:
xds_best = get_best_model(search_results, np.arange(10, 110, 10), np.arange(2, 6))
tstart = "2012-09-03"
tend = "2022-09-10"
fcst = xds_best['probs'].to_pandas()
fcst = fcst.loc[tstart:tend]
thresholds = np.linspace(0.01, 0.99, 100)[::-1]
tpr_bn, fpr_bn, precision_bn = get_roc_curve(xds_best['probs'].to_pandas(), thresholds, partial(evaluate_threshold, pew=None))  

In [15]:
rocs = {}
datatypes = ['RSAM', 'CO2', 'SO2', 'H2S']
for _ds in datatypes:
    _data = xds_best['original_data'].loc[dict(type=_ds)].to_pandas().loc[tstart:tend]
    _data = _data.fillna(method='ffill')
    tpr_, fpr_, prec_ = get_roc_curve(_data, thresholds, partial(evaluate_threshold, pew=None))
    rocs[_ds] = dict(tpr=tpr_, fpr=fpr_, precision=prec_)

In [None]:
fig = make_subplots(rows=4, cols=2, specs=[[{"colspan": 2, "secondary_y": True}, None], 
                                           [{"colspan": 2, "secondary_y": True}, None],
                                           [{"rowspan": 2}, {"rowspan": 2}], [{}, {}]],
                                           horizontal_spacing=0.1)

validation_plot(xds_best['probs'].to_pandas(), 0.1, debug=False, fig=fig, row=1, col=1)
validation_plot(xds_best['probs'].to_pandas(), 0.3, debug=False, fig=fig, row=2, col=1)
fig.update_yaxes(tickvals=[.1, .3, .5], row=1, col=1)
fig.update_yaxes(tickvals=[.1, .3, .5], row=2, col=1)
fig.update_xaxes(range=[tstart, tend], row=1, col=1)
fig.update_xaxes(range=[tstart, tend], row=2, col=1)
fig.add_annotation(text="<b>(A)</b>", xref="x domain", yref="y domain", x=.05, y=.9, showarrow=False, row=1, col=1)
fig.add_annotation(text="<b>(B)</b>", xref="x domain", yref="y domain", x=.05, y=.9, showarrow=False, row=2, col=1)
# add legend manually
fig.add_trace(go.Scatter(x=['2013-12-01'], y=[1, 1], mode='markers',
                        showlegend=False, line_color=get_color(0)), row=1, col=1)
fig.add_annotation(text="True Positives", x='2014-01-01', y=1, showarrow=False, xanchor='left', row=1, col=1)
fig.add_trace(go.Scatter(x=['2015-04-01'], y=[1, 1], mode='markers',
                        showlegend=False, line_color=get_color(1)), row=1, col=1)
fig.add_annotation(text="False Positives", x='2015-05-01', y=1, showarrow=False, xanchor='left', row=1, col=1)
fig.add_trace(go.Scatter(x=['2016-09-01'], y=[1, 1], mode='markers',
                        showlegend=False, line_color=get_color(2)), row=1, col=1)
fig.add_annotation(text="True Negatives", x='2016-10-01', y=1, showarrow=False, xanchor='left', row=1, col=1)
fig.add_trace(go.Scatter(x=['2018-03-01'], y=[1, 1], mode='markers',
                        showlegend=False, line_color=get_color(3)), row=1, col=1)
fig.add_annotation(text="False Negatives", x='2018-04-01', y=1, showarrow=False, xanchor='left', row=1, col=1)

for i, _ds in enumerate(datatypes):
    fpr_, tpr_, prec_ = rocs[_ds]['fpr'], rocs[_ds]['tpr'], rocs[_ds]['precision']
    fig.add_trace(go.Scatter(x=make_strictly_increasing(fpr_), y=tpr_, mode='lines', showlegend=False, line_color=get_color(i + 1)),
                  row=3, col=1)
fig.add_trace(go.Scatter(x=fpr_bn, y=tpr_bn, mode='lines', showlegend=False, line_color=get_color(0)), row=3, col=1)
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', showlegend=False, line=dict(color='navy', dash='dash')), row=3, col=1)

# add legend manually
legend_x = 0.95
legend_len = 0.1
fig.add_trace(go.Scatter(x=[legend_x, legend_x + legend_len], y=[0.9, 0.9], mode='lines', showlegend=False, line_color=get_color(1)), row=3, col=1)
fig.add_annotation(text="RSAM", x=1.2, y=0.9, showarrow=False, row=3, col=1)
fig.add_trace(go.Scatter(x=[legend_x, legend_x + legend_len], y=[0.8, 0.8], mode='lines', showlegend=False, line_color=get_color(2)), row=3, col=1)
fig.add_annotation(text=u'CO\u2082', x=1.2, y=0.8, showarrow=False, row=3, col=1)
fig.add_trace(go.Scatter(x=[legend_x, legend_x + legend_len], y=[0.7, 0.7], mode='lines', showlegend=False, line_color=get_color(3)), row=3, col=1)
fig.add_annotation(text=u'SO\u2082', x=1.2, y=0.7, showarrow=False, row=3, col=1)
fig.add_trace(go.Scatter(x=[legend_x, legend_x + legend_len], y=[0.6, 0.6], mode='lines', showlegend=False, line_color=get_color(4)), row=3, col=1)
fig.add_annotation(text=u'H\u2082S', x=1.2, y=0.6, showarrow=False, row=3, col=1)
fig.add_trace(go.Scatter(x=[legend_x, legend_x + legend_len], y=[0.5, 0.5], mode='lines', showlegend=False, line_color=get_color(0)), row=3, col=1)
fig.add_annotation(text="Bayesian Network", x=1.3, y=0.5, showarrow=False, row=3, col=1)
fig.add_trace(go.Scatter(x=[legend_x, legend_x + legend_len], y=[0.4, 0.4], mode='lines', showlegend=False, line=dict(color='navy', dash='dash')), row=3, col=1)
fig.add_annotation(text="Random Classifier", x=1.3, y=0.4, showarrow=False, row=3, col=1)
fig.update_xaxes(tickvals=[0., 0.5, 1.], ticktext=["0", "0.5", "1"], title='False positive rate', row=3, col=1)
fig.update_yaxes(title='True positive rate', row=3, col=1)
fig.add_annotation(text="<b>(C)</b>", xref="x domain", yref="y domain", x=.95, y=.95, showarrow=False, row=3, col=1)


# Warning times
plot_tresh = []
warning_times = []
explosive_eruptions = load_whakaari_catalogue(2, '0D').loc[tstart:tend]
for thresh in thresholds:
    warning_time = np.empty(3)*np.nan
    _ , time_windows = evaluate_threshold(thresh, fcst, pew=None, return_windows=True)
    for w in time_windows:
        if w['type'] == 'true_positive':
            for i, e in enumerate(explosive_eruptions.iterrows()):
                if w['start'] < e[0] <= w['end']:
                    warning_time[i] = (e[0] - w['start']).days
    warning_times.append(warning_time)
    plot_tresh.append(thresh)
warning_times = np.array(warning_times)
fig.add_trace(go.Scatter(x=plot_tresh, y=warning_times[:,0], mode='lines', showlegend=False, line_color=get_color(1)), row=3, col=2)
fig.add_trace(go.Scatter(x=plot_tresh, y=warning_times[:,1], mode='lines', showlegend=False, line_color=get_color(2)), row=3, col=2)
fig.add_trace(go.Scatter(x=plot_tresh, y=warning_times[:,2], mode='lines', showlegend=False, line_color=get_color(3)), row=3, col=2)

# Add legend manually
fig.add_annotation(text='Eruptions',
                   x=.19, y=70, showarrow=False, row=3, col=2)
fig.add_trace(go.Scatter(x=[.15, .17], y=[60, 60], mode='lines', showlegend=False, line_color=get_color(1)), row=3, col=2)
fig.add_annotation(text=explosive_eruptions.iloc[0].tvalue.strftime('%Y-%m-%d'),
                   x=.2, y=60, showarrow=False, row=3, col=2)
fig.add_trace(go.Scatter(x=[.15, .17], y=[50, 50], mode='lines', showlegend=False, line_color=get_color(2)), row=3, col=2)
fig.add_annotation(text=explosive_eruptions.iloc[1].tvalue.strftime('%Y-%m-%d'),
                   x=.2, y=50, showarrow=False, row=3, col=2)
fig.add_trace(go.Scatter(x=[.15, .17], y=[40, 40], mode='lines', showlegend=False, line_color=get_color(3)), row=3, col=2)
fig.add_annotation(text=explosive_eruptions.iloc[2].tvalue.strftime('%Y-%m-%d'),
                   x=.2, y=40, showarrow=False, row=3, col=2)
fig.add_annotation(text="<b>(D)</b>", xref="x domain", yref="y domain", x=.95, y=.95, showarrow=False, row=3, col=2)

fig.update_xaxes(range=[0.05, 0.25], title='Threshold', row=3, col=2)
fig.update_yaxes(range=[0., 80], title='Warning time [days]', row=3, col=2)

fig.update_layout(width=1000, height=600)
fig.write_image('data/validation_plots.png', width=1000, height=600, scale=5)
fig
 

In [None]:
res = compute_rates(fcst, pew=40)
print(res)
print(res['positive_rates'][0] / res['negative_rates'][0])

### Forecast Trellis Plot

In [None]:
fig = make_subplots(rows=4, cols=1, specs=[[{"secondary_y": True}], 
                                           [{"secondary_y": True}],
                                           [{"secondary_y": True}],
                                           [{"secondary_y": True}]])

xds_best = get_best_model(search_results, np.arange(10, 110, 10), np.arange(2, 6))
xds_best_seismic = get_best_model(search_results, np.arange(10, 110, 10), np.arange(2, 6),
                                exclude_from_test=['SO2', 'H2S', 'CO2'])
xds_best_gas = get_best_model(search_results, np.arange(10, 110, 10), np.arange(2, 6),
                            exclude_from_test=['RSAM', 'Eqr'])

showlegend = True
eruptions = load_whakaari_catalogue(2, '0D')
q_min = 0.15
q_max = 0.85
for irow, group_name in enumerate(['b', 'c', 'd', 'e']):
    time = pd.to_datetime(xds_best['time'])[ groups == group_name]
    probs_best = xds_best['probs'].values[groups == group_name]
    fig.add_trace(go.Scatter(x=time, y=probs_best,
                            mode='lines', name="Eruption Probability (best model)",
                            line_color=get_color(0), showlegend=showlegend),
                            row=irow+1, col=1)
    fig.add_trace(go.Scatter(x=pd.to_datetime(xds_best['time'])[groups == group_name],
                            y=xds_best_gas['probs'].values[groups == group_name],
                            mode='lines', name="Gas Eruption Probability",
                            line_color=get_color(3, alpha=0.5), showlegend=showlegend),
                            row=irow+1, col=1)
    fig.add_trace(go.Scatter(x=pd.to_datetime(xds_best['time'])[groups == group_name],
                            y=xds_best_seismic['probs'].values[groups == group_name],
                            mode='lines', name="Seismic Eruption Probability",
                            line_color=get_color(4, alpha=0.5), showlegend=showlegend),
                            row=irow+1, col=1)
    t1 = xds_best['time'][groups == group_name][0].values
    t2 = xds_best['time'][groups == group_name][-1].values
    dfe = eruptions.loc[t1:t2]
    for i in range(len(dfe.index)):
        fig.add_trace(go.Scatter(x=[dfe.index[i], dfe.index[i]], y=[0., 1.], mode='lines',
                                line_width=.8, line_color='black', name='Explosive Eruption', 
                                showlegend=showlegend), secondary_y=True, row=irow + 1, col=1)
        showlegend = False

    indices = np.where(groups == group_name)[0]

x_annot = 0.08
y_annot = 0.9
fig.add_annotation(text="<b>(A) Group B</b>", xref="x domain", yref="y domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(B) Group C</b>", xref="x3 domain", yref="y3 domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(C) Group D</b>", xref="x3 domain", yref="y5 domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(D) Group E</b>", xref="x3 domain", yref="y7 domain", x=x_annot, y=y_annot, showarrow=False)

fig.add_annotation(text="Dome extrusion", x="2012-11-24", y=1.2, showarrow=False, yref='y domain')
fig.add_vrect(
    x0="2012-11-22", x1="2012-12-10",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=1, col=1)

fig.add_annotation(text="Geysering", x="2013-02-15", y=1.2, showarrow=False, yref='y domain')
# Add shape regions
fig.add_vrect(
    x0="2013-01-15", x1="2013-04-10",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=1, col=1)

fig.add_annotation(text="Minor steam and mud eruptions", x="2013-10-04", y=1., showarrow=True, yref='y domain',
                   xanchor='right')
fig.add_annotation(x="2013-08-17", y=1., showarrow=True, yref='y domain')
fig.add_vrect(
    x0="2013-08-15", x1="2013-08-18",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=1, col=1)
fig.add_vrect(
    x0="2013-10-01", x1="2013-10-08",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=1, col=1)

fig.add_annotation(text="Banded tremor", x="2015-10-13", y=1.2, showarrow=False, yref='y3 domain', xref='x2')
fig.add_vrect(
    x0="2015-10-13", x1="2015-10-20",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=2, col=1)

fig.add_annotation(text="Non-explosive ash venting", x="2016-09-13", y=1.2, showarrow=False, yref='y5 domain', xref='x3')
fig.add_vrect(
    x0="2016-09-13", x1="2016-09-18",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=3, col=1)


fig.add_annotation(text="Earthquake swarm", x="2019-06-15", y=1.2, showarrow=False, yref='y5 domain', xref='x3')
fig.add_vrect(
    x0="2019-04-23", x1="2019-07-01",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=3, col=1)

fig.add_annotation(text="Minor ash emissions", x="2019-12-24", y=1.2, showarrow=False,
                   arrowcolor=get_color(7), yref='y5 domain', xref='x3')
fig.add_vrect(
    x0="2019-12-23", x1="2019-12-29",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=3, col=1)


fig.add_annotation(text="Lava extrusion", x="2020-01-15", y=1.2, showarrow=False,
                   arrowcolor=get_color(7), yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2020-01-10", x1="2020-01-20",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)


fig.add_annotation(text="Minor ash emissions", x="2020-11-13", y=1.2, showarrow=False,
                   arrowcolor=get_color(7), yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2020-11-13", x1="2020-12-01",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)


fig.add_annotation(text="Small steam explosions", x="2020-12-29", y=1.1, showarrow=False,
                   xanchor='left', arrowcolor=get_color(7), yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2020-12-29", x1="2021-01-02",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)

fig.add_annotation(text="Minor ash emissions", x="2022-09-18", y=1.2, showarrow=False, yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2022-09-18", x1="2022-09-24",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)

fig.add_annotation(text="Small steam explosion", x="2024-05-24", y=1.2, showarrow=False,
                   yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2024-05-24", x1="2024-05-31",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)

fig.add_annotation(text="Minor ash emissions", x="2024-07-24", y=1.1, showarrow=False,
                   xanchor='left', yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2024-07-24", x1="2024-09-10",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)


fig.update_layout(width=1200, height=1000, legend=dict(y=1.1, orientation='h', font=dict(size=15)),)
fig.update_yaxes(range=[0,1.2], row=5, col=1)
fig.update_yaxes(showticklabels=False, showgrid=False, secondary_y=True)
fig.update_xaxes(range=["2012-09-01", "2024-06-19"], row=5, col=1)
for _r in range(1, 5):
    fig.update_xaxes(tickfont=dict(size=15), row=_r, col=1)
    fig.update_yaxes(tickfont=dict(size=15), row=_r, col=1)
fig.write_image('data/eruption_forecasts_whakaari.png', width=1200, height=1000, scale=5)
fig


## Supplementary material

In [None]:
fig = go.Figure()
metric = 'mod_roc_auc_no_pew'
for nstates in [2, 3, 4, 5]:
    vals = []
    error = []
    showlegend=True
    for pew in search_results.keys():
        sdf_tmp = search_results[pew][nstates] 
        sdf_tmp = sdf_tmp.sort_values(by=[f'rank_test_{metric}'])
        vals.append(sdf_tmp.iloc[0][f'mean_test_{metric}'])
        error.append(sdf_tmp.iloc[0][f'std_test_{metric}'])
    vals = np.array(vals)
    error = np.array(error)
    fig.add_trace(go.Scatter(x=list(search_results.keys()), y=vals, mode='lines',
                             name=f'{nstates} states', line_color=get_color(nstates-3)))
    fig.add_trace(go.Scatter(x=list(search_results.keys()), y=vals+error, mode='lines', marker=dict(color="#444"),
                                 line=dict(width=0), showlegend=False))
    fig.add_trace(go.Scatter(x=list(search_results.keys()), y=vals-error, mode='lines', marker=dict(color="#444"),
                                 line=dict(width=0), showlegend=False, fillcolor=get_color(nstates-3, alpha=0.1),
                                 fill='tonexty'))
 
fig.update_layout(title="Model objective function vs pre-eruption window", xaxis_title="Pre-eruption window")

In [None]:
### Uncertainty estimates
zarr_store = 'data/whakaari_forecasts.zarr'
if False:
    if os.path.isdir(zarr_store):
        shutil.rmtree(zarr_store)
fts = []
fts_no_seismic = []
fts_no_gas = []
for nstates in np.arange(2, 6):
    for pew in np.arange(10, 110, 10):
        for _c in search_results[pew][nstates].iterrows():
            if _c[1].mean_test_mod_roc_auc_no_pew < 0.85:
                continue
            _e = _c[1].params
            xds = forecasts(pd.concat((X_train, X_test, X_remainder)), pew=pew, expert_only=False,
                            modelfile=_e['clf__modelfile'],
                            bins=_e['discretize__bins'],
                            hidden_nodes=_e['clf__hidden_nodes'],
                            uniformize=_e['clf__uniformize'],
                            randomize=False, zarr_store=zarr_store,
                            recompute=True, save_trained_model=False,
                            smoothing=30)
            xds_no_seismic = forecasts(pd.concat((X_train, X_test, X_remainder)),
                                    pew=pew, expert_only=False,
                                    exclude_from_test=('RSAM', 'Eqr'),
                                    modelfile=_e['clf__modelfile'],
                                    bins=_e['discretize__bins'],
                                    hidden_nodes=_e['clf__hidden_nodes'],
                                    uniformize=_e['clf__uniformize'],
                                    randomize=False, zarr_store=zarr_store,
                                    recompute=True, save_trained_model=False,
                                    smoothing=30)
            xds_no_gas = forecasts(pd.concat((X_train, X_test, X_remainder)),
                                pew=pew, expert_only=False,
                                exclude_from_test=('CO2', 'SO2', 'H2S'),
                                modelfile=_e['clf__modelfile'],
                                bins=_e['discretize__bins'],
                                hidden_nodes=_e['clf__hidden_nodes'],
                                uniformize=_e['clf__uniformize'],
                                randomize=False, zarr_store=zarr_store,
                                recompute=True, save_trained_model=False,
                                smoothing=30)

            fts.append(xds['probs'].values)
            fts_no_seismic.append(xds_no_seismic['probs'].values)
            fts_no_gas.append(xds_no_gas['probs'].values)
xds_all = xr.DataArray(np.array(fts), dims=['model', 'time'], coords={'model': np.arange(len(fts)), 'time': xds['probs'].time})
xds_all_no_seismic = xr.DataArray(np.array(fts_no_seismic), dims=['model', 'time'],
                                  coords={'model': np.arange(len(fts_no_seismic)), 'time': xds_no_seismic['probs'].time})
xds_all_no_gas = xr.DataArray(np.array(fts_no_gas), dims=['model', 'time'],
                              coords={'model': np.arange(len(fts_no_gas)), 'time': xds_no_gas['probs'].time})

group = f"/ensemble_model/exclude_from_test=()"
xds_all.to_zarr(zarr_store, group=group, mode='a')
group = f"/ensemble_model/exclude_from_test=('RSAM', 'Eqr')"
xds_all_no_seismic.to_zarr(zarr_store, group=group, mode='a')
group = f"/ensemble_model/exclude_from_test=('CO2', 'SO2', 'H2S')"
xds_all_no_gas.to_zarr(zarr_store, group=group, mode='a')

In [None]:
model1 = f"/ensemble_model/exclude_from_test=()"
zarr_store = 'data/whakaari_forecasts.zarr'
xds1 = xr.open_zarr(zarr_store, group=model1, consolidated=False).to_array()
xds_best = get_best_model(search_results, np.arange(10, 110, 10), np.arange(2, 6))
fig = make_subplots(rows=4, cols=1, specs=[[{"secondary_y": True}], 
                                           [{"secondary_y": True}],
                                           [{"secondary_y": True}],
                                           [{"secondary_y": True}]])
showlegend = True
eruptions = load_whakaari_catalogue(1, '0D')
cump = np.full(xds1.shape[-1], np.nan)
q_min = 0.15
q_max = 0.85
for irow, group_name in enumerate(['b', 'c', 'd', 'e']):
    time = pd.to_datetime(xds1['time'])[ groups == group_name]
    probs_median = xds1.median('model').values[0, groups == group_name]
    probs_best = xds_best['probs'].values[groups == group_name]
    probs_min = xds1.chunk(dict(model=-1)).quantile(q_min, 'model').values[0, groups == group_name]
    probs_max = xds1.chunk(dict(model=-1)).quantile(q_max, 'model').values[0, groups == group_name]
    fig.add_trace(go.Scatter(x=time, y=probs_median,
                             mode='lines', name="Eruption Probability (median model)",
                             line_color=get_color(0), showlegend=showlegend),
                             row=irow+1, col=1)
    fig.add_trace(go.Scatter(x=time, y=probs_best,
                             mode='lines', name="Eruption Probability (best model)",
                             line_dash='dash',
                             line_color=get_color(0), showlegend=showlegend),
                             row=irow+1, col=1)
    if True:
        fig.add_trace(go.Scatter(x=time, y=probs_min,
                                mode='lines', marker=dict(color="#444"),
                                line=dict(width=0), showlegend=False), row=irow+1, col=1)
        fig.add_trace(go.Scatter(x=time, y=probs_max,
                                mode='lines', marker=dict(color="#444"),
                                line=dict(width=0), showlegend=False, fillcolor=get_color(0, alpha=0.3),
                                fill='tonexty'), row=irow+1, col=1)
 
    t1 = xds1['time'][groups == group_name][0].values
    t2 = xds1['time'][groups == group_name][-1].values
    dfe = eruptions.loc[t1:t2]
    for i in range(len(dfe.index)):
        fig.add_trace(go.Scatter(x=[dfe.index[i], dfe.index[i]], y=[0., 1.], mode='lines',
                                 line_width=.8, line_color='black', name='Observed Eruption', 
                                 showlegend=showlegend), secondary_y=True, row=irow + 1, col=1)
        showlegend = False

    indices = np.where(groups == group_name)[0]

x_annot = 0.09
y_annot = 0.9
fig.add_annotation(text="<b>(A) Group B</b>", xref="x domain", yref="y domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(B) Group C</b>", xref="x3 domain", yref="y3 domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(C) Group D</b>", xref="x3 domain", yref="y5 domain", x=x_annot, y=y_annot, showarrow=False)
fig.add_annotation(text="<b>(D) Group E</b>", xref="x3 domain", yref="y7 domain", x=x_annot, y=y_annot, showarrow=False)

fig.add_annotation(text="Dome extrusion", x="2012-11-24", y=1.2, showarrow=False, yref='y domain')
fig.add_vrect(
    x0="2012-11-22", x1="2012-12-10",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=1, col=1)

fig.add_annotation(text="Geysering", x="2013-02-15", y=1.2, showarrow=False, yref='y domain')
# Add shape regions
fig.add_vrect(
    x0="2013-01-15", x1="2013-04-10",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=1, col=1)

fig.add_annotation(text="Minor steam and mud eruptions", x="2013-10-04", y=1., showarrow=True, yref='y domain',
                   xanchor='right')
fig.add_annotation(x="2013-08-17", y=1., showarrow=True, yref='y domain')
fig.add_vrect(
    x0="2013-08-15", x1="2013-08-18",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=1, col=1)
fig.add_vrect(
    x0="2013-10-01", x1="2013-10-08",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=1, col=1)

fig.add_annotation(text="Banded tremor", x="2015-10-13", y=1.2, showarrow=False, yref='y3 domain', xref='x2')
fig.add_vrect(
    x0="2015-10-13", x1="2015-10-20",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=2, col=1)

fig.add_annotation(text="Non-explosive ash venting", x="2016-09-13", y=1.2, showarrow=False, yref='y5 domain', xref='x3')
fig.add_vrect(
    x0="2016-09-13", x1="2016-09-18",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=3, col=1)


fig.add_annotation(text="Earthquake swarm", x="2019-06-15", y=1.2, showarrow=False, yref='y5 domain', xref='x3')
fig.add_vrect(
    x0="2019-04-23", x1="2019-07-01",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=3, col=1)

fig.add_annotation(text="Minor ash emissions", x="2019-12-24", y=1.2, showarrow=False,
                   arrowcolor=get_color(7), yref='y5 domain', xref='x3')
fig.add_vrect(
    x0="2019-12-23", x1="2019-12-29",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=3, col=1)


fig.add_annotation(text="Lava extrusion", x="2020-01-15", y=1.2, showarrow=False,
                   arrowcolor=get_color(7), yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2020-01-10", x1="2020-01-20",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)


fig.add_annotation(text="Minor ash emissions", x="2020-11-13", y=1.2, showarrow=False,
                   arrowcolor=get_color(7), yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2020-11-13", x1="2020-12-01",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)


fig.add_annotation(text="Small steam explosions", x="2020-12-29", y=1.1, showarrow=False,
                   xanchor='left', arrowcolor=get_color(7), yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2020-12-29", x1="2021-01-02",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)

fig.add_annotation(text="Minor ash emissions", x="2022-09-18", y=1.2, showarrow=False, yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2022-09-18", x1="2022-09-24",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)

fig.add_annotation(text="Small steam explosion", x="2024-05-24", y=1.2, showarrow=False,
                   yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2024-05-24", x1="2024-05-31",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)

fig.add_annotation(text="Minor ash emissions", x="2024-07-24", y=1.1, showarrow=False,
                   xanchor='left', yref='y7 domain', xref='x4')
fig.add_vrect(
    x0="2024-07-24", x1="2024-09-10",
    fillcolor=get_color(7), opacity=0.2,
    layer="below", line_width=0,
    row=4, col=1)


fig.update_layout(width=1200, height=1000, legend=dict(y=1.1, orientation='h', font=dict(size=15)),)
fig.update_yaxes(range=[0,1.2], row=5, col=1)
fig.update_yaxes(showticklabels=False, showgrid=False, secondary_y=True)
fig.update_xaxes(range=["2012-09-01", "2024-06-19"], row=5, col=1)
for _r in range(1, 5):
    fig.update_xaxes(tickfont=dict(size=15), row=_r, col=1)
    fig.update_yaxes(tickfont=dict(size=15), row=_r, col=1)
fig.write_image('data/eruption_forecast_uncertainties_whakaari.png', width=1200, height=1000, scale=5)
fig


In [22]:
def hindcasts(data: pd.DataFrame, exclude_from_test: Sequence=(), pew: int=30, expert_only: bool=False, eq_sample_size: int=1,
              modelfile: str='data/Whakaari_4s_initial1.xdsl', bins: tuple=(0, 5, 95, 100),
              hidden_nodes: bool=True, uniformize: bool=False, randomize: bool=False, zarr_store='data/whakaari_forecasts.zarr',
              recompute=False, save_trained_model=False, smoothing=None, ex_nodes=[]):
    """
    Compute BN hindcasts
    """
    data_fill = data.ffill(axis=0)
    data_fill.loc['2022-07-01':, 'RSAM'] = np.nan
    data_fill.loc['2022-07-01':, 'Eqr'] = np.nan
    pipe = Pipeline([('discretize', Discretizer(bins=bins, strategy='quantile', names=None)),
                        ('clf', WhakaariModel(expert_only=expert_only, uniformize=uniformize, eq_sample_size=eq_sample_size,
                                            randomize=randomize, hidden_nodes=hidden_nodes,
                                            modelfile=modelfile, smoothing=smoothing))])

    x_train, x_test, x_remainder, y_train, y_test, y_remainder, groups = get_train_test_data(data_fill, ndays=30)
    y_train = pre_eruption_window(y_train, pew)
    y_test = pre_eruption_window(y_test, pew)
    y_all = pd.concat([y_train, y_test, y_remainder])
    probs = np.zeros(data_fill.shape[0])
    disc_data = np.full(data_fill.shape, '*', dtype='<U7')
    pipe.fit(data_fill, y_all)
    probs = pipe.predict_proba(data_fill)[:, 1]

    xds = xr.Dataset(
        {
            "probs": (["time"], probs),
            "probs_min": (["time"], probs),
            "probs_max": (["time"], probs),
            "original_data": (["time", "type"], data.values),
            "discrete_data": (["time", "type"], disc_data),
            "y_all": (["time"], y_all.values.squeeze())
        },
        coords={"time": data.index.tz_localize(None),
                "type": data.columns.astype(str)},
    )
    return xds


In [None]:
best_objective = -np.inf
pews = np.arange(10, 110, 10)
nstates = np.arange(2, 6)
for _ns in nstates:
    for _pew in pews:
        sdf_tmp = search_results[_pew][_ns]
        sdf_tmp = sdf_tmp.sort_values(by=['rank_test_mod_roc_auc_no_pew'])
        if sdf_tmp.iloc[0].mean_test_mod_roc_auc_no_pew > best_objective:
            best_objective = sdf_tmp.iloc[0].mean_test_mod_roc_auc_no_pew
            best_estimator = sdf_tmp.iloc[0].params
            best_pew = _pew
xds = hindcasts(pd.concat((X_train, X_test, X_remainder)), pew=best_pew,
                expert_only=False, exclude_from_test=(),
                modelfile=best_estimator['clf__modelfile'],
                bins=best_estimator['discretize__bins'],
                hidden_nodes=best_estimator['clf__hidden_nodes'],
                uniformize=best_estimator['clf__uniformize'],
                randomize=False, zarr_store=None,
                recompute=True, save_trained_model=False,
                smoothing=30)
fig = forecast_plot(xds['probs'], log=False, ploterr=False, eruptions=True, label='Forecast', showlegend=True, color_id=0, window=1)
fig.update_yaxes(showticklabels=False, showgrid=False, secondary_y=True)
fig.write_image('data/hindcast_whakaari.png', width=1200, height=400, scale=5)
fig

In [None]:
compute_rates(xds['probs'].to_pandas(), pew=40)