Skip to content

Commit eb2df2f

Browse files
committed
fixed max_entropy acquisition function
1 parent 85445ae commit eb2df2f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/deep_bayesian_active_learning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def max_entropy(learner, X, n_instances=1, T=100):
6363
learning_phase = True
6464
MC_samples = [MC_output([subset, learning_phase])[0] for _ in range(T)]
6565
MC_samples = np.array(MC_samples) # [#samples x batch size x #classes]
66-
acquisition = - np.mean(np.sum(MC_samples * np.log(MC_samples + 1e-10), axis=-1), axis=0) # [batch size]
66+
expected_p = np.mean(MC_samples, axis=0)
67+
acquisition = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1) # [batch size]
6768
query_idx = (-acquisition).argsort()[:n_instances]
6869
return query_idx, X[query_idx]
6970

0 commit comments

Comments
 (0)