In [None]:
from sklearn import metrics
import seaborn as sns

# evaluation 함수 정의 --> fitted model이 들어와야 함
def evaluate_class_mdl(fitted_model, X_train, X_test, y_train, y_test, plot=True, pct=True, thresh=.5):
    y_train_pred = fitted_model.predict(X_train).squeeze()
    if len(np.unique(y_train_pred)) > 2:
        y_train_pred = np.where(y_train_pred > thresh, 1, 0)
        y_test_prob = fitted_model.predict(X_test).squeeze()
        y_test_pred = np.where(y_test_pred > thres, 1, 0)
    else:
        y_test_prob = fitted_model.predict_proba(X_test)[:, 1]
        y_test_pred = np.where(y_test_prob > thresh, 1, 0)
    roc_auc_te = metrics.roc_auc_score(y_test, y_test_prob)

    cf_matrix = metrics.confusion_matrix(y_test, y_test_pred)
    tn, fp, fn, tp = cf_matrix.ravel()
    acc_tr = metrics.accuracy_score(y_train, y_train_pred)
    acc_te = metrics.accuracy_score(y_test, y_test_pred)
    pre_te = metrics.precision_score(y_test, y_test_pred)
    rec_te = metrics.recall_score(y_test, y_test_pred)
    f1_te = metrics.f1_score(y_test, y_test_pred)
    mcc_te = metrics.matthews_corrcoef(y_test, y_test_pred)

    if plot:
        print(f'Accuracy_train : {acc_tr: .4f}\t\tAccuracy_test : {acc_te: .4f}')
        print(f'Precision_test : {pre_te: .4f}\t\tRecall_test : {rec_te: .4f}')
        print(f'ROC-AUC_test : {roc_auc_te: .4f}\t\tF1_test : {f1_te: .4f}\t\tMCC_test : {mcc_te: .4f}')
        if pct: # normalize 할지 말지
            ax = sns.heatmap(cf_matrix/np.sum(cf_matrix), annot=True,
                             fmt='.2%', cmap='Blues', annot_kws={'size':16})
        else:
            ax = sns.heatmap(cf_matrix, annot=True,
                             fmt='d', cmap='Blues', annot_kws={'size':16})
        ax.set_xlabel('Predicted', fontsize=12)
        ax.set_ylabel('Observed', fontsize=12)
        plt.show()

        return y_train_pred, y_test_prob, y_test_pred
    else:
        t = cf_matrix.sum()
        metrics_dict = {
            'accuracy_train' : acc_tr,
            'accuracy_test' : acc_te,
            'precision' : pre_te,
            'recall' : rec_te,
            'roc_auc' : roc_auc_te,
            'f1' : f1_te,
            'mcc' : mcc_te,
            'tn%' : tn/t,
            'fp%' : fp/t,
            'fn%' : fn/t,
            'tp%' : tp/t
        }
        return metrics_dict