[![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vanderschaarlab/temporai/blob/main/tutorials/extending/tutorial01_custom_method.ipynb)

# Extending TemporAI Tutorial 01: Writing a Custom Method Plugin

This tutorial shows how to extend TemporAI by wring a custom *method* (as in algorithm, model) plugin.

## Writing a Custom `Plugin` 101

In order to write a custom plugin for TemporAI, you need to do the following:
1. Inherit from the appropriate **base class** for the category of plugin you are writing.
2. Implement the **methods** (as in, functions of the class) that the plugin needs.
3. **Register** the plugin with TemporAI.

We will go through an example in this tutorial.

### 1. Inherit from the appropriate **base class** for the category of the method plugin you are writing.

You need to find which category of method plugin you are writing.

A summary of different plugin categories is available in the 
[README](https://github.com/vanderschaarlab/temporai/blob/main/README.md#-methods).

You can also view all the different plugin categories as so:

In [None]:
from tempor import plugin_loader

plugin_categories = plugin_loader.list_categories(plugin_type="method")

list(plugin_categories.keys())

['prediction.one_off.classification',
 'prediction.one_off.regression',
 'prediction.temporal.classification',
 'prediction.temporal.regression',
 'preprocessing.encoding.static',
 'preprocessing.encoding.temporal',
 'preprocessing.imputation.static',
 'preprocessing.imputation.temporal',
 'preprocessing.nop',
 'preprocessing.scaling.static',
 'preprocessing.scaling.temporal',
 'time_to_event',
 'treatments.one_off.regression',
 'treatments.temporal.classification',
 'treatments.temporal.regression']

Remember you can also see the existing method plugins and how they correspond to different categories, as follows:

In [None]:
all_plugins = plugin_loader.list(plugin_type="method")

from rich.pretty import pprint  # For prettifying the print output only.

pprint(all_plugins, indent_guides=True)

Let's say you would like to write a plugin of category `"prediction.one_off.classification"`.

You can find which base class you need to inherit from as follows.

In [None]:
plugin_categories = plugin_loader.list_categories(plugin_type="method")

print("Base classes for all categories:")
pprint(plugin_categories, indent_guides=False)

print("Base class you need:")
print(plugin_categories["prediction.one_off.classification"])

Base classes for all categories:


Base class you need:
<class 'tempor.methods.prediction.one_off.classification.BaseOneOffClassifier'>


You can then find the class in the TemporAI source code, to see its method signatures etc.

### 2. Implement the **methods** the plugin needs.

Different category plugins have different methods (functions) that need to be implemented, but the key methods are:
* `_fit()` where you provide your implementation of the fitting (training).
* `_predict()` where you provide your implementation of the prediction (inference).
* `_transform()` where you provide your implementation of data transformation (for preprocessing plugins).

Classification-related plugins also have `_predict_proba()` and treatment effects plugins have `_predict_counterfactuals()`.

Note that these methods have a preceding underscore `_`, and are different from the corresponding "public" methods
without the underscore (e.g `fit()`). When extending, you need to implement the `_<...>` method,
and the corresponding "public" method in TemporAI is what the user of your plugin will call.
The "public" methods also do various necessary validation and other checks behind the scenes.

If you haven't implemented some required method for the plugin, Python will notify you by raising an exception when you
attempt to instantiate your plugin (see [Python `abc`](https://docs.python.org/3/library/abc.html)).


In our example case, you will need to implement the following methods for `BaseOneOffClassifier`:

```python
from tempor.methods.prediction.one_off.classification import BaseOneOffClassifier

class MyPlugin(BaseOneOffClassifier):
    # The initializer:
    def __init__(self, **params: Any) -> None:
        ...

    # The _fit implementation.
    def _fit(self, data: dataset.BaseDataset, *args, **kwargs):
        ...

    def _predict(self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any) -> samples.StaticSamples:
        ...

    def _predict_proba(self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any) -> samples.StaticSamples:
        ...
    
    @staticmethod
    def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]:
        # This method is not currently used in TemporAI (it will be used once AutoML component is implemented).
        # For now, you may just return an empty list.
        ...
``` 

### 3. **Register** the plugin with TemporAI.

Registering your plugin with TemporAI is very simple, you need to use the `register_plugin` decorator,
as shown in the example below.

You will need to specify the `name` of your plugin and its `category` in the decorator.

> **Note:** You may omit `plugin_type="method"` below, as `"method"` is the default plugin type.

```python
from tempor.core.plugins import register_plugin

@register_plugin(name="my_plugin", category="prediction.one_off.classification", plugin_type="method")
class MyPlugin(BaseOneOffClassifier):
    ...
```

### Note on `__init__` parameters (arguments)

You will also need to define the input parameters (arguments) that will be passed into your plugin's `__init__` in the
following way:

```python
import dataclasses

# 1. Write dataclass with your __init__ parameters:
@dataclasses.dataclass
class MyPluginParams:
    # Specify the parameter, data type and default value as below:
    lr: float = 0.001
    batch_size: int = 100

class MyPlugin(BaseOneOffClassifier):
    # 2. Set the `ParamsDefinition` class variable in your plugin to this dataclass.
    ParamsDefinition = MyPluginParams
    
    def __init__(self, **params: Any) -> None:
        # 3. Call the parent __init__ as so.
        super().__init__(**params)

        # 4. You will now be able to access these in your class like so:
        print(self.params.lr)
        print(self.params.batch_size) 


# 5. The user will then be able to specify the arguments as necessary when initializing your plugin:
model = MyPlugin(batch_size=22)
```


### Putting it all together

Now putting this together in an example of a one-off classifier plugin that always returns `1`s.

In [None]:
import dataclasses
from typing import Any, List

import numpy as np

from tempor.core.plugins import register_plugin
from tempor.methods.core import Params
from tempor.data import dataset, samples
from tempor.methods.prediction.one_off.classification import BaseOneOffClassifier


@dataclasses.dataclass
class MyClassifierParams:
    some_parameter: int = 1
    other_parameter: float = 0.5


@register_plugin(name="my_classifier", category="prediction.one_off.classification", plugin_type="method")
class MyClassifierClassifier(BaseOneOffClassifier):
    ParamsDefinition = MyClassifierParams

    def __init__(self, **param) -> None:
        super().__init__(**param)

    def _fit(self, data: dataset.BaseDataset, *args, **kwargs):
        """Does nothing."""
        return self  # Fit method needs to return `self`.

    def _predict(self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any) -> samples.StaticSamples:
        """Always returns 1"""

        assert data.predictive.targets is not None
        preds = np.ones_like(data.predictive.targets.numpy())

        return samples.StaticSamples.from_numpy(preds, dtype=int)

    def _predict_proba(self, data: dataset.PredictiveDataset, *args: Any, **kwargs: Any) -> samples.StaticSamples:
        """Always returns 1.0"""

        assert data.predictive.targets is not None
        preds = np.ones_like(data.predictive.targets.numpy())

        return samples.StaticSamples.from_numpy(preds, dtype=float)

    @staticmethod
    def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]:
        return []

We now see our plugin in TemporAI:

In [None]:
from tempor import plugin_loader

all_plugins = plugin_loader.list(plugin_type="method")

pprint(all_plugins, indent_guides=True)

my_classifier_found = "my_classifier" in all_plugins["prediction"]["one_off"]["classification"]
print(f"`my_classifier` plugin found in the category 'prediction.one_off.classification': {my_classifier_found}")
assert my_classifier_found

`my_classifier` plugin found in the category 'prediction.one_off.classification': True


The plugin can be used as normal.

In [None]:
# Get the plugin.

my_classifier = plugin_loader.get("prediction.one_off.classification.my_classifier", plugin_type="method")

print(my_classifier)

MyClassifierClassifier(
    name='my_classifier',
    category='prediction.one_off.classification',
    plugin_type='method',
    params={'some_parameter': 1, 'other_parameter': 0.5}
)


In [None]:
# Fit and predict on some data.

dataset = plugin_loader.get("prediction.one_off.sine", plugin_type="datasource", random_state=42).load()

my_classifier.fit(dataset)

print("Prediction:")
my_classifier.predict(dataset)

Prediction:


Unnamed: 0_level_0,feat_0
sample_idx,Unnamed: 1_level_1
0,1
1,1
2,1
3,1
4,1
...,...
95,1
96,1
97,1
98,1
