# Classification Plugins


**Adjutorium** provides a set of default predicton plugins and can be extended with any number of other plugins.

In this tutorial, we will focus on the __classification__ tasks.

### Plugins 101

Every **Adjutorium plugin** must implement the **`Plugin`** interface provided by `adjutorium/plugins/core/base_plugin.py`.

Each **Adjutorium prediction plugin** must implement the **`PredictionPlugin`** interface provided by `adjutorium/plugins/prediction/base.py`

__Warning__ : If a plugin doesn't override all the abstract methods, it won't be loaded by the library.




__API__ : Every prediction plugin must implement the following methods:
- `name()` - a static method that returns the name of the plugin. e.g., neural_nets, perceptron, etc.

- `subtype()` - a static method that returns the plugin's subtype. e.g., "classification", "survival_analysis" etc. It will be used for filtering the plugin in the optimization process.
    
- `hyperparameter_space()` - a static method that returns the hyperparameters that can be tuned during the optimization. The method will return a list of `skopt.space.Dimension` derived objects.

### Setup

In [None]:
import sys
import warnings
import time
from tqdm import tqdm

import numpy as np
import pandas as pd

from sklearn.datasets import load_breast_cancer
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons, make_circles, make_classification

from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt

from IPython.display import HTML, display
import tabulate

if not sys.warnoptions:
    warnings.simplefilter("ignore")

### Loading the Classification plugins

Make sure that you have installed Adjutorium in your workspace.

You can do that by running `pip install .` in the root of the repository.

In [None]:
from adjutorium.plugins.prediction.classifiers import Classifiers, ClassifierPlugin

classifiers = Classifiers()

### List the existing plugins

In [None]:
classifiers.list()

### List the existing plugins

Now we should see the new plugins loaded.

In [None]:
from adjutorium.plugins.prediction.classifiers import Classifiers

classifiers = Classifiers()

classifiers.list()

## Benchmarks

We test the prediction plugins using the [Wisconsin Breast Cancer dataset](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)).

### Loading the data

In [None]:
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2)

### Duration benchmarks

__About__ : This step measures the fit_predict duration for each plugin on the dataset. The times are reported in milliseconds.

In [None]:
duration = []

plugins = classifiers.list()

for plugin in tqdm(plugins):
    plugin_duration = [plugin] 
    ctx = classifiers.get(plugin)
    
    start = time.time() * 1000
    ctx.fit_predict(X, y)
            
    plugin_duration.append(round(time.time() * 1000 - start, 4))

    duration.append(plugin_duration)

### Duration(ms) results

In [None]:
display(
    HTML(tabulate.tabulate(duration, headers=["Plugin", "Duration(ms)"], tablefmt="html"))
)

### Prediction performance

__Steps__
 - We train each prediction plugin on the dataset.
 - We report the accuracy,  AUROC, and AURPC metrics on the test set.

In [None]:
def get_metrics(plugin, X_train, y_train, X_test, y_test):
    plugin.fit(X_train, y_train)

    y_pred = plugin.predict(X_test)
    
    score = metrics.accuracy_score(y_test, y_pred)
    
    fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred)
    auroc = metrics.auc(fpr, tpr)
    
    prec, recall, thresholds = metrics.precision_recall_curve(y_test, y_pred)
    aurpc = metrics.auc(recall, prec)
    
    return round(score, 4), round(auroc, 4), round(aurpc, 4)


metrics_headers = ["Plugin", "Accuracy", "AUROC", "AURPC"]
test_score = []


for plugin in plugins:
    fproc = classifiers.get(plugin)

    score, auroc, aurpc = get_metrics(fproc, X_train, y_train, X_test, y_test)
    
    test_score.append([plugin, score, auroc, aurpc])

In [None]:
display(
    HTML(tabulate.tabulate(test_score, headers=metrics_headers, tablefmt="html"))
)

# Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!

### Star Adjutorium on GitHub

The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.

- [Star Adjutorium](https://github.com/vanderschaarlab/adjutorium-framework)
