# Dataset Formatter

Testing dataset formatter capabilities.

In [None]:
import pandas as pd
from torch.utils.data import DataLoader

from fl_manager.components.preprocessors.torchvision import constants as vision_constants
from fl_manager.core.components.datasets import DataFrameDatasetRegistry
from fl_manager.core.components.formatters import DatasetFormatterComposite, DatasetFormatterRegistry
from fl_manager.core.components.preprocessors import DatasetPreprocessorComposite, DatasetPreprocessorRegistry
from fl_manager.core.components.readers import DatasetReaderRegistry
from fl_manager.core.schemas.pandas_dataset import PandasDataset

### Load Dataset

In [None]:
DatasetReaderRegistry.list()

In [None]:
reader = DatasetReaderRegistry.create("huggingface", "ylecun/mnist", {"train": "train"})
data = reader.fetch_dataset()
train_data = data.train

In [None]:
reader._available_splits

In [None]:
train_data.head(2)

### Setup Operations and Run Formatter

In [None]:
DatasetFormatterRegistry.list()

In [None]:
dataset_formatter = DatasetFormatterComposite()
dataset_formatter.add(DatasetFormatterRegistry.create("dict_extractor", "image", ["bytes"]))
dataset_formatter.add(DatasetFormatterRegistry.create("column_dropper", "image"))
dataset_formatter.add(DatasetFormatterRegistry.create("column_rename", "image__bytes", "image"))

In [None]:
f_train_data = dataset_formatter.run(train_data)

In [None]:
f_train_data.head(2)

### DataLoader

In [None]:
dataset_preprocessor = DatasetPreprocessorComposite()
dataset_preprocessor.add(DatasetPreprocessorRegistry.create("bytes_to_tensor"))
dataset_preprocessor.add(
  DatasetPreprocessorRegistry.create("tensor_normalization", **vision_constants.MNIST_NORM_VALUES))

In [None]:
dataset = DataFrameDatasetRegistry.get('torch_dataframe_transforms_dataset')({"image": dataset_preprocessor},
                                                                             ["image", "label"]).get_dataset(
  PandasDataset(train=f_train_data, val=pd.DataFrame(), test=pd.DataFrame()))

In [None]:
dataloader = DataLoader(dataset.train, batch_size=2)
for sample in dataloader:
  break
sample