# Base plugin

> All plugins should subclass `BasePlugin`.

In [None]:
# | default_exp plugins.core

In [None]:
# | hide


from nbdev.showdoc import *

In [None]:
# | export


from typing import Optional, TYPE_CHECKING
from abc import ABC

from plotly.graph_objs._figure import Figure
from sklearn.base import BaseEstimator

from poniard.utils.utils import get_kwargs, non_default_repr

if TYPE_CHECKING:
    from poniard.estimators.core import PoniardBaseEstimator

In [None]:
# | export


class BasePlugin(ABC):
    """Base plugin class. New plugins should inherit from this class."""

    def __init__(self):
        self._init_params = get_kwargs(back=True)
        self._poniard: Optional["PoniardBaseEstimator"] = None

    def on_setup_start(self):
        """Called during setup start."""
        pass

    def on_setup_data(self):
        """Called after X and y have been set."""
        pass

    def on_infer_types(self):
        """Called after type inference."""
        pass

    def on_setup_preprocessor(self):
        """Called after preprocessor construction."""
        pass

    def on_setup_end(self):
        """Called after setup is complete."""
        pass

    def on_fit_start(self):
        """Called during fit start."""
        pass

    def on_fit_end(self):
        """Called after fitting is complete."""
        pass

    def on_plot(self, figure: Figure, name: str):
        """Called when a plot is created."""
        pass

    def on_get_estimator(self, estimator: BaseEstimator, name: str):
        """Called when an estimator is selected."""
        pass

    def on_analyze_estimator(self, estimator: BaseEstimator, name: str):
        """Called when an estimator is analyzed."""
        pass

    def on_add_estimators(self):
        """Called after adding an estimator."""
        pass

    def on_remove_estimators(self):
        """Called after removing an estimator."""
        pass

    def on_add_preprocessing_step(self):
        """Called after adding a preprocessing step."""
        pass

    def on_reassign_types(self):
        """Called after reassigning types."""
        pass

    def _check_plugin_used(self, plugin_cls_name: str):
        """Check if another plugin is present. If it is, return its instance. Else, return False."""
        plugin_names = [x.__class__.__name__ for x in self._poniard.plugins]
        check = any(x == plugin_cls_name for x in plugin_names)
        if check:
            return self._poniard.plugins[plugin_names.index(plugin_cls_name)]
        else:
            return False

    def __repr__(self):
        return non_default_repr(self)

`BasePlugin` defines a set of common plugin hooks that allow plugin classes to execute actions during the life of a Poniard estimator. These could be thought of hooks in callbacks in other libraries like [Keras](https://keras.io/api/callbacks/), [Transformers](https://huggingface.co/docs/transformers/main_classes/callback) or [fastai](https://docs.fast.ai/callback.core.html).

We have named them plugins as callbacks in other libraries are generally not expected to significantly alter what the main code does, and instead add funcionality like logging, model saving, etc. Poniard plugins have no such restriction.

## Developing plugins

Plugins allow devs to extend Poniard funcionality beyond what the main module offers. Doing so is straightforward: subclass `BasePlugin` and implement the desired methods.

Crucially, Poniard estimators inject themselves to all plugins during initialization, meaning that plugin instances have access to the estimator on the attribute `_poniard`.

The following minimal example builds a plugin that adds a new (useless) feature and modifies the preprocessor.

In [None]:
import numpy as np
import pandas as pd
from sklearn.feature_selection import VarianceThreshold

from poniard import PoniardClassifier

In [None]:
class StringFeaturePlugin(BasePlugin):
    """A plugin that adds a feature comprised of a single string.

    Parameters
    ----------
    string :
        The string to add as a feature.
    """

    def __init__(self, string: str):
        super().__init__()
        self.string = string

    def on_setup_data(self):
        data = self._poniard.X
        if hasattr(data, "iloc"):
            self._poniard.X = data.assign(**{self.string: self.string})
        else:
            self._poniard.X = np.append(data, self.string, axis=1)
        return

    def on_setup_preprocessor(self):
        old_preprocessor = self._poniard.preprocessor
        if isinstance(old_preprocessor[-1], VarianceThreshold):
            self._poniard.preprocessor = old_preprocessor[:-1]
            self._poniard.pipelines = self._poniard._build_pipelines()
        return


features = pd.DataFrame(
    np.random.normal(size=(20, 2)), columns=[f"X_{i}" for i in range(2)]
)
target = np.random.choice([0, 1], size=20)
pnd = PoniardClassifier(plugins=StringFeaturePlugin("foobar")).setup(features, target)
pnd.preprocessor

Target info
-----------
Type: binary
Shape: (20,)
Unique values: 2

Main metric
-----------
roc_auc

Thresholds
----------
Minimum unique values to consider a feature numeric: 2
Minimum unique values to consider a categorical high cardinality: 20

Inferred feature types
----------------------


Unnamed: 0,numeric,categorical_high,categorical_low,datetime
0,X_0,,foobar,
1,X_1,,,






In [None]:
# | hide
import nbdev

nbdev.nbdev_export()