5
5
6
6
from modAL .models import ActiveLearner
7
7
from modAL .utils .data import modALinput
8
- from modAL .utils .selection import multi_argmax
8
+ from modAL .utils .selection import multi_argmax , shuffled_argmax
9
9
from typing import Tuple , Optional
10
10
from itertools import combinations
11
11
12
12
13
13
def _SVM_loss (multiclass_classifier : ActiveLearner ,
14
- X : modALinput ,
15
- most_certain_classes : Optional [int ] = None ) -> np .ndarray :
14
+ X : modALinput , most_certain_classes : Optional [int ] = None ) -> np .ndarray :
16
15
"""
17
16
Utility function for max_loss and mean_max_loss strategies.
18
17
@@ -43,8 +42,8 @@ def _SVM_loss(multiclass_classifier: ActiveLearner,
43
42
return cls_loss
44
43
45
44
46
- def SVM_binary_minimum (classifier : ActiveLearner ,
47
- X_pool : modALinput ) -> Tuple [np .ndarray , modALinput ]:
45
+ def SVM_binary_minimum (classifier : ActiveLearner , X_pool : modALinput ,
46
+ random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
48
47
"""
49
48
SVM binary minimum multilabel active learning strategy. For details see the paper
50
49
Klaus Brinker, On Active Learning in Multi-label Classification
@@ -53,23 +52,30 @@ def SVM_binary_minimum(classifier: ActiveLearner,
53
52
Args:
54
53
classifier: The multilabel classifier for which the labels are to be queried. Must be an SVM model
55
54
such as the ones from sklearn.svm.
56
- X: The pool of samples to query from.
55
+ X_pool: The pool of samples to query from.
56
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
57
+ can be used to break the tie when the highest utility score is not unique.
57
58
58
59
Returns:
59
- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
60
+ The index of the instance from X_pool chosen to be labelled;
61
+ the instance from X_pool chosen to be labelled.
60
62
"""
61
63
62
64
decision_function = np .array ([svm .decision_function (X_pool )
63
65
for svm in classifier .estimator .estimators_ ]).T
64
66
65
67
min_abs_dist = np .min (np .abs (decision_function ), axis = 1 )
66
- query_idx = np .argmin (min_abs_dist )
68
+
69
+ if not random_tie_break :
70
+ query_idx = np .argmin (min_abs_dist )
71
+ else :
72
+ query_idx = shuffled_argmax (min_abs_dist )
73
+
67
74
return query_idx , X_pool [query_idx ]
68
75
69
76
70
- def max_loss (classifier : OneVsRestClassifier ,
71
- X_pool : modALinput ,
72
- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
77
+ def max_loss (classifier : OneVsRestClassifier , X_pool : modALinput ,
78
+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
73
79
74
80
"""
75
81
Max Loss query strategy for SVM multilabel classification.
@@ -82,24 +88,30 @@ def max_loss(classifier: OneVsRestClassifier,
82
88
classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
83
89
such as the ones from sklearn.svm. Although the function will execute for other models as well,
84
90
the mathematical calculations in Li et al. work only for SVM-s.
85
- X: The pool of samples to query from.
91
+ X_pool: The pool of samples to query from.
92
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
93
+ can be used to break the tie when the highest utility score is not unique.
86
94
87
95
Returns:
88
- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
96
+ The index of the instance from X_pool chosen to be labelled;
97
+ the instance from X_pool chosen to be labelled.
89
98
"""
90
99
91
100
assert len (X_pool ) >= n_instances , 'n_instances cannot be larger than len(X_pool)'
92
101
93
102
most_certain_classes = classifier .predict_proba (X_pool ).argmax (axis = 1 )
94
103
loss = _SVM_loss (classifier , X_pool , most_certain_classes = most_certain_classes )
95
104
96
- query_idx = multi_argmax (loss , n_instances )
105
+ if not random_tie_break :
106
+ query_idx = multi_argmax (loss , n_instances )
107
+ else :
108
+ query_idx = shuffled_argmax (loss , n_instances )
109
+
97
110
return query_idx , X_pool [query_idx ]
98
111
99
112
100
- def mean_max_loss (classifier : OneVsRestClassifier ,
101
- X_pool : modALinput ,
102
- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
113
+ def mean_max_loss (classifier : OneVsRestClassifier , X_pool : modALinput ,
114
+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
103
115
"""
104
116
Mean Max Loss query strategy for SVM multilabel classification.
105
117
@@ -111,22 +123,28 @@ def mean_max_loss(classifier: OneVsRestClassifier,
111
123
classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
112
124
such as the ones from sklearn.svm. Although the function will execute for other models as well,
113
125
the mathematical calculations in Li et al. work only for SVM-s.
114
- X: The pool of samples to query from.
126
+ X_pool: The pool of samples to query from.
127
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
128
+ can be used to break the tie when the highest utility score is not unique.
115
129
116
130
Returns:
117
- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
131
+ The index of the instance from X_pool chosen to be labelled;
132
+ the instance from X_pool chosen to be labelled.
118
133
"""
119
134
120
135
assert len (X_pool ) >= n_instances , 'n_instances cannot be larger than len(X_pool)'
121
136
loss = _SVM_loss (classifier , X_pool )
122
137
123
- query_idx = multi_argmax (loss , n_instances )
138
+ if not random_tie_break :
139
+ query_idx = multi_argmax (loss , n_instances )
140
+ else :
141
+ query_idx = shuffled_argmax (loss , n_instances )
142
+
124
143
return query_idx , X_pool [query_idx ]
125
144
126
145
127
- def min_confidence (classifier : OneVsRestClassifier ,
128
- X_pool : modALinput ,
129
- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
146
+ def min_confidence (classifier : OneVsRestClassifier , X_pool : modALinput ,
147
+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
130
148
"""
131
149
MinConfidence query strategy for multilabel classification.
132
150
@@ -136,22 +154,28 @@ def min_confidence(classifier: OneVsRestClassifier,
136
154
137
155
Args:
138
156
classifier: The multilabel classifier for which the labels are to be queried.
139
- X: The pool of samples to query from.
157
+ X_pool: The pool of samples to query from.
158
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
159
+ can be used to break the tie when the highest utility score is not unique.
140
160
141
161
Returns:
142
- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
162
+ The index of the instance from X_pool chosen to be labelled;
163
+ the instance from X_pool chosen to be labelled.
143
164
"""
144
165
145
166
classwise_confidence = classifier .predict_proba (X_pool )
146
167
classwise_min = np .min (classwise_confidence , axis = 1 )
147
- query_idx = multi_argmax ((- 1 )* classwise_min , n_instances )
168
+
169
+ if not random_tie_break :
170
+ query_idx = multi_argmax (- classwise_min , n_instances )
171
+ else :
172
+ query_idx = shuffled_argmax (- classwise_min , n_instances )
148
173
149
174
return query_idx , X_pool [query_idx ]
150
175
151
176
152
- def avg_confidence (classifier : OneVsRestClassifier ,
153
- X_pool : modALinput ,
154
- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
177
+ def avg_confidence (classifier : OneVsRestClassifier , X_pool : modALinput ,
178
+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
155
179
"""
156
180
AvgConfidence query strategy for multilabel classification.
157
181
@@ -161,22 +185,28 @@ def avg_confidence(classifier: OneVsRestClassifier,
161
185
162
186
Args:
163
187
classifier: The multilabel classifier for which the labels are to be queried.
164
- X: The pool of samples to query from.
188
+ X_pool: The pool of samples to query from.
189
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
190
+ can be used to break the tie when the highest utility score is not unique.
165
191
166
192
Returns:
167
- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
193
+ The index of the instance from X_pool chosen to be labelled;
194
+ the instance from X_pool chosen to be labelled.
168
195
"""
169
196
170
197
classwise_confidence = classifier .predict_proba (X_pool )
171
198
classwise_mean = np .mean (classwise_confidence , axis = 1 )
172
- query_idx = multi_argmax (classwise_mean , n_instances )
199
+
200
+ if not random_tie_break :
201
+ query_idx = multi_argmax (classwise_mean , n_instances )
202
+ else :
203
+ query_idx = shuffled_argmax (classwise_mean , n_instances )
173
204
174
205
return query_idx , X_pool [query_idx ]
175
206
176
207
177
- def max_score (classifier : OneVsRestClassifier ,
178
- X_pool : modALinput ,
179
- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
208
+ def max_score (classifier : OneVsRestClassifier , X_pool : modALinput ,
209
+ n_instances : int = 1 , random_tie_break : bool = 1 ) -> Tuple [np .ndarray , modALinput ]:
180
210
"""
181
211
MaxScore query strategy for multilabel classification.
182
212
@@ -186,24 +216,30 @@ def max_score(classifier: OneVsRestClassifier,
186
216
187
217
Args:
188
218
classifier: The multilabel classifier for which the labels are to be queried.
189
- X: The pool of samples to query from.
219
+ X_pool: The pool of samples to query from.
220
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
221
+ can be used to break the tie when the highest utility score is not unique.
190
222
191
223
Returns:
192
- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
224
+ The index of the instance from X_pool chosen to be labelled;
225
+ the instance from X_pool chosen to be labelled.
193
226
"""
194
227
195
228
classwise_confidence = classifier .predict_proba (X_pool )
196
229
classwise_predictions = classifier .predict (X_pool )
197
230
classwise_scores = classwise_confidence * (classwise_predictions - 1 / 2 )
198
231
classwise_max = np .max (classwise_scores , axis = 1 )
199
- query_idx = multi_argmax (classwise_max , n_instances )
232
+
233
+ if not random_tie_break :
234
+ query_idx = multi_argmax (classwise_max , n_instances )
235
+ else :
236
+ query_idx = shuffled_argmax (classwise_max , n_instances )
200
237
201
238
return query_idx , X_pool [query_idx ]
202
239
203
240
204
- def avg_score (classifier : OneVsRestClassifier ,
205
- X_pool : modALinput ,
206
- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
241
+ def avg_score (classifier : OneVsRestClassifier , X_pool : modALinput ,
242
+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
207
243
"""
208
244
AvgScore query strategy for multilabel classification.
209
245
@@ -213,16 +249,23 @@ def avg_score(classifier: OneVsRestClassifier,
213
249
214
250
Args:
215
251
classifier: The multilabel classifier for which the labels are to be queried.
216
- X: The pool of samples to query from.
252
+ X_pool: The pool of samples to query from.
253
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
254
+ can be used to break the tie when the highest utility score is not unique.
217
255
218
256
Returns:
219
- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
257
+ The index of the instance from X_pool chosen to be labelled;
258
+ the instance from X_pool chosen to be labelled.
220
259
"""
221
260
222
261
classwise_confidence = classifier .predict_proba (X_pool )
223
262
classwise_predictions = classifier .predict (X_pool )
224
263
classwise_scores = classwise_confidence * (classwise_predictions - 1 / 2 )
225
264
classwise_mean = np .mean (classwise_scores , axis = 1 )
226
- query_idx = multi_argmax (classwise_mean , n_instances )
265
+
266
+ if not random_tie_break :
267
+ query_idx = multi_argmax (classwise_mean , n_instances )
268
+ else :
269
+ query_idx = shuffled_argmax (classwise_mean , n_instances )
227
270
228
271
return query_idx , X_pool [query_idx ]
0 commit comments