# Add a new plugin

By default, the library will import all the files with prefix "plugin\_\*.py" from src/synthcity/plugins, and load all the classes which implement the [Plugin interface](src/synthcity/plugins/core/plugin.py).

Each plugin must implement the following methods:
- hyperparameter_space() - a static method that returns the hyperparameters that can be tuned during AutoML.
- type() - a static method that returns the type of the plugin. e.g., debug, generative, bayesian, etc.
- name() - a static method that returns the name of the plugin. e.g., ctgan, random_noisee, etc.
- _fit() - internal method, called by `fit` on each training set.
- _generate() - internal method, called by `generate`.

## Existing plugins

In [None]:
# synthcity absolute
from synthcity.plugins import Plugins

generators = Plugins()

generators.list()

## Example plugin: Generate 0-1

In [None]:
# stdlib
from typing import Any, List

# third party
import numpy as np
import pandas as pd

# synthcity absolute
from synthcity.plugins.core.dataloader import DataLoader, GenericDataLoader
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema


class ZeroOnePlugin(Plugin):
    """Dummy plugin for debugging."""

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)

    @staticmethod
    def name() -> str:
        return "zero_one"

    @staticmethod
    def type() -> str:
        return "debug"

    @staticmethod
    def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Distribution]:
        return []

    def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "ZeroOnePlugin":
        self.features_count = X.shape[1]
        return self

    def _generate(self, count: int, syn_schema: Schema, **kwargs: Any):
        return GenericDataLoader(
            np.random.randint(0, 2, size=(count, self.features_count))
        )

In [None]:
# Add the new plugin to the collection

generators.add("zero_one", ZeroOnePlugin)

In [None]:
# Check the new plugins list
generators.list()

In [None]:
# Load reference data

# third party
from sklearn.datasets import load_breast_cancer

X, y = load_breast_cancer(return_X_y=True, as_frame=True)

loader = GenericDataLoader(X)

loader.dataframe()

In [None]:
# Train the new plugin

gen = generators.get("zero_one")

gen.fit(loader)

In [None]:
# Generate some new data

gen.generate(count=10)

### Oops, this didn't work.

__The Plugin interface enforces the new generated data to:__
 - satistify the same constraints as the training set.
 - Or to satisfy the constraints provided at inference time(if provided).
 
 
 If the generated dataframe fails to comply, an exception will be raised.

Let's try again

## A functional plugin: Integrate SDV CTGAN

In this example, we will show how to integrate the CTGAN implementation from the [SDV library](https://github.com/sdv-dev/SDV).

__Note__ : `synthcity` also includes a dedicated CTGAN re-implementation.

In [None]:
# stdlib
from typing import Any, List

# third party
import numpy as np
import pandas as pd
from ctgan import CTGANSynthesizer

# synthcity absolute
from synthcity.plugins.core.dataloader import DataLoader, GenericDataLoader
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema
from synthcity.plugins.core.distribution import (
    Distribution,
    IntegerDistribution,
)


class sdv_ctgan_plugin(Plugin):
    """SDV CTGAN integration in synthcity."""

    def __init__(
        self,
        embedding_n_units: int = 128,
        n_iter: int = 2000,
        batch_size: int = 100,
        cat_limit: int = 15,
        **kwargs: Any
    ) -> None:
        super().__init__(**kwargs)
        self.cat_limit = cat_limit
        self.model = CTGANSynthesizer(
            embedding_dim=embedding_n_units,
            batch_size=batch_size,
            epochs=n_iter,
            verbose=False,
        )

    @staticmethod
    def name() -> str:
        return "sdv_ctgan"

    @staticmethod
    def type() -> str:
        return "debug"

    @staticmethod
    def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
        """
        We can customize the hyperparameter space, and use it in AutoML benchmarks.
        """
        return [
            IntegerDistribution(name="embedding_n_units", low=100, high=500, step=50),
            IntegerDistribution(name="batch_size", low=100, high=300, step=50),
            IntegerDistribution(name="n_iter", low=100, high=500, step=50),
        ]

    def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "sdvPlugin":
        """We selected the discrete columns based on the count of unique values, and train the CTGAN"""
        discrete_columns = []

        for col in X.columns:
            if len(X[col].unique()) < self.cat_limit:
                discrete_columns.append(col)

        self.model.fit(X.dataframe(), discrete_columns=discrete_columns)
        return self

    def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFrame:
        return self._safe_generate(self.model.sample, count, syn_schema)

In [None]:
generators.add("sdv_ctgan", sdv_ctgan_plugin)

generators.list()

In [None]:
# Train the new plugin

gen = generators.get("sdv_ctgan", n_iter=100)

gen.fit(loader)

In [None]:
# Generate some new data

gen.generate(count=10).dataframe()

In [None]:
# Custom generation constraints

# synthcity absolute
from synthcity.plugins.core.constraints import Constraints

constraints = Constraints(rules=[("worst radius", ">", 15)])

generated = gen.generate(count=10, constraints=constraints)

assert (generated["worst radius"] > 15).any()

generated.dataframe()

## Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!

### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub

- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.


### Checkout other projects from vanderschaarlab
- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)
- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)
