In [63]:
# -*- encoding: utf-8 -*-
import os

import numpy as np
import pandas as pd
from ACME.ACME import ACME
from sklearn.metrics import f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, PolynomialFeatures, StandardScaler

from prml.linear import VariationalLogisticRegression
from utils import load_lt_data


def create_images_file(_path: str) -> str:
    if not os.path.exists(_path):
        os.makedirs(_path)
    return _path


def create_toy_data(is_breast: bool = False,
                    is_heart: bool = False,
                    is_bone: bool = False,
                    is_kaggle_heart: bool = False,
                    _path: str = None):
    scaler = StandardScaler()
    feature = PolynomialFeatures(degree=1, include_bias=True)

    if is_breast:
        image_path = create_images_file("./images/breast_data")
        LT = load_lt_data(_all=True, path="./breast_data/fix_breast_cancer.xlsx")
        feature_names = LT.feature_names
    elif is_heart:
        image_path = create_images_file("./images/spect_data")
        LT = load_lt_data(_all=True, path="./spectf_data/over_resample.xlsx")
        feature_names = LT.feature_names
    elif is_bone:
        image_path = create_images_file("./images/bone_marrow_transplant_data")
        LT = load_lt_data(_all=True, path="./bone_marrow_transplant_data/fix_bone_data.xlsx")
        feature_names = LT.feature_names
    elif is_kaggle_heart:
        image_path = create_images_file("./images/heart_disease_data")
        LT = load_lt_data(_all=True, path="./heart_disease_data/over_resample.xlsx")
        feature_names = LT.feature_names
    else:
        image_path = create_images_file("./images/LT")
        LT = load_lt_data(_all=True, path='./data/over_resample_all_fields_scaler.xlsx')
        feature_names = LT.feature_names
        scaler = MinMaxScaler()

    Xtrain, Xtest, Ytrain, Ytest = train_test_split(LT.data, LT.target, test_size=.3)

    Xtrain = scaler.fit_transform(Xtrain)
    Xtest = scaler.transform(Xtest)

    train = np.hstack((Ytrain.reshape(-1, 1), Xtrain))
    test = np.hstack((Ytest.reshape(-1, 1), Xtest))

    dataset = np.vstack((train, test))

    return image_path, Xtrain, Xtest, Ytrain, Ytest, dataset, feature_names


flag = True

while flag:
    image_path, Xtrain, Xtest, Ytrain, Ytest, dataset, feature_names = create_toy_data(is_bone=True)

    vlr = VariationalLogisticRegression(a0=1, b0=2)
    vlr.fit(Xtrain, Ytrain, feature_names)

    y_pred = vlr.predict(Xtest)
    _score = vlr.score(Xtest, Ytest)
    _f1_macro = f1_score(Ytest, y_pred, average='macro')
    _recall_score = recall_score(Ytest, y_pred, average='macro')
    _precision_score = precision_score(Ytest, y_pred, average='macro')

    if _score >= 0.95 and _f1_macro >= 0.93 and _recall_score >= 0.93 and _precision_score >= 0.93:
        print(_score, _f1_macro, _recall_score, _precision_score, "\n")

        acme_vlr = ACME(vlr, target="survival_status", features=feature_names, task="class")

        dataset = pd.DataFrame(dataset, columns=["survival_status"] + feature_names)
        acme_vlr = acme_vlr.explain(dataset, robust=True, label_class=1)
        summary_plot_1 = acme_vlr.summary_plot()
        summary_plot_1.show()
        summary_plot_1.write_image(file='./image_acme/bone_label_1.eps', format='eps')
        acme_vlr = acme_vlr.explain(dataset, robust=True, label_class=0)
        summary_plot_2 = acme_vlr.summary_plot()
        summary_plot_2.show()
        summary_plot_2.write_image(file='./image_acme/bone_label_0.eps', format='eps')
        bar_plot = acme_vlr.bar_plot()
        bar_plot.show()
        bar_plot.write_image(file='./image_acme/bone_bar.eps', format='eps')
        break


all positive 39 [('survival_time', 34, '-2.9097 ± 0.2788'), ('Relapse', 24, '1.1336 ± 0.2643'), ('PLTrecovery', 32, '1.0959 ± 0.3375'), ('Rbodymass', 30, '0.6519 ± 0.4387'), ('Donorage35', 3, '0.5763 ± 0.3776'), ('extcGvHD', 26, '-0.5753 ± 0.2472'), ('ANCrecovery', 31, '0.4918 ± 0.3829'), ('AML', 35, '-0.4342 ± 0.3862'), ('Txpostrelapse', 14, '0.423 ± 0.3028'), ('ABOmatch', 9, '-0.3894 ± 0.2578'), ('CD3dkgx10d8', 29, '-0.3367 ± 0.3193'), ('Riskgroup', 13, '0.3067 ± 0.3005'), ('CD3dCD34', 28, '0.3009 ± 0.2971'), ('lymphoma', 37, '0.2585 ± 0.4522'), ('HLAmismatch', 17, '0.2286 ± 0.406'), ('RecipientCMV', 12, '0.1951 ± 0.4678'), ('DonorABO', 6, '0.1876 ± 0.2598'), ('Alel', 19, '0.1738 ± 0.4578'), ('Recipientage', 21, '0.1712 ± 0.515'), ('CD34kgx10d6', 27, '0.1562 ± 0.327'), ('CMVstatus', 10, '0.1526 ± 0.5024'), ('IIIV', 4, '0.1387 ± 0.2754'), ('aGvHDIIIIV', 25, '-0.1329 ± 0.5213'), ('Stemcellsource', 1, '-0.1288 ± 0.3245'), ('Gendermatch', 5, '-0.1137 ± 0.2698'), ('HLAmatch', 16, '0.1122 

all positive 39 [('survival_time', 34, '-2.9864 ± 0.2771'), ('Relapse', 24, '1.1425 ± 0.2829'), ('PLTrecovery', 32, '0.859 ± 0.3972'), ('Txpostrelapse', 14, '0.7729 ± 0.2909'), ('Rbodymass', 30, '0.5583 ± 0.4379'), ('ANCrecovery', 31, '0.5579 ± 0.4149'), ('RecipientRh', 8, '0.494 ± 0.2571'), ('ABOmatch', 9, '-0.4448 ± 0.2564'), ('RecipientCMV', 12, '0.4104 ± 0.4766'), ('nonmalignant', 38, '-0.408 ± 0.5109'), ('RecipientABO', 7, '-0.4022 ± 0.2642'), ('CD3dkgx10d8', 29, '-0.3613 ± 0.3342'), ('extcGvHD', 26, '-0.3556 ± 0.2549'), ('Alel', 19, '0.3391 ± 0.4397'), ('chronic', 36, '-0.3377 ± 0.3767'), ('Donorage35', 3, '0.2831 ± 0.3803'), ('CMVstatus', 10, '0.2802 ± 0.5228'), ('HLAmatch', 16, '0.2482 ± 0.6607'), ('DonorCMV', 11, '-0.2332 ± 0.3174'), ('AML', 35, '-0.2207 ± 0.3797'), ('Diseasegroup', 15, '0.2139 ± 0.5164'), ('aGvHDIIIIV', 25, '-0.1993 ± 0.4932'), ('lymphoma', 37, '0.1962 ± 0.4295'), ('Riskgroup', 13, '-0.1957 ± 0.3037'), ('Recipientgender', 0, '-0.1848 ± 0.2686'), ('CD34kgx10d6