In [50]:
import pandas as pd
from ACME.ACME import ACME
from sklearn.metrics import f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split

from prml.linear import VariationalLogisticRegression

dataset = pd.read_excel("./data/over_resample_all_fields_scaler.xlsx")
features = dataset.drop(columns='is_readmission').columns.to_list()

while True:
    Xtrain, Xtest, Ytrain, Ytest = train_test_split(dataset[features].values, dataset['is_readmission'].values,
                                                    test_size=.3)

    vlr = VariationalLogisticRegression()
    vlr.fit(Xtrain, Ytrain, feature_names=features)

    y_pred = vlr.predict(Xtest)
    _score = vlr.score(Xtest, Ytest)
    _f1_macro = f1_score(Ytest, y_pred, average='macro')
    _recall_score = recall_score(Ytest, y_pred, average='macro')
    _precision_score = precision_score(Ytest, y_pred, average='macro')

    if _score >= 0.85 and _f1_macro >= 0.85 and _recall_score >= 0.85 and _precision_score >= 0.85:
        print(_score, _f1_macro, _recall_score, _precision_score, "\n")

        acme_vlr = ACME(vlr, 'is_readmission', features=features,
                        cat_features=['sex', 'HepatitisC', 'is_alcohol', 'Diabetes', 'HBP', 'Hepatitis'], K=50,
                        task='class')

        acme_vlr = acme_vlr.explain(dataset, robust=True, label_class=1)
        summary_plot_1 = acme_vlr.summary_plot()
        summary_plot_1.show()
        summary_plot_1.write_image(file='./image_acme/lt_label_1.eps', format='eps')
        acme_vlr = acme_vlr.explain(dataset, robust=True, label_class=0)
        summary_plot_2 = acme_vlr.summary_plot()
        summary_plot_2.show()
        summary_plot_2.write_image(file='./image_acme/lt_label_0.eps', format='eps')
        bar_plot = acme_vlr.bar_plot()
        bar_plot.show()
        bar_plot.write_image(file='./image_acme/lt_bar.eps', format='eps')
        break


all positive 33 [('ICUTime', 10, '20.7559 ± 1.011'), ('OperationTime_h', 23, '9.0431 ± 1.3118'), ('T', 7, '-4.9306 ± 0.8979'), ('NTproBNP', 12, '4.8463 ± 0.8477'), ('RBCT', 22, '4.3964 ± 0.9491'), ('PT_sec_minus', 29, '4.2413 ± 0.6062'), ('PLT_minus', 27, '-4.0855 ± 0.7872'), ('CREA', 21, '-3.8997 ± 0.8535'), ('SBP', 9, '2.8596 ± 0.6258'), ('Fib_gL', 17, '-2.8112 ± 2.1695'), ('UA_minus', 25, '-2.7894 ± 0.8111'), ('HR', 8, '-2.509 ± 0.557'), ('ALB_minus', 31, '-2.3848 ± 0.7306'), ('HOD', 24, '2.0221 ± 0.6481'), ('TBIL', 20, '1.9331 ± 0.5925'), ('PCT', 11, '-1.8392 ± 0.6485'), ('ALB', 18, '-1.8167 ± 0.6924'), ('WBC', 15, '-1.4359 ± 0.8417'), ('TBA_minus', 32, '-1.3915 ± 0.6429'), ('NEUT_minus', 28, '1.0597 ± 0.427'), ('age', 1, '-0.7989 ± 0.5343'), ('sex', 0, '-0.6724 ± 0.263'), ('Diabetes', 4, '0.6592 ± 0.2214'), ('HGB_minus', 26, '0.6463 ± 0.4792'), ('NEUT', 16, '-0.6438 ± 0.6187'), ('HBP', 5, '0.6376 ± 0.3082'), ('HepatitisC', 2, '0.5627 ± 0.4584'), ('PLT', 13, '0.4398 ± 0.9604'), ('A