In [2]:
import pickle
import pandas as pd
import numpy as np
import altair as alt
alt.data_transformers.enable("vegafusion")

DataTransformerRegistry.enable('vegafusion')

In [3]:
def error_df(file):
    with open(file, 'rb') as f:
        result = pickle.load(f)
    error = result['error_metrics_overall']
    return error


def all_errors(file_list):
    errors_df = None
    for i in file_list:
        data = error_df(i)
        data = pd.DataFrame(data).T
        data['Model'] = 'LSTM' if 'lstm' in i else 'GRU'
        data['Features'] = 'Satellite' if 'no_site' in i else 'Site + Satellite'
        errors_df = pd.concat((errors_df, data))
    return errors_df.reset_index(drop=True)

In [4]:
all_errors(['../results/80/lstm_site.pkl',
            '../results/80/lstm_no_site.pkl',
            '../results/80/gru_site.pkl',
            '../results/80/gru_no_site.pkl'])

Unnamed: 0,F1 Score,F2 Score,Precision,Recall,Accuracy,% Low Survival Rate,% High Survival Rate,RMSE,Model,Features
0,0.371,0.371,0.37,0.371,0.622,29.9,70.1,17.206508,LSTM,Site + Satellite
1,0.391,0.392,0.39,0.392,0.635,29.9,70.1,17.086888,LSTM,Satellite
2,0.441,0.51,0.359,0.57,0.567,29.9,70.1,19.38535,GRU,Site + Satellite
3,0.426,0.51,0.334,0.588,0.525,29.9,70.1,19.055124,GRU,Satellite


In [6]:
def conf_matrix_df(file, model, features):
    with open(file, 'rb') as f:
        result = pickle.load(f)
    conf_matrix = result['conf_matrix_overall'].reset_index(
    ).melt(id_vars='index')
    conf_matrix.columns = ['True', 'Predicted', 'Count']
    conf_matrix['Features'] = features
    conf_matrix['Model'] = model
    return conf_matrix


def plot_conf_matrix(threshold, file_list):
    conf_df = None
    for i in file_list:
        model = 'LSTM' if 'lstm' in i else 'GRU'
        features = 'Satellite' if 'no_site' in i else 'Site + Satellite'
        conf_df = pd.concat((conf_df, conf_matrix_df(i, model, features)))

    conf_plot = alt.Chart(conf_df).mark_rect(opacity=0.8).encode(
        y=alt.Y('True', title=''),
        x=alt.X('Predicted', title='', axis=alt.Axis(labelAngle=0)),
        color=alt.Color('Count:Q', scale=alt.Scale(
            scheme='greens'), legend=None)
    ).properties(
        width=200,
        height=200
    )

    text = alt.Chart(conf_df).mark_text(baseline='middle', fontSize=12).encode(
        x='Predicted:N',
        y='True:N',
        text='Count:Q'
    ).properties(
        width=200,
        height=200
    )
    conf_mat_plot = (conf_plot+text).facet(
        row=alt.Row('Model:N', title='', header=alt.Header(
            labelAngle=0, labelPadding=7, labelOrient='right')),
        column=alt.Column('Features:O', sort=[
            'Satellite', 'Site + Satellite'], title=f'{threshold}%')
    ).configure_view(
        strokeWidth=0  # Removes the border around each facet
    ).configure_header(
        titleFontSize=20,  # Size of the facet title
        labelFontSize=14
    ).configure_axis(
        labelFontSize=12
    )

    return conf_mat_plot

In [7]:
plot_conf_matrix(80, ['../results/80/lstm_site.pkl',
                      '../results/80/lstm_no_site.pkl',
                      '../results/80/gru_site.pkl',
                      '../results/80/gru_no_site.pkl'])

In [8]:
def rmse_df(file, model, features):
    with open(file, 'rb') as f:
        result = pickle.load(f)
    rmse = result['error_metrics_age']
    rmse['Features'] = features
    rmse['Model'] = model
    return rmse


def plot_rmse(threshold, file_list):
    rmse = None
    for i in file_list:
        model = 'LSTM' if 'lstm' in i else 'GRU'
        features = 'Satellite' if 'no_site' in i else 'Site + Satellite'
        rmse = pd.concat((rmse, rmse_df(i, model, features)))

    rmse_plot = alt.Chart(rmse).mark_line().encode(
        x=alt.X('Age'),
        y=alt.Y('RMSE'),
        color=alt.Color('Model:N')
    ).facet(
        column=alt.Column('Features:O', sort=[
            'Satellite', 'Site + Satellite'], title='RMSE Plots')).configure_view(
        strokeWidth=0  # Removes the border around each facet
    ).configure_header(
        titleFontSize=20,  # Size of the facet title
    )

    return rmse_plot

In [9]:
plot_rmse(80, ['../results/80/lstm_site.pkl',
               '../results/80/lstm_no_site.pkl',
               '../results/80/gru_site.pkl',
               '../results/80/gru_no_site.pkl'])

In [10]:
def residual_df(file, model, features):
    with open(file, 'rb') as f:
        result = pickle.load(f)
    df = result['pred_df'].groupby(by='ID').mean(numeric_only=True)
    df['Features'] = features
    df['Model'] = model
    df['Residual'] = abs(df['raw_y_true']-df['raw_y_pred'])
    return df


def plot_residual(threshold, file_list):
    resid_df = None
    for i in file_list:
        model = 'LSTM' if 'lstm' in i else 'GRU'
        features = 'Satellite' if 'no_site' in i else 'Site + Satellite'
        resid_df = pd.concat((resid_df, residual_df(i, model, features)))

    resid_plot = alt.Chart(resid_df).mark_circle(opacity=0.5, size=40).encode(
        x=alt.X('raw_y_true:Q', title='True value'),
        y=alt.Y('Residual:Q', title='Residual'),
        color=alt.Color('Model:N', legend=None)
    ).properties(
        width=400,
        height=300).facet(
        row=alt.Row('Model:N', title='', header=alt.Header(
            labelAngle=0, labelPadding=7, labelOrient='right', labelFontSize=15)),
        column=alt.Column('Features:O', sort=[
            'Satellite', 'Site + Satellite'], title='Residual Plots', header=alt.Header(labelFontSize=15))).configure_view(
        strokeWidth=0  # Removes the border around each facet
    ).configure_header(
        titleFontSize=20,  # Size of the facet title
    ).configure_axis(titleFontSize=15)

    return resid_plot

In [11]:
plot_residual(80, ['../results/80/lstm_site.pkl',
                   '../results/80/lstm_no_site.pkl',
                   '../results/80/gru_site.pkl',
                   '../results/80/gru_no_site.pkl'])