-
Notifications
You must be signed in to change notification settings - Fork 325
/
Copy pathcustom_query_strategies.py
111 lines (92 loc) · 3.81 KB
/
custom_query_strategies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
=============================
Template for query strategies
=============================
The first two arguments of a query strategy function is always the estimator and the pool
of instances to be queried from. Additional arguments are accepted as keyword arguments.
A valid query strategy function always returns indices of the queried
instances.
def custom_query_strategy(classifier, X, a_keyword_argument=42):
# measure the utility of each instance in the pool
utility = utility_measure(classifier, X)
# select and return the indices of the instances to be queried
return select_instances(utility)
This function can be used in the active learning workflow.
learner = ActiveLearner(classifier, query_strategy=custom_query_strategy)
query_idx, query_instance = learner.query(X_pool, a_keyword_argument=42*42)
In this example, we will construct a custom query function by combining classifier uncertainty
and classifier margin.
"""
import matplotlib.pyplot as plt
import numpy as np
from modAL.models import ActiveLearner
from modAL.uncertainty import classifier_margin, classifier_uncertainty
from modAL.utils.combination import make_linear_combination, make_product
from modAL.utils.selection import multi_argmax
from sklearn.datasets import make_blobs
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
# generating the data
centers = np.asarray([[-2, 3], [0.5, 5], [1, 1.5]])
X, y = make_blobs(
n_features=2, n_samples=1000, random_state=0, cluster_std=0.7,
centers=centers
)
# visualizing the dataset
with plt.style.context('seaborn-white'):
plt.figure(figsize=(7, 7))
plt.scatter(x=X[:, 0], y=X[:, 1], c=y, cmap='viridis', s=50)
plt.title('The dataset')
plt.show()
# initial training data
initial_idx = np.random.choice(range(len(X)), size=20)
X_training, y_training = X[initial_idx], y[initial_idx]
# initializing the learner
learner = ActiveLearner(
estimator=GaussianProcessClassifier(1.0 * RBF(1.0)),
X_training=X_training, y_training=y_training
)
# creating new utility measures by linear combination and product
# linear_combination will return 1.0*classifier_uncertainty + 1.0*classifier_margin
linear_combination = make_linear_combination(
classifier_uncertainty, classifier_margin,
weights=[1.0, 1.0]
)
# product will return (classifier_uncertainty**0.5)*(classifier_margin**0.1)
product = make_product(
classifier_uncertainty, classifier_margin,
exponents=[0.5, 0.1]
)
# visualizing the different utility metrics
with plt.style.context('seaborn-white'):
utilities = [
(1, classifier_uncertainty(learner, X), 'Classifier uncertainty'),
(2, classifier_margin(learner, X), 'Classifier margin'),
(3, linear_combination(learner, X), '1.0*uncertainty + 1.0*margin'),
(4, product(learner, X), '(uncertainty**0.5)*(margin**0.5)')
]
plt.figure(figsize=(18, 14))
for idx, utility, title in utilities:
plt.subplot(2, 2, idx)
plt.scatter(x=X[:, 0], y=X[:, 1], c=utility, cmap='viridis', s=50)
plt.title(title)
plt.colorbar()
plt.show()
# defining the custom query strategy, which uses the linear combination of
# classifier uncertainty and classifier margin
def custom_query_strategy(classifier, X, n_instances=1):
utility = linear_combination(classifier, X)
return multi_argmax(utility, n_instances=n_instances)
custom_query_learner = ActiveLearner(
estimator=GaussianProcessClassifier(1.0 * RBF(1.0)),
query_strategy=custom_query_strategy,
X_training=X_training, y_training=y_training
)
# pool-based sampling
n_queries = 20
for idx in range(n_queries):
query_idx, query_instance = custom_query_learner.query(X, n_instances=2)
custom_query_learner.teach(
X=X[query_idx].reshape(-1, 2),
y=y[query_idx].reshape(-1, )
)