/
core.py
92 lines (69 loc) · 2.55 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/plugins.core.ipynb.
# %% auto 0
__all__ = ['BasePlugin']
# %% ../../nbs/plugins.core.ipynb 3
from typing import Optional, TYPE_CHECKING
from abc import ABC
from plotly.graph_objs._figure import Figure
from sklearn.base import BaseEstimator
from ..utils.utils import get_kwargs, non_default_repr
if TYPE_CHECKING:
from poniard.estimators.core import PoniardBaseEstimator
# %% ../../nbs/plugins.core.ipynb 4
class BasePlugin(ABC):
"""Base plugin class. New plugins should inherit from this class."""
def __init__(self):
self._init_params = get_kwargs(back=True)
self._poniard: Optional["PoniardBaseEstimator"] = None
def on_setup_start(self):
"""Called during setup start."""
pass
def on_setup_data(self):
"""Called after X and y have been set."""
pass
def on_infer_types(self):
"""Called after type inference."""
pass
def on_setup_preprocessor(self):
"""Called after preprocessor construction."""
pass
def on_setup_end(self):
"""Called after setup is complete."""
pass
def on_fit_start(self):
"""Called during fit start."""
pass
def on_fit_end(self):
"""Called after fitting is complete."""
pass
def on_plot(self, figure: Figure, name: str):
"""Called when a plot is created."""
pass
def on_get_estimator(self, estimator: BaseEstimator, name: str):
"""Called when an estimator is selected."""
pass
def on_analyze_estimator(self, estimator: BaseEstimator, name: str):
"""Called when an estimator is analyzed."""
pass
def on_add_estimators(self):
"""Called after adding an estimator."""
pass
def on_remove_estimators(self):
"""Called after removing an estimator."""
pass
def on_add_preprocessing_step(self):
"""Called after adding a preprocessing step."""
pass
def on_reassign_types(self):
"""Called after reassigning types."""
pass
def _check_plugin_used(self, plugin_cls_name: str):
"""Check if another plugin is present. If it is, return its instance. Else, return False."""
plugin_names = [x.__class__.__name__ for x in self._poniard.plugins]
check = any(x == plugin_cls_name for x in plugin_names)
if check:
return self._poniard.plugins[plugin_names.index(plugin_cls_name)]
else:
return False
def __repr__(self):
return non_default_repr(self)