In [2]:
from pyparsing import col
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import statistics
import pickle

from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.metrics import precision_score, recall_score, roc_auc_score
from xgboost import XGBClassifier
from sklearn.neural_network import MLPClassifier

from util import data, metrics, adult_data, credit_data, crime_data, plots
from models.DECAF import DECAF

In [3]:
dfr, Xy, min_max_scaler = crime_data.load()
dm = data.DataModule(Xy)
X = Xy[:, :-1].astype(np.float32)
y = np.round(Xy[:, -1]).astype(np.uint32)

dag_seed = [[84, 67], [41, 61], [0, 22], [43, 9], [54, 57], [22, 66], [3, 33], [86, 87], [62, 82], [86, 73], [74, 40], [91, 0], [77, 70], [91, 74], [45, 8], [67, 70], [62, 92], [8, 11], [42, 11], [50, 23], [19, 34], [80, 84], [85, 48], [29, 8], [51, 34], [72, 97], [56, 69], [11, 9], [74, 89], [85, 10], [29, 72], [24, 27], [92, 59], [94, 76], [96, 95], [92, 58], [68, 15], [4, 46], [41, 76], [65, 70], [36, 53], [73, 37], [13, 78], [36, 8], [91, 63], [94, 99], [44, 77], [5, 23], [84, 65], [21, 82], [66, 88], [32, 47], [80, 82], [94, 20], [85, 83], [63, 40], [17, 38], [43, 8], [18, 78], [4, 87], [39, 10], [20, 18], [51, 5], [0, 45], [23, 62], [83, 36], [21, 19], [65, 3], [14, 17], [83, 98], [21, 10], [63, 69], [44, 45], [11, 92], [40, 9], [49, 34], [85, 82], [35, 73], [16, 92], [15, 73], [82, 98], [43, 66], [13, 64], [13, 100], [74, 88], [40, 8], [45, 74], [95, 28], [43, 63], [83, 29], [13, 35], [65, 44], [38, 23], [33, 70], [29, 56], [49, 44], [49, 17], [45, 47], [8, 69], [18, 54], [29, 42], [49, 48], [3, 34], [16, 19], [69, 9], [75, 35], [13, 97], [36, 42], [64, 68], [72, 52], [67, 44], [43, 42], [95, 94], [84, 11], [92, 60], [65, 74], [13, 89], [65, 17], [67, 74], [19, 10], [61, 7], [0, 67], [81, 92], [58, 93], [78, 31], [78, 37], [4, 75], [99, 37], [95, 75], [14, 74], [8, 38], [85, 49], [92, 32], [86, 4], [21, 8], [65, 10], [83, 89], [22, 82], [78, 73], [18, 11], [94, 89], [41, 69], [94, 97], [80, 17], [67, 31], [85, 14], [13, 95], [64, 63], [87, 64], [14, 63], [67, 66], [34, 82], [32, 57], [20, 73], [44, 17], [50, 52], [23, 82], [46, 69], [39, 5], [80, 66], [80, 19], [29, 18], [30, 32], [65, 63], [17, 82], [0, 80], [75, 29], [28, 78], [5, 58], [0, 78], [85, 84], [86, 96], [80, 18], [72, 100], [34, 62], [14, 9], [96, 73], [16, 67], [14, 8], [62, 31], [40, 3], [39, 62], [29, 34], [45, 23], [74, 77], [5, 92], [75, 45], [77, 93], [34, 33], [19, 8], [22, 44], [44, 23], [4, 2], [84, 43], [12, 14], [56, 20], [60, 98], [41, 94], [42, 40], [49, 56], [67, 21], [95, 35], [60, 31], [43, 74], [45, 63], [18, 33], [21, 63], [70, 88], [35, 37], [49, 67], [17, 33], [33, 62], [44, 39], [45, 94], [66, 90], [77, 38], [34, 47], [43, 19], [12, 97], [6, 7], [85, 65], [29, 10], [20, 30], [50, 63], [23, 81], [29, 44], [16, 36], [73, 76], [17, 10], [94, 88], [13, 76], [8, 89], [68, 30], [64, 9], [82, 6], [42, 101], [65, 75], [64, 8], [38, 70], [56, 54], [85, 45], [56, 55], [84, 83], [49, 23], [68, 69], [83, 45], [75, 76], [34, 11], [21, 42], [10, 34], [16, 77], [50, 42], [54, 53], [2, 28], [14, 57], [66, 60], [75, 49], [67, 93], [51, 10], [18, 60], [40, 30], [80, 49], [76, 79], [12, 100], [14, 70], [51, 89], [35, 99], [50, 97], [58, 57], [96, 94], [13, 41], [49, 11], [20, 3], [37, 32], [65, 67], [21, 29], [60, 52], [89, 7], [34, 31], [17, 93], [33, 11], [29, 46], [80, 67], [20, 15], [20, 35], [94, 68], [19, 63], [40, 10], [100, 101], [75, 41], [56, 27], [95, 20], [74, 98], [86, 22], [36, 59], [21, 56], [46, 40], [66, 51], [2, 97], [36, 77], [21, 40], [86, 0], [10, 69], [80, 21], [29, 28], [63, 61], [8, 15], [70, 71], [22, 17], [2, 6], [91, 28], [43, 33], [22, 27], [73, 72], [65, 94], [34, 93], [15, 88], [21, 36], [67, 94], [4, 70], [77, 33], [28, 40], [87, 89], [17, 31], [63, 17], [64, 20], [21, 74], [65, 46], [36, 46], [86, 64], [70, 69], [67, 29], [22, 21], [4, 82], [8, 70], [22, 15], [0, 21], [19, 11], [64, 42], [40, 39], [96, 11], [87, 93], [90, 99], [85, 66], [24, 82], [61, 78], [23, 25], [68, 89], [61, 6], [67, 18], [83, 74], [29, 94], [51, 31], [80, 41], [89, 61], [59, 32], [70, 31], [39, 32], [83, 9], [21, 32], [60, 58], [77, 98], [60, 59], [10, 88], [43, 46], [67, 7], [44, 42], [94, 18], [62, 30], [4, 44], [69, 48], [4, 0], [51, 42], [46, 77], [51, 94], [86, 89], [2, 13], [29, 30], [83, 66], [2, 94], [39, 9], [49, 46], [89, 76], [95, 32], [73, 52], [69, 59], [22, 94], [66, 14], [86, 24], [43, 11], [14, 37], [43, 23], [46, 9], [86, 94], [59, 58], [13, 3], [23, 48], [51, 100], [67, 6], [0, 94], [66, 20], [65, 66], [20, 31], [77, 31], [22, 67], [50, 100], [21, 78], [94, 6], [95, 43], [95, 45], [10, 71], [64, 34], [59, 73], [28, 52], [21, 88], [60, 57], [14, 93], [40, 15], [21, 51], [36, 31], [36, 73], [87, 35], [94, 56], [85, 70], [15, 18], [2, 42], [95, 40], [31, 26], [29, 15], [22, 80], [19, 33], [21, 14], [38, 33], [38, 37], [38, 93], [86, 80], [17, 24], [66, 68], [67, 61], [16, 46], [83, 92], [64, 18], [19, 93], [4, 19], [41, 34], [10, 9], [4, 16], [88, 98], [87, 46], [96, 93], [87, 95], [61, 30], [4, 50], [16, 95], [19, 46], [34, 15], [28, 50], [21, 34], [96, 21], [56, 53], [75, 67], [65, 21], [10, 8], [95, 19], [36, 37], [4, 84], [19, 31], [80, 61], [89, 62], [17, 89], [12, 75], [77, 90], [2, 12], [36, 39], [77, 23], [13, 67], [11, 70], [77, 60], [70, 73], [86, 67], [14, 46], [87, 66], [54, 26], [75, 94], [0, 89], [4, 40], [13, 12], [64, 44], [75, 42], [87, 33], [49, 21], [36, 35], [74, 99], [22, 42], [4, 68], [64, 62], [31, 32], [4, 45], [8, 35], [70, 92], [89, 18], [0, 64], [96, 84], [96, 70], [86, 85], [80, 64], [49, 15], [4, 49], [36, 20], [65, 20], [87, 34], [89, 88], [29, 88], [18, 31], [88, 97], [14, 24], [77, 76], [44, 68], [62, 79], [77, 35], [44, 94], [53, 57], [11, 69], [61, 58], [35, 52], [96, 83], [43, 64], [77, 99], [66, 18], [95, 67], [19, 73], [67, 24], [19, 37], [95, 21], [29, 78], [96, 92], [14, 15], [19, 61], [66, 74], [69, 60], [21, 15], [22, 40], [68, 3], [74, 63], [41, 17], [0, 66], [89, 69], [89, 6], [14, 78], [63, 56], [64, 11], [17, 73], [13, 8], [80, 78], [28, 17], [86, 21], [4, 51], [85, 87], [85, 37], [46, 44], [19, 15], [95, 36], [62, 61], [95, 98], [46, 51], [74, 93], [92, 7], [64, 14], [28, 51], [14, 99], [95, 65], [16, 35], [64, 49], [19, 41], [68, 9], [4, 77], [36, 61], [69, 62], [46, 70], [90, 52], [94, 9], [95, 77], [83, 94], [29, 40], [66, 42], [22, 64], [64, 67], [43, 20], [22, 32], [33, 61], [4, 62], [29, 20], [22, 60], [58, 35], [11, 97], [4, 29], [0, 42], [63, 31], [19, 35], [12, 18], [64, 21], [4, 43], [15, 11], [29, 69], [94, 82], [42, 18], [94, 3], [10, 15], [22, 6], [66, 94], [17, 30], [4, 5], [94, 31], [41, 39], [2, 50], [80, 81], [51, 99], [80, 89], [41, 59], [4, 23], [46, 47], [36, 40], [87, 14], [80, 40], [8, 9], [89, 79], [16, 30], [98, 99], [22, 62], [31, 30], [98, 30], [86, 65], [22, 65], [33, 6], [68, 74], [62, 6], [36, 32], [86, 84], [42, 41], [87, 38], [74, 70], [78, 98], [20, 39], [82, 73], [20, 10], [50, 90], [22, 26], [22, 78], [75, 44], [22, 14], [77, 88], [86, 14], [40, 77], [33, 73], [91, 90], [80, 15], [85, 46], [69, 55], [11, 76], [28, 62], [60, 6], [44, 18], [80, 45], [36, 45], [4, 83], [28, 72], [68, 63], [13, 44], [67, 46], [66, 44], [19, 79], [95, 29], [67, 45], [65, 69], [17, 78], [84, 8], [74, 8], [21, 94], [67, 41], [41, 74], [69, 76], [80, 88], [64, 89], [49, 30], [45, 77], [75, 68], [13, 66], [22, 49], [68, 56], [92, 93], [68, 59], [39, 78], [66, 23], [40, 58], [84, 42], [3, 15], [67, 15], [70, 93], [4, 67], [36, 41], [38, 30], [13, 98], [16, 94], [69, 88], [87, 63], [22, 20], [85, 96], [41, 3], [95, 11], [98, 73], [17, 35], [51, 6], [42, 15], [62, 7], [67, 40], [5, 6], [87, 36], [85, 88], [69, 101], [46, 39], [96, 67], [67, 68], [87, 9], [85, 40], [44, 47], [65, 61], [55, 58], [54, 58], [55, 59], [85, 68], [78, 92], [19, 40], [65, 14], [80, 68], [0, 34], [75, 46], [14, 36], [80, 14], [65, 68], [63, 10], [83, 64], [78, 79], [95, 42], [51, 18], [83, 78], [19, 56], [54, 61], [83, 42], [98, 97], [75, 28], [80, 96], [61, 88], [36, 62], [4, 90], [69, 61], [67, 36], [84, 82], [2, 91], [62, 60], [96, 68], [17, 9], [64, 15], [67, 14], [95, 99], [75, 39], [56, 10], [80, 59], [36, 38], [61, 31], [75, 38], [14, 68], [45, 51], [34, 88], [86, 74], [20, 70], [13, 19], [22, 37], [11, 62], [13, 50], [64, 70], [13, 37], [3, 8], [75, 73], [91, 50], [9, 82], [51, 40], [37, 31], [29, 19], [66, 61], [8, 34], [8, 82], [0, 19], [29, 38], [14, 29], [96, 38], [77, 62], [83, 76], [29, 50], [43, 68], [44, 101], [10, 54], [10, 55], [56, 33], [75, 52], [62, 58], [29, 77], [2, 90], [83, 3], [4, 95], [34, 9], [69, 7], [14, 32], [83, 44], [65, 43], [34, 18], [49, 77], [22, 23], [66, 3], [43, 34], [87, 65], [95, 66], [43, 67], [22, 38], [74, 76], [19, 6], [6, 58], [41, 89], [2, 72], [61, 35], [86, 25], [22, 87], [13, 16], [93, 35], [19, 32], [56, 52], [49, 18], [87, 75], [69, 5], [65, 29], [89, 73], [49, 29], [89, 15], [20, 17], [87, 88], [94, 77], [91, 52], [38, 32], [67, 37], [22, 10], [29, 45], [52, 7], [0, 96], [74, 62], [95, 89], [87, 68], [68, 7], [87, 78], [56, 92], [83, 75], [65, 34], [19, 78], [55, 54], [68, 5], [48, 61], [14, 19], [0, 68], [19, 20], [73, 30], [44, 32], [36, 17], [0, 70], [42, 39], [14, 34], [21, 46], [56, 60], [82, 81], [24, 23], [94, 10], [61, 60], [44, 40], [2, 52], [65, 19], [5, 9], [29, 32], [64, 10], [61, 92], [18, 77], [68, 34], [96, 32], [38, 26], [16, 44], [20, 34], [45, 17], [16, 49], [87, 49], [17, 48], [75, 77], [75, 43], [80, 73], [29, 36], [86, 43], [36, 18], [75, 92], [93, 31], [80, 43], [16, 20], [96, 19], [43, 17], [98, 93], [51, 68], [45, 101], [41, 93], [51, 8], [68, 38], [96, 51], [96, 34], [16, 65], [45, 18], [51, 88], [19, 44], [43, 29], [20, 32], [49, 68], [0, 43], [66, 15], [89, 99], [89, 33], [29, 74], [22, 43], [80, 83], [3, 72], [51, 50], [17, 32], [39, 30], [66, 41], [80, 85], [45, 34], [6, 57], [56, 23], [18, 32], [16, 29], [56, 59], [15, 31], [67, 69], [44, 8], [49, 20], [13, 77], [18, 38], [4, 27], [10, 3], [95, 41], [85, 67], [48, 6], [0, 87], [41, 31], [49, 19], [92, 30], [13, 70], [94, 39], [2, 43], [22, 69], [87, 29], [83, 13], [75, 69], [3, 11], [68, 20], [30, 72], [20, 40], [68, 76], [4,101]]
bias_dict_FTU = {101: [4]}
bias_dict_CF = {101: [100, 4, 45, 69]}
bias_dict_DP = {101: [42, 100, 69, 44, 45, 4]}

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


In [4]:
original = {
    "precision": [],
    "recall": [],
    "auroc": [],
    "FTU": [],
    "DP": []
}

DECAF_ND = {
    "precision": [],
    "recall": [],
    "auroc": [],
    "FTU": [],
    "DP": []
}
DECAF_FTU = {
    "precision": [],
    "recall": [],
    "auroc": [],
    "FTU": [],
    "DP": []
}
DECAF_CF = {
    "precision": [],
    "recall": [],
    "auroc": [],
    "FTU": [],
    "DP": []
}
DECAF_DP = {
    "precision": [],
    "recall": [],
    "auroc": [],
    "FTU": [],
    "DP": []
}

def calculate_scores(Xy_synth, fairness_type):
    ## Takes Xy_synthetic data (including label) of the adult data set and fairness type
    ## Calculates the precision, recall, auroc, FTU and DP metrics
    X_synth = Xy_synth[:, :-1]
    y_synth = np.round(Xy_synth[:, -1]).astype(int)  
    
    synth_clf = MLPClassifier(max_iter=2000).fit(X_synth, y_synth)

    y_pred_synth = synth_clf.predict(X)
    y_pred_synth_proba = synth_clf.predict_proba(X)
    
    if fairness_type == "original":
        original["precision"].append(precision_score(y, y_pred_synth))
        original["recall"].append(precision_score(y, y_pred_synth))
        original["auroc"].append(roc_auc_score(y, y_pred_synth_proba[:, 1]))
        original["FTU"].append(metrics.ftu(synth_clf, X_synth, 4))
        original["DP"].append(metrics.dp(synth_clf, X_synth, 4))
    elif fairness_type == "ND":
        DECAF_ND["precision"].append(precision_score(y, y_pred_synth))
        DECAF_ND["recall"].append(recall_score(y, y_pred_synth))
        DECAF_ND["auroc"].append(roc_auc_score(y, y_pred_synth_proba[:, 1]))
        DECAF_ND["FTU"].append(metrics.ftu(synth_clf, X_synth, 4))
        DECAF_ND["DP"].append(metrics.dp(synth_clf, X_synth, 4))
    elif fairness_type == "FTU":
        DECAF_FTU["precision"].append(precision_score(y, y_pred_synth))
        DECAF_FTU["recall"].append(recall_score(y, y_pred_synth))
        DECAF_FTU["auroc"].append(roc_auc_score(y, y_pred_synth_proba[:, 1]))
        DECAF_FTU["FTU"].append(metrics.ftu(synth_clf, X_synth, 4))
        DECAF_FTU["DP"].append(metrics.dp(synth_clf, X_synth, 4))
    elif fairness_type == "CF":
        DECAF_CF["precision"].append(precision_score(y, y_pred_synth))
        DECAF_CF["recall"].append(recall_score(y, y_pred_synth))
        DECAF_CF["auroc"].append(roc_auc_score(y, y_pred_synth_proba[:, 1]))
        DECAF_CF["FTU"].append(metrics.ftu(synth_clf, X_synth, 4))
        DECAF_CF["DP"].append(metrics.dp(synth_clf, X_synth, 4))    
    elif fairness_type == "DP":
        DECAF_DP["precision"].append(precision_score(y, y_pred_synth))
        print("precision done")
        DECAF_DP["recall"].append(recall_score(y, y_pred_synth))
        print("recall done")
        DECAF_DP["auroc"].append(roc_auc_score(y, y_pred_synth_proba[:, 1]))
        print("auroc done")
        DECAF_DP["FTU"].append(metrics.ftu(synth_clf, X_synth, 4))
        print("ftu done")
        DECAF_DP["DP"].append(metrics.dp(synth_clf, X_synth, 4))
        print("dp done")
    else:
        print("Warning: fairness_type not recognized")
    return

In [5]:
try1 = range(10) # 50 epochs

for i in try1:
    print(i)
    model = DECAF(input_dim=dm.dims[0])    
    trained_model = "logs/DECAF_crime/version_{}/checkpoints/epoch=49-step=1449.ckpt".format(i)
    model = model.load_from_checkpoint(trained_model)
    
    calculate_scores(Xy, fairness_type="original")
    
    synthetic_data = model.gen_synthetic(dm.dataset.x, gen_order=model.get_gen_order(), biased_edges={}).detach().numpy()
    calculate_scores(synthetic_data, fairness_type="ND")
    
    synthetic_data_FTU = model.gen_synthetic(dm.dataset.x, gen_order=model.get_gen_order(), biased_edges=bias_dict_FTU).detach().numpy()
    calculate_scores(synthetic_data_FTU, fairness_type="FTU")
    
    synthetic_data_CF = model.gen_synthetic(dm.dataset.x, gen_order=model.get_gen_order(), biased_edges=bias_dict_CF).detach().numpy()
    calculate_scores(synthetic_data_CF, fairness_type="CF")
    
    synthetic_data_DP = model.gen_synthetic(dm.dataset.x, gen_order=model.get_gen_order(), biased_edges=bias_dict_DP).detach().numpy()    
    calculate_scores(synthetic_data_DP, fairness_type="DP")


0
Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


precision done
recall done
auroc done
ftu done
dp done
1


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
precision done
recall done
auroc done
ftu done
dp done
2


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
precision done
recall done
auroc done
ftu done
dp done
3


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
precision done
recall done
auroc done
ftu done
dp done
4


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


precision done
recall done
auroc done
ftu done
dp done
5


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
precision done
recall done
auroc done
ftu done
dp done
6


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
precision done
recall done
auroc done
ftu done
dp done
7


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
precision done
recall done
auroc done
ftu done
dp done
8


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


precision done
recall done
auroc done
ftu done
dp done
9


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
precision done
recall done
auroc done
ftu done
dp done


In [6]:
def calc_metrics(data):
    # Calculate mean and stdev of the data, returns a string
#     mean = statistics.mean(data)
#     print(mean)
#     stdev = statistics.stdev(data, mean)
#     print(stdev)
    data = np.array(data)
    mean = np.nanmean(data) # DP sometimes gives nan
    stdev = np.nanstd(data) # DP sometimes gives nan
    return "{:.3f}±{:.3f}".format(mean, stdev)

d = [
     ["original", calc_metrics(original["precision"]), calc_metrics(original["recall"]), calc_metrics(original["auroc"]), calc_metrics(original["FTU"]), calc_metrics(original["DP"])],
     ["DECAF-ND", calc_metrics(DECAF_ND["precision"]), calc_metrics(DECAF_ND["recall"]), calc_metrics(DECAF_ND["auroc"]), calc_metrics(DECAF_ND["FTU"]), calc_metrics(DECAF_ND["DP"])],
     ["DECAF-FTU", calc_metrics(DECAF_FTU["precision"]), calc_metrics(DECAF_FTU["recall"]), calc_metrics(DECAF_FTU["auroc"]), calc_metrics(DECAF_FTU["FTU"]), calc_metrics(DECAF_FTU["DP"])],
     ["DECAF-CF", calc_metrics(DECAF_CF["precision"]), calc_metrics(DECAF_CF["recall"]), calc_metrics(DECAF_CF["auroc"]), calc_metrics(DECAF_CF["FTU"]), calc_metrics(DECAF_CF["DP"])],
     ["DECAF-DP", calc_metrics(DECAF_DP["precision"]), calc_metrics(DECAF_DP["recall"]), calc_metrics(DECAF_DP["auroc"]), calc_metrics(DECAF_DP["FTU"]), calc_metrics(DECAF_DP["DP"])]
]

df = pd.DataFrame(d, columns = ["Model", "Precision", "Recall", "AUROC", "FTU", "DP"])
df

Unnamed: 0,Model,Precision,Recall,AUROC,FTU,DP
0,original,0.936±0.035,0.936±0.035,0.975±0.016,0.023±0.016,0.510±0.014
1,DECAF-ND,0.671±0.103,0.689±0.133,0.734±0.074,0.466±0.153,0.477±0.134
2,DECAF-FTU,0.587±0.062,0.697±0.143,0.661±0.086,0.132±0.132,0.122±0.067
3,DECAF-CF,0.528±0.060,0.610±0.134,0.549±0.093,0.088±0.052,0.045±0.036
4,DECAF-DP,0.499±0.037,0.631±0.157,0.518±0.068,0.045±0.044,0.016±0.010
