# Scikit-Active-Machine-Learning: Overview of Pool-based Active Learning

In [None]:
import sys
sys.path.append('../..')

import numpy as np
import pylab as plt
import matplotlib.animation as animation
import ipywidgets as widgets
import skactiveml.pool as skacmlp
import warnings
warnings.filterwarnings('ignore')

from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.display import HTML
from skactiveml.utils import is_labeled, is_unlabeled, MISSING_LABEL, call_func

# classifiers
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from skactiveml.classifier import PWC, SklearnClassifier, CMM

# data sets
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler

## Data Sets

In [None]:
def create_2d_data_set(seed=42):
    X, y = make_blobs(n_samples=200, n_features=2, centers=12, cluster_std=1, random_state=seed)
    y = y % 2
    X = StandardScaler().fit_transform(X)
    return X, y

## Classifiers

In [None]:
clf_dict = {'PWC' : PWC(classes=[0,1], missing_label=MISSING_LABEL),
            'CMM' : CMM(classes=[0, 1], missing_label=MISSING_LABEL),
            'GaussianNaiveBayes' : SklearnClassifier(GaussianNB(), classes=[0, 1], missing_label=MISSING_LABEL),
            'DecisionTree' : SklearnClassifier(DecisionTreeClassifier(), classes=[0, 1], missing_label=MISSING_LABEL),
            'KNN' : SklearnClassifier(KNeighborsClassifier(), classes=[0, 1], missing_label=MISSING_LABEL),
            'SVC' : SklearnClassifier(SVC(probability=True), classes=[0, 1], missing_label=MISSING_LABEL),
            }

perf_est_dict = {'PWC' : PWC}

## Query Strategies

In [None]:
query_strategies = {}
for qs_name in skacmlp.__all__:
    query_strategies[qs_name] = getattr(skacmlp, qs_name)
print(query_strategies.keys())

## AL Cycle

In [None]:
def get_labels_with_selector(X, y, y_oracle, clf, selector, budget=30):  
    y_list = []
    
    for b in range(budget):
        unlabeled = np.where (is_unlabeled(y))[0]
        clf.fit(X, y)
        unlabeled_id = call_func(selector.query, X_cand=X[unlabeled], X=X, y=y, X_eval=X)
        sample_id = unlabeled[unlabeled_id]
        y[sample_id] = y_oracle[sample_id]
        y_list.append(y.copy())

    return np.array(y_list)

## Plotting Functions

In [None]:
def plot_scores_2d(X, y, y_oracle, X_1_mesh, X_2_mesh, X_mesh, clf, selector, ax):
    # compute gains
    clf.fit(X, y)
    posteriors = clf.predict_proba(X_mesh)[:,0].reshape(X_1_mesh.shape)
    
    # compute gains
    _, scores = call_func(selector.query, X_cand=X_mesh, X=X, y=y, X_eval=X, return_utilities=True)
    scores = scores.reshape(X_1_mesh.shape)
    
    # get indizes for plotting
    labeled_indices = np.where(is_labeled(y))[0]
    unlabeled_indices = np.where(is_unlabeled(y))[0]
    
    out = []
    
    scatter_1 = ax.scatter(X[labeled_indices, 0], X[labeled_indices, 1], c=[[.2, .2, .2]], s=90, marker='o', zorder=3.8)
    scatter_2 = ax.scatter(X[labeled_indices, 0], X[labeled_indices, 1], c=[[.8, .8, .8]], s=60, marker='o', zorder=4)
    for cl, marker in zip([0,1],['D','s']):
        cl_labeled_idx = labeled_indices[y[labeled_indices] == cl]
        cl_unlabeled_idx = unlabeled_indices[y_oracle[unlabeled_indices]==cl]
        ax.scatter(X[cl_labeled_idx, 0], X[cl_labeled_idx, 1], c=np.ones(len(cl_labeled_idx))*cl, marker=marker, vmin=-0.2, vmax=1.2, cmap='coolwarm', s=20, zorder=5)
        ax.scatter(X[cl_unlabeled_idx, 0], X[cl_unlabeled_idx, 1], c=np.ones(len(cl_unlabeled_idx)) * cl, marker=marker, vmin=-0.2, vmax=1.2, cmap='coolwarm', s=20, zorder=3)
        ax.scatter(X[cl_unlabeled_idx, 0], X[cl_unlabeled_idx, 1], c='k', marker=marker, vmin=-0.1, vmax=1.1, cmap='coolwarm', s=30, zorder=2.8)
    
    cs_0 = ax.contourf(X_1_mesh, X_2_mesh, scores, cmap='Greens', alpha=.75).collections
    cs_1 = ax.contour(X_1_mesh, X_2_mesh, posteriors, [.5], colors='k', linewidths=[2], zorder=1).collections
    cs_2 = ax.contour(X_1_mesh, X_2_mesh, posteriors, [.25,.75], cmap='coolwarm_r', linewidths=[2,2], 
                      zorder=1, linestyles='--', alpha=.9, vmin=.2, vmax=.8).collections
    
    return cs_0 + cs_1 + cs_2 + [scatter_1] + [scatter_2]

def plot_gain_data_set_2d(al, clf, usefulness, perf_est, budget=50, fps=2, seed=43, n_samples=250):
    
    # Create data set.
    X, y_oracle = create_2d_data_set(seed = seed)
    X = StandardScaler().fit_transform(X)
    y = np.full(y_oracle.shape, MISSING_LABEL)
    classes = np.unique(y_oracle)

    # Create mesh for plotting.
    x_1_vec = np.linspace(min(X[:, 0]), max(X[:, 0]), 21)
    x_2_vec = np.linspace(min(X[:, 1]), max(X[:, 1]), 21)
    X_1_mesh, X_2_mesh = np.meshgrid(x_1_vec, x_2_vec)
    X_mesh = np.array([X_1_mesh.reshape(-1), X_2_mesh.reshape(-1)]).T
        
    # Create classifier.
    clf_name = clf
    perf_est_name = perf_est
    clf = clf_dict[clf]
    perf_est = call_func(perf_est_dict[perf_est], n_classes=len(classes))
    
    # Create utiltiy strategy.
    ut_qs = call_func(query_strategies[usefulness], clf=clf, perf_est=perf_est, model=clf, classes=classes, random_state=seed)
        
    # Execute AL cycle.
    al_qs = call_func(query_strategies[al], clf=clf, perf_est=perf_est, model=clf, classes=classes, random_state=seed)
    Y = get_labels_with_selector(X, y, y_oracle, clf, al_qs, budget=budget)
    
    # Setup figure for plotting.
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111)
    ax.set_title(f"AL={al}, Classifier={clf_name}, \n Usefulness={usefulness}, Performan Estimator={perf_est_name}")
    ax.set_xlim(min(X[:, 0]), max(X[:, 0]))
    ax.set_ylim(min(X[:, 1]), max(X[:, 1]))
    ax.set_xlabel(r'$x_1$')
    ax.set_ylabel(r'$x_2$')
    
    ims = []
    for i in range(len(Y)):
        img = plot_scores_2d(X, Y[i], y_oracle, X_1_mesh, X_2_mesh, X_mesh, clf, ut_qs, ax)
        ims.append(img)
    ani = animation.ArtistAnimation(fig, ims, blit=False, interval=1000/fps, repeat_delay=1)
    plt.close()
    return HTML(ani.to_html5_video())

## Animation of AL Cycle

In [None]:
budget_slider = widgets.IntSlider(value=15, min=1, max=200, step=1)
fps_slider = widgets.FloatSlider(value=1, min=0.1, max=15, step=0.1)
clf_slider = widgets.Dropdown(options=clf_dict.keys())
al_slider = widgets.Dropdown(options=query_strategies.keys())
usefulness_slider = widgets.Dropdown(options=query_strategies.keys())
perf_slider = widgets.Dropdown(options=perf_est_dict.keys())
interact(plot_gain_data_set_2d, budget=budget_slider, fps=fps_slider, clf=clf_slider,
         al=al_slider, usefulness=usefulness_slider, perf_est=perf_slider)