-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
79 lines (54 loc) · 2.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from glob import glob
import numpy as np
from sklearn.metrics import roc_curve, precision_recall_curve, auc
from keras.datasets import cifar100, cifar10
def normalize_minus1_1(data):
return 2*(data/255.) - 1
def get_channels_axis():
import keras
idf = keras.backend.image_data_format()
if idf == 'channels_first':
return 1
assert idf == 'channels_last'
return 3
def load_cifar10():
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
return (X_train, y_train), (X_test, y_test)
def load_cifar100(label_mode='coarse'):
(X_train, y_train), (X_test, y_test) = cifar100.load_data(label_mode=label_mode)
return (X_train, y_train), (X_test, y_test)
def save_roc_pr_curve_data(scores, labels, file_path):
scores = scores.flatten()
labels = labels.flatten()
scores_pos = scores[labels == 1]
scores_neg = scores[labels != 1]
truth = np.concatenate((np.zeros_like(scores_neg), np.ones_like(scores_pos)))
preds = np.concatenate((scores_neg, scores_pos))
fpr, tpr, roc_thresholds = roc_curve(truth, preds)
roc_auc = auc(fpr, tpr)
# pr curve where "normal" is the positive class
precision_norm, recall_norm, pr_thresholds_norm = precision_recall_curve(truth, preds)
pr_auc_norm = auc(recall_norm, precision_norm)
# pr curve where "anomaly" is the positive class
precision_anom, recall_anom, pr_thresholds_anom = precision_recall_curve(truth, -preds, pos_label=0)
pr_auc_anom = auc(recall_anom, precision_anom)
np.savez_compressed(file_path,
preds=preds, truth=truth,
fpr=fpr, tpr=tpr, roc_thresholds=roc_thresholds, roc_auc=roc_auc,
precision_norm=precision_norm, recall_norm=recall_norm,
pr_thresholds_norm=pr_thresholds_norm, pr_auc_norm=pr_auc_norm,
precision_anom=precision_anom, recall_anom=recall_anom,
pr_thresholds_anom=pr_thresholds_anom, pr_auc_anom=pr_auc_anom)
def get_class_name_from_index(index, dataset_name):
ind_to_name = {
'cifar10': ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'),
'cifar100': ('aquatic mammals', 'fish', 'flowers', 'food containers', 'fruit and vegetables',
'household electrical devices', 'household furniture', 'insects', 'large carnivores',
'large man-made outdoor things', 'large natural outdoor scenes', 'large omnivores and herbivores',
'medium-sized mammals', 'non-insect invertebrates', 'people', 'reptiles', 'small mammals', 'trees',
'vehicles 1', 'vehicles 2'),
'20news':('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19'),
'reuters': ('0', '1', '2', '3', '4'),
'caltech': ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'),
}
return ind_to_name[dataset_name][index]