[![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vanderschaarlab/temporai/blob/main/tutorials/extending/tutorial03_custom_datasource.ipynb)

# Extending TemporAI Tutorial 03: Writing a Custom Data Source Plugin

This tutorial shows how to extend TemporAI by wring a custom *data source* plugin.

**Note**

See also "Writing a Custom `Plugin` 101" section in "Writing a Custom Method Plugin" tutorial.

### Inherit from the appropriate **base class** for the category of the data source plugin you are writing.

You need to find which category of data source plugin you are writing.

You can view all the different data source plugin categories as so:

In [None]:
from tempor import plugin_loader

plugin_categories = plugin_loader.list_categories(plugin_type="datasource")

list(plugin_categories.keys())

['prediction.one_off',
 'prediction.temporal',
 'time_to_event',
 'treatments.one_off',
 'treatments.temporal']

Remember you can also see the existing method plugins and how they correspond to different categories, as follows:

In [None]:
all_plugins = plugin_loader.list(plugin_type="datasource")

from rich.pretty import pprint  # For prettifying the print output only.

pprint(all_plugins, indent_guides=True)

Let's say you would like to write a plugin of category `"prediction.one_off"`.

You can find which base class you need to inherit from as follows.

In [None]:
plugin_categories = plugin_loader.list_categories(plugin_type="datasource")

print("Base classes for all categories:")
pprint(plugin_categories, indent_guides=False)

print("Base class you need:")
print(plugin_categories["prediction.one_off"])

Base classes for all categories:


Base class you need:
<class 'tempor.datasources.datasource.OneOffPredictionDataSource'>


You can then find the class in the TemporAI source code, to see its method signatures etc.

### Implement the **methods** the plugin needs.

`DataSource` plugins require the following methods to be implemented:
* `load()` which returns the appropriate `DataSet`.
* `dataset_dir()` which returns a string with the subdirectory where any data files will be stored. If no data files, return `None`.
* `url()` which returns the data URL if relevant. If not applicable, return `None`.

The initializer `__init__()` can take keyword arguments related to initialization of the dataset, e.g. number of samples, random seed, etc.  

If you haven't implemented some required method for the plugin, Python will notify you by raising an exception when you
attempt to instantiate your plugin (see [Python `abc`](https://docs.python.org/3/library/abc.html)).


### **Register** the plugin with TemporAI.

Registering your plugin with TemporAI is very simple, you need to use the `register_plugin` decorator,
as shown in the example below.

You will need to specify the `name` of your plugin and its `category` in the decorator.

The `plugin_type` needs to be set to `"datasource"`.

```python
from tempor.core.plugins import register_plugin

@register_plugin(name="my_plugin", category="prediction.one_off", plugin_type="datasource")
class MyPlugin(OneOffPredictionDataSource):
    ...
```

### Example

Now putting this together in a minimal example.

In [None]:
import numpy as np

from tempor.data.dataset import OneOffPredictionDataset
from tempor.core.plugins import register_plugin
from tempor.datasources.datasource import OneOffPredictionDataSource


@register_plugin(name="my_datasource", category="prediction.one_off", plugin_type="datasource")
class MyDataSource(OneOffPredictionDataSource):
    def __init__(self, random_seed: int = 123, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.random_seed = random_seed

    def url(self):
        return None

    def dataset_dir(self):
        return None

    def load(self) -> OneOffPredictionDataset:
        np.random.seed(self.random_seed)
        return OneOffPredictionDataset(
            time_series=np.random.normal(size=(100, 30, 10)),
            targets=np.random.normal(size=(100, 1)),
        )

We now see our plugin in TemporAI:

In [None]:
from tempor import plugin_loader

all_plugins = plugin_loader.list(plugin_type="datasource")

pprint(all_plugins, indent_guides=True)

my_datasource_found = "my_datasource" in all_plugins["prediction"]["one_off"]
print(f"`my_datasource` plugin found in the category 'prediction.one_off': {my_datasource_found}")
assert my_datasource_found

`my_datasource` plugin found in the category 'prediction.one_off': True


The plugin can be used as normal.

In [None]:
# Get the plugin.

my_datasource = plugin_loader.get("prediction.one_off.my_datasource", plugin_type="datasource")

print(my_datasource)

<__main__.MyDataSource object at 0x7fa2cf032850>


In [None]:
# Load data.

dataset = my_datasource.load()

dataset

OneOffPredictionDataset(
    time_series=TimeSeriesSamples([100, *, 10]),
    predictive=OneOffPredictionTaskData(targets=StaticSamples([100, 1]))
)

In [None]:
# Preview covariates.

dataset.time_series

Unnamed: 0_level_0,Unnamed: 1_level_0,feat_0,feat_1,feat_2,feat_3,feat_4,feat_5,feat_6,feat_7,feat_8,feat_9
sample_idx,time_idx,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0,0,-1.085631,0.997345,0.282978,-1.506295,-0.578600,1.651437,-2.426679,-0.428913,1.265936,-0.866740
0,1,-0.678886,-0.094709,1.491390,-0.638902,-0.443982,-0.434351,2.205930,2.186786,1.004054,0.386186
0,2,0.737369,1.490732,-0.935834,1.175829,-1.253881,-0.637752,0.907105,-1.428681,-0.140069,-0.861755
0,3,-0.255619,-2.798589,-1.771533,-0.699877,0.927462,-0.173636,0.002846,0.688223,-0.879536,0.283627
0,4,-0.805367,-1.727669,-0.390900,0.573806,0.338589,-0.011830,2.392365,0.412912,0.978736,2.238143
...,...,...,...,...,...,...,...,...,...,...,...
99,25,-0.567276,-1.011354,-0.263128,0.281661,0.850365,0.675597,0.518956,1.458113,0.514021,-0.845099
99,26,-0.074948,2.889178,-0.055376,-1.284538,-0.215400,-0.002616,-0.406990,-0.089739,0.264811,1.060700
99,27,0.167216,-0.226127,1.517813,2.083333,-1.053875,-0.212461,1.006044,-0.253001,0.298598,-1.256375
99,28,1.212878,-1.656727,0.702245,0.047495,-0.736849,-0.050498,0.285193,0.735459,-0.384255,-0.262967


In [None]:
# Preview targets.

dataset.predictive.targets

Unnamed: 0_level_0,feat_0
sample_idx,Unnamed: 1_level_1
0,-1.054170
1,-0.783011
2,1.827901
3,1.746807
4,1.328258
...,...
95,-0.766137
96,1.112182
97,0.076831
98,-1.566442
