# Tutorial 0: Basic examples

In [None]:
import warnings
import sys

warnings.filterwarnings("ignore")

from sklearn.datasets import load_diabetes
import synthcity.logger as log

from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader

log.add(sink=sys.stderr, level="INFO")

X, y = load_diabetes(return_X_y=True, as_frame=True)
X["target"] = y

X

In [None]:
loader = GenericDataLoader(X, target_column="target", sensitive_columns=["sex"],)

## List the available generative models

In [None]:
from synthcity.plugins import Plugins

Plugins().list()

## Load and train a generative model

In [None]:
from synthcity.plugins import Plugins

syn_model = Plugins().get("marginal_distributions")

syn_model.fit(loader)

## Generate new data using the model

In [None]:
syn_model.generate(count=10).dataframe()

## Generate new data under some constraints

In [None]:
# Constraint: target <= 100
from synthcity.plugins.core.constraints import Constraints

constraints = Constraints(rules=[("target", "<=", 100)])

generated = syn_model.generate(count=10, constraints=constraints)

assert (generated["target"] <= 100).any()

generated.dataframe()

In [None]:
# Constraint: target > 150

from synthcity.plugins.core.constraints import Constraints

constraints = Constraints(rules=[("target", ">", 150)])

generated = syn_model.generate(count=10, constraints=constraints)

assert (generated["target"] > 150).any()

generated.dataframe()

## Serialization

In [None]:
from synthcity.utils.serialization import save, load, save_to_file, load_from_file

buff = save(syn_model)

type(buff)

In [None]:
reloaded = load(buff)

reloaded.name()

## Plot real-synthetic distributions

In [None]:
import matplotlib.pyplot as plt

syn_model.plot(plt, loader)

plt.show()

## Benchmark the quality of plugins

In [None]:
from synthcity.benchmark import Benchmarks

constraints = Constraints(rules=[("target", "ge", 150)])

score = Benchmarks.evaluate(
    [
        ("marginal_distributions", "marginal_distributions", {}),
        ("dummy_sampler", "dummy_sampler", {}),
    ],
    loader,
    synthetic_size=1000,
    synthetic_constraints=constraints,
    repeats=2,
)

In [None]:
Benchmarks.print(score)

In [None]:
import pandas as pd
import numpy as np

means = []
for plugin in score:
    data = score[plugin]["mean"]
    directions = score[plugin]["direction"].to_dict()
    means.append(data)

out = pd.concat(means, axis=1)
out.set_axis(score.keys(), axis=1, inplace=True)

bad_highlight = "background-color: lightcoral;"
ok_highlight = "background-color: green;"
default = ""


def highlights(row):
    metric = row.name
    if directions[metric] == "minimize":
        best_val = np.min(row.values)
        worst_val = np.max(row)
    else:
        best_val = np.max(row.values)
        worst_val = np.min(row)

    styles = []
    for val in row.values:
        if val == best_val:
            styles.append(ok_highlight)
        elif val == worst_val:
            styles.append(bad_highlight)
        else:
            styles.append(default)

    return styles


out.style.apply(highlights, axis=1)

# 