# Data Distribution

Explore data distribution strategies and visualization.

In [None]:
import pandas as pd
from fl_manager.utils.visualization.plotters.dataset_comparison_heatmap_plotter import DatasetComparisonHeatMapPlotter
from fl_manager.utils.visualization.plotters.dataset_comparison_plotter import DatasetComparisonPlotter

from fl_manager.core.meta_registry import MetaRegistry
from fl_manager.core.schemas.pandas_dataset import PandasDataset

In [None]:
MetaRegistry.get("dataset_distributor").list()

### Synthetic Dataset

In [None]:
data = {
  "feature1": range(1000),
  "feature2": range(1000, 2000),
  "label": [i % 10 for i in range(1000)]  # 10 classes
}
df = pd.DataFrame(data)

### Data Distribution & Split

In [None]:
splitter = MetaRegistry.get("dataset_splitter").create("proportion", [0.89, 0.0, 0.11])
distributor = MetaRegistry.get("dataset_distributor").create("iid", 10, with_server=False)

In [None]:
distributed = distributor._distribute(df)

In [None]:
split = splitter.split(distributor.get_dataset_distribution(PandasDataset(train=df)))
split

In [None]:
distributor.global_test_split

In [None]:
dirichlet_distributor = MetaRegistry.get("dataset_distributor").create("dirichlet", target_col='label', balancing=False,
                                                                       num_clients=3, alpha=0.5, with_server=False,
                                                                       min_distribution_size=0)
dirichlet_distributor.get_dataset_distribution(PandasDataset(train=df))

In [None]:
plotter = DatasetComparisonPlotter(dirichlet_distributor._distributed_dataset, 'label', x_label='Client',
                                   color_palette='deep')
plotter.run()

In [None]:
plotter = DatasetComparisonPlotter(distributor._distributed_dataset, 'label', 'Client', color_palette='deep',
                                   hatch_patterns=['/o', '\\', '|*', '-', '+|', 'x', 'o', 'O', '.', '*'])
plotter.run()

In [None]:
plotter = DatasetComparisonHeatMapPlotter(dirichlet_distributor._distributed_dataset, 'label', 'Client')
plotter.run()

In [None]:
plotter = DatasetComparisonHeatMapPlotter(distributor._distributed_dataset, 'label', 'Client')
plotter.run()

In [None]:
MetaRegistry.get("dataset_splitter").list()

In [None]:
stratified_splitter = MetaRegistry.get("dataset_splitter").create("stratified", target_col="label",
                                                                  min_samples_per_class=20)

In [None]:
s = stratified_splitter.split(PandasDataset(train=dirichlet_distributor._distributed_dataset[0]))
s

In [None]:
plotter = DatasetComparisonPlotter([s.train, s.val, s.test], 'label', 'Split', ['train', 'val', 'test'],
                                   color_palette='deep')
plotter.run()

In [None]:
splitter = MetaRegistry.get("dataset_splitter").create("proportion", [0.8, 0.1, 0.1])
d = splitter.split(PandasDataset(train=dirichlet_distributor._distributed_dataset[0]))
d

In [None]:
plotter = DatasetComparisonPlotter([d.train, d.val, d.test], 'label', 'Split', ['train', 'val', 'test'],
                                   color_palette='deep')
plotter.run()

In [None]:
plotter = DatasetComparisonHeatMapPlotter([d.train, d.val, d.test], 'label', 'Split', ['train', 'val', 'test'])
plotter.run()