In [5]:
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

iris = datasets.load_iris()
X, y = iris.data, iris.target

X_labeled, X_pool, y_labeled, y_pool = train_test_split(X, y, test_size=0.9, random_state=42)

classifier = RandomForestClassifier(n_estimators=100, random_state=42)

classifier.fit(X_labeled, y_labeled)

y_pool_predictions = classifier.predict(X_pool)

initial_accuracy = accuracy_score(y_pool, y_pool_predictions)
print(f"Initial model accuracy: {initial_accuracy:.2f}")

num_queries = 10
for i in range(num_queries):
    uncertainty = np.max(classifier.predict_proba(X_pool), axis=1)
    query_instance_index = np.argmax(uncertainty)

    query_instance = X_pool[query_instance_index].reshape(1, -1)
    queried_label = int(input(f"Query {i + 1}: What is the label for this instance? "))

    X_labeled = np.concatenate((X_labeled, query_instance), axis=0)
    y_labeled = np.concatenate((y_labeled, np.array([queried_label])))

    X_pool = np.delete(X_pool, query_instance_index, axis=0)
    y_pool = np.delete(y_pool, query_instance_index)

    classifier.fit(X_labeled, y_labeled)

    y_pool_predictions = classifier.predict(X_pool)

    current_accuracy = accuracy_score(y_pool, y_pool_predictions)
    print(f"Model accuracy after {i + 1} queries: {current_accuracy:.2f}")


Initial model accuracy: 0.96
Query 1: What is the label for this instance? 0
Model accuracy after 1 queries: 0.93
Query 2: What is the label for this instance? 1
Model accuracy after 2 queries: 0.94
Query 3: What is the label for this instance? 2
Model accuracy after 3 queries: 0.95
Query 4: What is the label for this instance? 3
Model accuracy after 4 queries: 0.92
Query 5: What is the label for this instance? 4
Model accuracy after 5 queries: 0.92
Query 6: What is the label for this instance? 5
Model accuracy after 6 queries: 0.87
Query 7: What is the label for this instance? 6
Model accuracy after 7 queries: 0.84
Query 8: What is the label for this instance? 7
Model accuracy after 8 queries: 0.80
Query 9: What is the label for this instance? 8
Model accuracy after 9 queries: 0.77
Query 10: What is the label for this instance? 9
Model accuracy after 10 queries: 0.74
