In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import logging
import numpy as np
from sklearn import metrics
from sklearn.metrics import roc_auc_score, confusion_matrix, precision_recall_curve, auc, ndcg_score
from tdescore.classifier.train import train_classifier
from tdescore.classifier.features import relevant_columns, column_descriptions
from tdescore.classifier.collate import get_classified_sources, convert_to_train_dataset
import matplotlib.patheffects as path_effects
import pandas as pd
import shap

In [None]:
logging.getLogger("tdescore").setLevel("INFO")

In [None]:
n_iter = 10

all_all_res, clfs = train_classifier(
    n_iter=n_iter,
    columns=relevant_columns
)

n_estimator_set = list(sorted(clfs.keys()))

In [None]:
plt.figure()
plt.suptitle("tdescore")
c_metrics = [x for x in all_all_res.columns if x not in ["n_estimator", "all_res"]]
for i, y in enumerate(c_metrics):
    plt.subplot(2, 3, i+1)
    plt.plot(all_all_res["n_estimator"], all_all_res[y]*100., color=f"C{i}")
    plt.scatter(n_estimator_set, all_all_res[y]*100., color=f"C{i}")
    plt.ylabel(f'{y} [%]')

plt.subplots_adjust(wspace=1.0)
plt.show()

In [None]:
metric = "precision_recall_area"

best_index = all_all_res[metric].idxmax()

best_estimator = all_all_res.iloc[best_index]["n_estimator"]

print(f"Best value is {best_estimator}")

clf = clfs[best_estimator]
all_res = all_all_res[all_all_res["n_estimator"] == best_estimator]["all_res"].iloc[0]

In [None]:
all_res

In [None]:
def flatten():
    true_class = []
    all_probs = []
    for i in range(n_iter):
        probs = all_res[f"probs_{i}"]
        true_class += all_res[f"class"].tolist()
        all_probs += probs.tolist()
    return true_class, all_probs

In [None]:
tclass, aprobs = flatten()

fscale = 4.
figsize=(fscale*1.618, fscale)

fpr, tpr, thresholds = metrics.roc_curve(tclass, aprobs)
plt.figure(figsize=figsize)
plt.plot(fpr, tpr, label=f"AUC={roc_auc_score(tclass, aprobs):.3f}")
plt.plot([0.0, 1.0], [0.0, 1.0], linestyle=":", label="random AUC=0.500")
plt.xlabel("false positive")
plt.ylabel("true positive")
plt.legend()
plt.xlim(0.0, 0.15)
plt.ylim(0.0, 1.0)
plt.savefig("figures/roc.pdf", bbox_inches='tight')
plt.show()

fpr, fnr, thresholds = metrics.det_curve(tclass, aprobs)
plt.figure(figsize=figsize)
plt.plot(fpr, fnr)
plt.xlabel("false positive")
plt.ylabel("false negative")
plt.xlim(0.0, 0.1)
plt.ylim(0.0, 1.0)
plt.savefig("figures/fp_fn.pdf", bbox_inches='tight')
plt.show()

x, y, thresholds = metrics.precision_recall_curve(tclass, aprobs)
plt.figure(figsize=figsize)
plt.plot(x, y)
plt.xlabel("Precision")
plt.ylabel("Recall")
plt.xlim(0.0, 1.0)
plt.ylim(0.0, 1.0)
plt.savefig("figures/precision_recall.pdf", bbox_inches='tight')
plt.show()

In [None]:
pr, recall, thresholds = metrics.precision_recall_curve(tclass, aprobs)

index = np.arange(len(thresholds))

mask = recall[1:] >= 0.95

loose_index = max(index[mask])

threshold_loose = thresholds[loose_index]
print(f"Loose threshold {threshold_loose:.2f}, Precision={100.*pr[loose_index]:.1f}%, Recall={100.*recall[loose_index]:.1f}%")


mask = pr[:-1] >= 0.95
strict_index = min(index[mask])

threshold_strict = thresholds[strict_index]
print(f"Strict Threshold {threshold_strict:.2f}, Precision={100.*pr[strict_index]:.1f}%, Recall={100.*recall[strict_index]:.1f}%")

mask = pr[:-1] >= 0.8

balanced_index = min(index[mask])

threshold_balanced = thresholds[balanced_index]
print(f"Balanced Threshold {threshold_balanced:.2f}, Precision={100.*pr[balanced_index]:.1f}%, Recall={100.*recall[balanced_index]:.1f}%")

In [None]:
plt.figure(figsize=figsize)
plt.plot(thresholds[:-1], x[1:-1], label="Precision")
plt.plot(thresholds[:-1], y[1:-1], label="Recall")
plt.xlim(0.0, 1.0)
plt.ylim(0.0, 1.0)
plt.xlabel(r"$\it{tdescore}$ threshold")
plt.legend()
for cut in [threshold_loose, threshold_balanced, threshold_strict]:
    plt.axvline(cut, linestyle=":", color="k")
plt.savefig("figures/precision_recall.pdf", bbox_inches='tight')
plt.show()

In [None]:
def plot_matrix(cut, label):
    print(f"Cut of {cut:.3f}")

    base_cm = confusion_matrix(tclass, np.array(aprobs) > cut)

    for k in range(2):
        
        plt.figure()
        
        if k==0:
            cm = base_cm/np.sum(base_cm, axis=0)
        elif k==1:
            cm = (base_cm.T/np.sum(base_cm, axis=1)).T
        else:
            cm = base_cm

        fig, ax = plt.subplots(figsize=(5, 4))
        ax.imshow(cm, cmap=plt.cm.Blues)
        ax.set_title(f'{["Prediction-Normalised", "Truth-Normalised", ""][k]} Confusion Matrix ({label})')
        ax.set_xticks(np.arange(len(cm)))
        ax.set_yticks(np.arange(len(cm)))
        ax.set_xticklabels(['Non-TDE', "TDE"])
        ax.set_yticklabels(['Non-TDE', "TDE"])
        ax.set_xlabel('Predicted label')
        ax.set_ylabel('True label')
        for i in range(len(cm)):
            for j in range(len(cm)):
                ax.text(j, i, f"{100.*cm[i, j]:.1f}%\n\n({base_cm[i, j]})", ha='center', va='center', color='white', fontsize=15,
                        path_effects=[path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()]
                       )        
        path = f"figures/matrix_{label}_{k}.pdf"
        plt.savefig(path)

In [None]:
for i, cut in enumerate([threshold_loose, threshold_balanced, threshold_strict]):
    plot_matrix(cut, label=["Inclusive", "Balanced", "Clean"][i])
    plt.show()

In [None]:
features = pd.DataFrame([relevant_columns, column_descriptions, list(clf.feature_importances_), ]).T
features.sort_values(by=2, ascending=False, inplace=True)
features

In [None]:
n_feature = len(relevant_columns)

text_str = r"""\begin{table*}[]
\centering
    \begin{tabular}{c|c|c}
    \textbf{Feature} &\textbf{Description}& \textbf{Importance (\%)}\\
    \hline
"""
print(text_str)
for _, row in features.iterrows():
    name = row[0].replace('_', '\_')
    print(f"\t{name} & {row[1]} & {100.*row[2]:.1f} \\\\")
print(r"\end{tabular}")
print(r"\caption{Relative importance of all " + str(len(features)) + r" features in \tdes, calculated by \xgboost \citep{xgboost} using the standard averaging of importance across all decision trees in the final model \citep[see e.g][]{ml_textbook}.}")
print(r"""\label{tab:importance}
\end{table*}""")
# print(r"""\end{tabular}
# \caption{Relative importance of all""" + len(features) + """features in \tdes, calculated by \xgboost \citep{xgboost} using the standard averaging of importance across all decision trees in the final model \citep[see e.g][]{ml_textbook}.}
# \label{tab:importance}
# \end{table*}""")

In [None]:
classified_sources = get_classified_sources()
data_to_use = convert_to_train_dataset(classified_sources)

explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(data_to_use)
expected_value = explainer.expected_value

explainer = shap.Explainer(clf, data_to_use, feature_names=relevant_columns)
shap_values = explainer(data_to_use)

def explain(name, classification=None):
    fig = plt.figure()
    index = classified_sources["ztf_name"].tolist().index(name)

    shap.plots.waterfall(
        shap_values[index],
        max_display=5, show=False
    )
    
    if classification is not None:
        title = f"{name} ({classification})"
    
    plt.title(title)
    plt.show()
    return fig

In [None]:
fig = explain("ZTF19aapreis", "Tidal Disruption Event")
fig.savefig("figures/ZTF19aapreis.pdf", bbox_inches='tight')