# Extending TemporAI Tutorial 04: Writing a Custom Metric Plugin

This tutorial shows how to extend TemporAI by wring a custom *metric* plugin.

*Skip the below cell if you are not on Google Colab / already have TemporAI installed:*

In [None]:
%pip install temporai

# Or from the repo, for the latest version:
# %pip install git+https://github.com/vanderschaarlab/temporai.git

**Note**

See also "Writing a Custom `Plugin` 101" section in "Writing a Custom Method Plugin" tutorial.

### Inherit from the appropriate **base class** for the category of the metric plugin you are writing.

You need to find which category of metric plugin you are writing.

You can view all the different metric plugin categories as so:

In [None]:
from tempor import plugin_loader

plugin_categories = plugin_loader.list_categories(plugin_type="metric")

list(plugin_categories.keys())

['prediction.one_off.classification',
 'prediction.one_off.regression',
 'time_to_event']

Remember you can also see the existing metric plugins and how they correspond to different categories, as follows:

In [None]:
all_plugins = plugin_loader.list(plugin_type="metric")

from rich.pretty import pprint  # For prettifying the print output only.

pprint(all_plugins, indent_guides=True)

Let's say you would like to write a plugin of category `"prediction.one_off.regression"`.

You can find which base class you need to inherit from as follows.

In [None]:
plugin_categories = plugin_loader.list_categories(plugin_type="metric")

print("Base classes for all categories:")
pprint(plugin_categories, indent_guides=False)

print("Base class you need:")
print(plugin_categories["prediction.one_off.regression"])

Base classes for all categories:


Base class you need:
<class 'tempor.metrics.metric.OneOffRegressionMetric'>


You can then find the class in the TemporAI source code, to see its method signatures etc.

### Implement the **methods** the plugin needs.

`Metric` plugins require the following methods to be implemented:
* `direction` property, which returns either `"minimize"` or `"maximize"`, representing the "good" direction of the metric.
* `_evaluate()` which takes in the actual and predicted values and returns the evaluated metric(s).

If you haven't implemented some required method for the plugin, Python will notify you by raising an exception when you
attempt to instantiate your plugin (see [Python `abc`](https://docs.python.org/3/library/abc.html)).


### **Register** the plugin with TemporAI.

Registering your plugin with TemporAI is very simple, you need to use the `register_plugin` decorator,
as shown in the example below.

You will need to specify the `name` of your plugin and its `category` in the decorator.

The `plugin_type` needs to be set to `"datasource"`.

```python
from tempor.core.plugins import register_plugin

@register_plugin(name="my_metric", category="prediction.one_off.regression", plugin_type="metric")
class MyMetric(OneOffRegressionMetric):
    ...
```

### Example

Now putting this together in a minimal example.

In [None]:
from typing import Any

import numpy as np
from sklearn.metrics import d2_pinball_score

from tempor.core import plugins
from tempor.metrics import metric, metric_typing


@plugins.register_plugin(name="my_metric", category="prediction.one_off.regression", plugin_type="metric")
class MyOneOffRegressionMetric(metric.OneOffRegressionMetric):
    """My custom metric, here we use D2 pinball score as per `sklearn`."""

    @property
    def direction(self) -> metric_typing.MetricDirection:  # noqa: D102
        return "maximize"

    def _evaluate(self, actual: np.ndarray, predicted: np.ndarray, *args: Any, **kwargs: Any) -> float:
        return d2_pinball_score(actual, predicted, alpha=kwargs.get("alpha", 0.5))

We now see our plugin in TemporAI:

In [None]:
from tempor import plugin_loader

all_plugins = plugin_loader.list(plugin_type="metric")

pprint(all_plugins, indent_guides=True)

my_metric_found = "my_metric" in all_plugins["prediction"]["one_off"]["regression"]
print(f"`my_metric` plugin found in the category 'prediction.one_off.regression': {my_metric_found}")
assert my_metric_found

`my_metric` plugin found in the category 'prediction.one_off.regression': True


The plugin can be used as normal.

In [None]:
# Get the plugin.

my_metric = plugin_loader.get("prediction.one_off.regression.my_metric", plugin_type="metric")

print(my_metric)

MyOneOffRegressionMetric(
    name='my_metric',
    description='My custom metric, here we use D2 pinball score as per `sklearn`.'
)


In [None]:
# Use the metric.

y_true = np.asarray([1, 2, 3])
y_pred = np.asarray([1, 3, 3])

metric_value = my_metric(y_true, y_pred, alpha=0.3)
print(metric_value)

0.27083333333333337


## 🎉 Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards *Machine learning and AI for Medicine*, you can do so in the following ways!



### ⭐ Star [TemporAI](https://github.com/vanderschaarlab/temporai) on GitHub

- The easiest way to help our community is by just starring the repos! This helps raise awareness of the tools we're building.



### Check out other projects from [vanderschaarlab](https://github.com/vanderschaarlab)
- 📝 [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)
- 📊 [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)
- 🤖 [SynthCity](https://github.com/vanderschaarlab/synthcity)
 