Skip to content

Commit b7815d8

Browse files
committed
add: random tie break for multilabel classification
1 parent bb3b579 commit b7815d8

File tree

6 files changed

+100
-49
lines changed

6 files changed

+100
-49
lines changed

modAL/multilabel.py

Lines changed: 87 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55

66
from modAL.models import ActiveLearner
77
from modAL.utils.data import modALinput
8-
from modAL.utils.selection import multi_argmax
8+
from modAL.utils.selection import multi_argmax, shuffled_argmax
99
from typing import Tuple, Optional
1010
from itertools import combinations
1111

1212

1313
def _SVM_loss(multiclass_classifier: ActiveLearner,
14-
X: modALinput,
15-
most_certain_classes: Optional[int] = None) -> np.ndarray:
14+
X: modALinput, most_certain_classes: Optional[int] = None) -> np.ndarray:
1615
"""
1716
Utility function for max_loss and mean_max_loss strategies.
1817
@@ -43,8 +42,8 @@ def _SVM_loss(multiclass_classifier: ActiveLearner,
4342
return cls_loss
4443

4544

46-
def SVM_binary_minimum(classifier: ActiveLearner,
47-
X_pool: modALinput) -> Tuple[np.ndarray, modALinput]:
45+
def SVM_binary_minimum(classifier: ActiveLearner, X_pool: modALinput,
46+
random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
4847
"""
4948
SVM binary minimum multilabel active learning strategy. For details see the paper
5049
Klaus Brinker, On Active Learning in Multi-label Classification
@@ -53,23 +52,30 @@ def SVM_binary_minimum(classifier: ActiveLearner,
5352
Args:
5453
classifier: The multilabel classifier for which the labels are to be queried. Must be an SVM model
5554
such as the ones from sklearn.svm.
56-
X: The pool of samples to query from.
55+
X_pool: The pool of samples to query from.
56+
random_tie_break: If True, shuffles utility scores to randomize the order. This
57+
can be used to break the tie when the highest utility score is not unique.
5758
5859
Returns:
59-
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
60+
The index of the instance from X_pool chosen to be labelled;
61+
the instance from X_pool chosen to be labelled.
6062
"""
6163

6264
decision_function = np.array([svm.decision_function(X_pool)
6365
for svm in classifier.estimator.estimators_]).T
6466

6567
min_abs_dist = np.min(np.abs(decision_function), axis=1)
66-
query_idx = np.argmin(min_abs_dist)
68+
69+
if not random_tie_break:
70+
query_idx = np.argmin(min_abs_dist)
71+
else:
72+
query_idx = shuffled_argmax(min_abs_dist)
73+
6774
return query_idx, X_pool[query_idx]
6875

6976

70-
def max_loss(classifier: OneVsRestClassifier,
71-
X_pool: modALinput,
72-
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
77+
def max_loss(classifier: OneVsRestClassifier, X_pool: modALinput,
78+
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
7379

7480
"""
7581
Max Loss query strategy for SVM multilabel classification.
@@ -82,24 +88,30 @@ def max_loss(classifier: OneVsRestClassifier,
8288
classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
8389
such as the ones from sklearn.svm. Although the function will execute for other models as well,
8490
the mathematical calculations in Li et al. work only for SVM-s.
85-
X: The pool of samples to query from.
91+
X_pool: The pool of samples to query from.
92+
random_tie_break: If True, shuffles utility scores to randomize the order. This
93+
can be used to break the tie when the highest utility score is not unique.
8694
8795
Returns:
88-
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
96+
The index of the instance from X_pool chosen to be labelled;
97+
the instance from X_pool chosen to be labelled.
8998
"""
9099

91100
assert len(X_pool) >= n_instances, 'n_instances cannot be larger than len(X_pool)'
92101

93102
most_certain_classes = classifier.predict_proba(X_pool).argmax(axis=1)
94103
loss = _SVM_loss(classifier, X_pool, most_certain_classes=most_certain_classes)
95104

96-
query_idx = multi_argmax(loss, n_instances)
105+
if not random_tie_break:
106+
query_idx = multi_argmax(loss, n_instances)
107+
else:
108+
query_idx = shuffled_argmax(loss, n_instances)
109+
97110
return query_idx, X_pool[query_idx]
98111

99112

100-
def mean_max_loss(classifier: OneVsRestClassifier,
101-
X_pool: modALinput,
102-
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
113+
def mean_max_loss(classifier: OneVsRestClassifier, X_pool: modALinput,
114+
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
103115
"""
104116
Mean Max Loss query strategy for SVM multilabel classification.
105117
@@ -111,22 +123,28 @@ def mean_max_loss(classifier: OneVsRestClassifier,
111123
classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
112124
such as the ones from sklearn.svm. Although the function will execute for other models as well,
113125
the mathematical calculations in Li et al. work only for SVM-s.
114-
X: The pool of samples to query from.
126+
X_pool: The pool of samples to query from.
127+
random_tie_break: If True, shuffles utility scores to randomize the order. This
128+
can be used to break the tie when the highest utility score is not unique.
115129
116130
Returns:
117-
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
131+
The index of the instance from X_pool chosen to be labelled;
132+
the instance from X_pool chosen to be labelled.
118133
"""
119134

120135
assert len(X_pool) >= n_instances, 'n_instances cannot be larger than len(X_pool)'
121136
loss = _SVM_loss(classifier, X_pool)
122137

123-
query_idx = multi_argmax(loss, n_instances)
138+
if not random_tie_break:
139+
query_idx = multi_argmax(loss, n_instances)
140+
else:
141+
query_idx = shuffled_argmax(loss, n_instances)
142+
124143
return query_idx, X_pool[query_idx]
125144

126145

127-
def min_confidence(classifier: OneVsRestClassifier,
128-
X_pool: modALinput,
129-
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
146+
def min_confidence(classifier: OneVsRestClassifier, X_pool: modALinput,
147+
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
130148
"""
131149
MinConfidence query strategy for multilabel classification.
132150
@@ -136,22 +154,28 @@ def min_confidence(classifier: OneVsRestClassifier,
136154
137155
Args:
138156
classifier: The multilabel classifier for which the labels are to be queried.
139-
X: The pool of samples to query from.
157+
X_pool: The pool of samples to query from.
158+
random_tie_break: If True, shuffles utility scores to randomize the order. This
159+
can be used to break the tie when the highest utility score is not unique.
140160
141161
Returns:
142-
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
162+
The index of the instance from X_pool chosen to be labelled;
163+
the instance from X_pool chosen to be labelled.
143164
"""
144165

145166
classwise_confidence = classifier.predict_proba(X_pool)
146167
classwise_min = np.min(classwise_confidence, axis=1)
147-
query_idx = multi_argmax((-1)*classwise_min, n_instances)
168+
169+
if not random_tie_break:
170+
query_idx = multi_argmax(-classwise_min, n_instances)
171+
else:
172+
query_idx = shuffled_argmax(-classwise_min, n_instances)
148173

149174
return query_idx, X_pool[query_idx]
150175

151176

152-
def avg_confidence(classifier: OneVsRestClassifier,
153-
X_pool: modALinput,
154-
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
177+
def avg_confidence(classifier: OneVsRestClassifier, X_pool: modALinput,
178+
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
155179
"""
156180
AvgConfidence query strategy for multilabel classification.
157181
@@ -161,22 +185,28 @@ def avg_confidence(classifier: OneVsRestClassifier,
161185
162186
Args:
163187
classifier: The multilabel classifier for which the labels are to be queried.
164-
X: The pool of samples to query from.
188+
X_pool: The pool of samples to query from.
189+
random_tie_break: If True, shuffles utility scores to randomize the order. This
190+
can be used to break the tie when the highest utility score is not unique.
165191
166192
Returns:
167-
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
193+
The index of the instance from X_pool chosen to be labelled;
194+
the instance from X_pool chosen to be labelled.
168195
"""
169196

170197
classwise_confidence = classifier.predict_proba(X_pool)
171198
classwise_mean = np.mean(classwise_confidence, axis=1)
172-
query_idx = multi_argmax(classwise_mean, n_instances)
199+
200+
if not random_tie_break:
201+
query_idx = multi_argmax(classwise_mean, n_instances)
202+
else:
203+
query_idx = shuffled_argmax(classwise_mean, n_instances)
173204

174205
return query_idx, X_pool[query_idx]
175206

176207

177-
def max_score(classifier: OneVsRestClassifier,
178-
X_pool: modALinput,
179-
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
208+
def max_score(classifier: OneVsRestClassifier, X_pool: modALinput,
209+
n_instances: int = 1, random_tie_break: bool = 1) -> Tuple[np.ndarray, modALinput]:
180210
"""
181211
MaxScore query strategy for multilabel classification.
182212
@@ -186,24 +216,30 @@ def max_score(classifier: OneVsRestClassifier,
186216
187217
Args:
188218
classifier: The multilabel classifier for which the labels are to be queried.
189-
X: The pool of samples to query from.
219+
X_pool: The pool of samples to query from.
220+
random_tie_break: If True, shuffles utility scores to randomize the order. This
221+
can be used to break the tie when the highest utility score is not unique.
190222
191223
Returns:
192-
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
224+
The index of the instance from X_pool chosen to be labelled;
225+
the instance from X_pool chosen to be labelled.
193226
"""
194227

195228
classwise_confidence = classifier.predict_proba(X_pool)
196229
classwise_predictions = classifier.predict(X_pool)
197230
classwise_scores = classwise_confidence*(classwise_predictions - 1/2)
198231
classwise_max = np.max(classwise_scores, axis=1)
199-
query_idx = multi_argmax(classwise_max, n_instances)
232+
233+
if not random_tie_break:
234+
query_idx = multi_argmax(classwise_max, n_instances)
235+
else:
236+
query_idx = shuffled_argmax(classwise_max, n_instances)
200237

201238
return query_idx, X_pool[query_idx]
202239

203240

204-
def avg_score(classifier: OneVsRestClassifier,
205-
X_pool: modALinput,
206-
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
241+
def avg_score(classifier: OneVsRestClassifier, X_pool: modALinput,
242+
n_instances: int = 1, random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
207243
"""
208244
AvgScore query strategy for multilabel classification.
209245
@@ -213,16 +249,23 @@ def avg_score(classifier: OneVsRestClassifier,
213249
214250
Args:
215251
classifier: The multilabel classifier for which the labels are to be queried.
216-
X: The pool of samples to query from.
252+
X_pool: The pool of samples to query from.
253+
random_tie_break: If True, shuffles utility scores to randomize the order. This
254+
can be used to break the tie when the highest utility score is not unique.
217255
218256
Returns:
219-
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
257+
The index of the instance from X_pool chosen to be labelled;
258+
the instance from X_pool chosen to be labelled.
220259
"""
221260

222261
classwise_confidence = classifier.predict_proba(X_pool)
223262
classwise_predictions = classifier.predict(X_pool)
224263
classwise_scores = classwise_confidence*(classwise_predictions-1/2)
225264
classwise_mean = np.mean(classwise_scores, axis=1)
226-
query_idx = multi_argmax(classwise_mean, n_instances)
265+
266+
if not random_tie_break:
267+
query_idx = multi_argmax(classwise_mean, n_instances)
268+
else:
269+
query_idx = shuffled_argmax(classwise_mean, n_instances)
227270

228271
return query_idx, X_pool[query_idx]

tests/core_tests.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,14 +1049,22 @@ def test_strategies(self):
10491049
classifier.fit(X_training, y_training)
10501050

10511051
active_learner = modAL.models.ActiveLearner(classifier)
1052+
# no random tie break
10521053
modAL.multilabel.SVM_binary_minimum(active_learner, X_pool)
1053-
10541054
modAL.multilabel.mean_max_loss(classifier, X_pool, n_query_instances)
10551055
modAL.multilabel.max_loss(classifier, X_pool, n_query_instances)
10561056
modAL.multilabel.min_confidence(classifier, X_pool, n_query_instances)
10571057
modAL.multilabel.avg_confidence(classifier, X_pool, n_query_instances)
10581058
modAL.multilabel.max_score(classifier, X_pool, n_query_instances)
10591059
modAL.multilabel.avg_score(classifier, X_pool, n_query_instances)
1060+
# random tie break
1061+
modAL.multilabel.SVM_binary_minimum(active_learner, X_pool, random_tie_break=True)
1062+
modAL.multilabel.mean_max_loss(classifier, X_pool, n_query_instances, random_tie_break=True)
1063+
modAL.multilabel.max_loss(classifier, X_pool, n_query_instances, random_tie_break=True)
1064+
modAL.multilabel.min_confidence(classifier, X_pool, n_query_instances, random_tie_break=True)
1065+
modAL.multilabel.avg_confidence(classifier, X_pool, n_query_instances, random_tie_break=True)
1066+
modAL.multilabel.max_score(classifier, X_pool, n_query_instances, random_tie_break=True)
1067+
modAL.multilabel.avg_score(classifier, X_pool, n_query_instances, random_tie_break=True)
10601068

10611069

10621070
class TestExamples(unittest.TestCase):

tests/example_tests/ensemble.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
learner_list = []
3333
for _ in range(n_learners):
3434
learner = ActiveLearner(
35-
estimator=RandomForestClassifier(),
35+
estimator=RandomForestClassifier(n_estimators=10),
3636
X_training=X_pool[initial_idx], y_training=y_pool[initial_idx],
3737
bootstrap_init=True
3838
)

tests/example_tests/query_by_committee.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
# initializing learner
3232
learner = ActiveLearner(
33-
estimator=RandomForestClassifier(),
33+
estimator=RandomForestClassifier(n_estimators=10),
3434
X_training=X_train, y_training=y_train
3535
)
3636
learner_list.append(learner)

tests/example_tests/shape_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
# create an ActiveLearner instance
3636
learner = ActiveLearner(
37-
estimator=RandomForestClassifier(),
37+
estimator=RandomForestClassifier(n_estimators=10),
3838
X_training=X_train, y_training=y_train
3939
)
4040
initial_prediction = learner.predict_proba(X_full)[:, 1].reshape(im_height, im_width)

tests/example_tests/stream_based_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
# initialize the learner
3232
learner = ActiveLearner(
33-
estimator=RandomForestClassifier(),
33+
estimator=RandomForestClassifier(n_estimators=10),
3434
X_training=X_train, y_training=y_train
3535
)
3636

0 commit comments

Comments
 (0)