-
Notifications
You must be signed in to change notification settings - Fork 325
/
Copy pathmultilabel_svm.py
34 lines (29 loc) · 1.24 KB
/
multilabel_svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import matplotlib.pyplot as plt
import numpy as np
from modAL.models import ActiveLearner
from modAL.multilabel import *
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
n_samples = 500
X = np.random.normal(size=(n_samples, 2))
y = np.array([[int(x1 > 0), int(x2 > 0)] for x1, x2 in X])
n_initial = 10
initial_idx = np.random.choice(range(len(X)), size=n_initial, replace=False)
X_initial, y_initial = X[initial_idx], y[initial_idx]
X_pool, y_pool = np.delete(X, initial_idx, axis=0), np.delete(y, initial_idx, axis=0)
with plt.style.context('seaborn-white'):
plt.figure(figsize=(10, 10))
plt.scatter(X[:, 0], X[:, 1], c='k', s=20)
plt.scatter(X[y[:, 0] == 1, 0], X[y[:, 0] == 1, 1],
facecolors='none', edgecolors='b', s=50, linewidths=2, label='class 1')
plt.scatter(X[y[:, 1] == 1, 0], X[y[:, 1] == 1, 1],
facecolors='none', edgecolors='r', s=100, linewidths=2, label='class 2')
plt.legend()
plt.show()
learner = ActiveLearner(
estimator=OneVsRestClassifier(SVC(probability=True, gamma='auto')),
query_strategy=avg_score,
X_training=X_initial, y_training=y_initial
)
query_idx, query_inst = learner.query(X_pool)
learner.teach(X_pool[query_idx], y_pool[query_idx])