Skip to content

Commit 218de8b

Browse files
committed
examples/keras_integration.py improved
1 parent a008c52 commit 218de8b

File tree

1 file changed

+18
-30
lines changed

1 file changed

+18
-30
lines changed

examples/keras_integration.py

+18-30
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
from modAL.models import ActiveLearner
1313

1414

15+
# build function for the Keras' scikit-learn API
1516
def create_keras_model():
1617
"""
1718
This function compiles and returns a Keras model.
18-
Should be passed for the KerasClassifier in the
19-
Keras scikit-learn API.
20-
:return: Keras model
19+
Should be passed to KerasClassifier in the Keras scikit-learn API.
2120
"""
2221
model = Sequential()
2322
model.add(Dense(512, activation='relu', input_shape=(784, )))
@@ -30,6 +29,9 @@ def create_keras_model():
3029
return model
3130

3231

32+
# create the classifier
33+
classifier = KerasClassifier(create_keras_model)
34+
3335
"""
3436
Data wrangling
3537
1. Reading data from Keras
@@ -38,59 +40,45 @@ def create_keras_model():
3840
"""
3941

4042
# read training data
41-
(x_train, y_train), (x_test, y_test) = mnist.load_data()
42-
x_train = x_train.reshape(60000, 784).astype('float32')/255
43-
x_test = x_test.reshape(10000, 784).astype('float32')/255
43+
(X_train, y_train), (X_test, y_test) = mnist.load_data()
44+
X_train = X_train.reshape(60000, 784).astype('float32') / 255
45+
X_test = X_test.reshape(10000, 784).astype('float32') / 255
4446
y_train = keras.utils.to_categorical(y_train, 10)
4547
y_test = keras.utils.to_categorical(y_test, 10)
4648

47-
# select the first example from each category
48-
initial_idx = list()
49-
for label in range(10):
50-
for elem_idx, elem in enumerate(x_train):
51-
if y_train[elem_idx][label] == 1.0:
52-
initial_idx.append(elem_idx)
53-
break
54-
5549
# assemble initial data
56-
x_initial = x_train[initial_idx]
50+
n_initial = 1000
51+
initial_idx = np.random.choice(range(len(X_train)), size=n_initial, replace=False)
52+
X_initial = X_train[initial_idx]
5753
y_initial = y_train[initial_idx]
5854

5955
# generate the pool
6056
# remove the initial data from the training dataset
61-
x_train = np.delete(x_train, initial_idx, axis=0)
62-
y_train = np.delete(y_train, initial_idx, axis=0)
63-
# sample random elements from x_train
64-
pool_size = 10000
65-
pool_idx = np.random.choice(range(len(x_train)), pool_size)
66-
x_pool = x_train[pool_idx]
67-
y_pool = y_train[pool_idx]
57+
X_pool = np.delete(X_train, initial_idx, axis=0)
58+
y_pool = np.delete(y_train, initial_idx, axis=0)
6859

6960
"""
7061
Training the ActiveLearner
7162
"""
7263

73-
# create the classifier
74-
classifier = KerasClassifier(create_keras_model)
75-
7664
# initialize ActiveLearner
7765
learner = ActiveLearner(
7866
predictor=classifier,
79-
X_initial=x_initial, y_initial=y_initial,
67+
X_initial=X_initial, y_initial=y_initial,
8068
verbose=0
8169
)
8270

8371
# the active learning loop
8472
n_queries = 10
8573
for idx in range(n_queries):
86-
query_idx, query_instance = learner.query(x_pool, n_instances=200, verbose=0)
74+
query_idx, query_instance = learner.query(X_pool, n_instances=200, verbose=0)
8775
learner.teach(
88-
X=x_pool[query_idx], y=y_pool[query_idx],
76+
X=X_pool[query_idx], y=y_pool[query_idx],
8977
verbose=0
9078
)
9179
# remove queried instance from pool
92-
x_pool = np.delete(x_pool, query_idx, axis=0)
80+
X_pool = np.delete(X_pool, query_idx, axis=0)
9381
y_pool = np.delete(y_pool, query_idx, axis=0)
9482

9583
# the final accuracy score
96-
print(learner.score(x_test, y_test, verbose=0))
84+
print(learner.score(X_test, y_test, verbose=0))

0 commit comments

Comments
 (0)