# Hugginface support

This Jupyter Notebook demonstrates various operations involving the huggingface bridge:

1. Converting a plaid dataset to hugginface 
2. Generating a hugginface dataset with a generator
3. Converting a hugginface dataset to plaid
4. Saving and Loading hugginface datasets
5. Handling plaid samples from huggingface datasets without converting the complete dataset to plaid


**Each section is documented and explained.**

In [None]:
# Import necessary libraries and functions
from Muscat.Bridges.CGNSBridge import MeshToCGNS
from Muscat.Containers import MeshCreationTools as MCT

import numpy as np
import pickle

from plaid.bridges import huggingface_bridge
from plaid.containers.sample import Sample
from plaid.containers.dataset import Dataset
from plaid.problem_definition import ProblemDefinition


In [None]:
# Print Sample util
def show_sample(sample: Sample):
    print(f"sample = {sample}")
    sample.show_tree()
    print(f"{sample.get_scalar_names() = }")
    print(f"{sample.get_field_names() = }")


## Initialize plaid dataset and problem_definition

In [None]:
# Input data
points = np.array([
    [0.0, 0.0],
    [1.0, 0.0],
    [1.0, 1.0],
    [0.0, 1.0],
    [0.5, 1.5],
])

triangles = np.array([
    [0, 1, 2],
    [0, 2, 3],
    [2, 4, 3],
])


dataset = Dataset()

print("Creating meshes dataset...")
for _ in range(3):

    mesh = MCT.CreateMeshOfTriangles(points, triangles)

    sample = Sample()

    sample.add_tree(MeshToCGNS(mesh))
    sample.add_scalar("scalar", np.random.randn())
    sample.add_field("node_field", np.random.rand(1, len(points)), location = 'Vertex')
    sample.add_field("cell_field", np.random.rand(1, len(points)), location = 'CellCenter')

    dataset.add_sample(sample)

infos = {
        "legal": {
            "owner": "Bob",
            "license": "my_license"},
        "data_production": {
            "type": "simulation",
            "physics": "3D example"}
    }

dataset.set_infos(infos)

print(f" {dataset = }")

problem = ProblemDefinition()
problem.add_output_scalars_names(["scalar"])
problem.add_output_fields_names(["node_field", "cell_field"])
problem.add_input_meshes_names(['/Base/Zone'])

problem.set_task('regression')
problem.set_split({'train':[0,1], 'test':[2]})

print(f" {problem = }")


## Section 1: Convert plaid dataset to huggingface

The description field of huggingface dataset is automatically configured to include data from the plaid dataset info and problem_definition to prevent loss of information and equivalence of format. 

In [None]:
hf_dataset = huggingface_bridge.plaid_dataset_to_huggingface(dataset, problem)
print()
print(f"{hf_dataset = }")
print(f"{hf_dataset.description = }")


## Section 2: Generate a hugginface dataset with a generator

In [None]:
def generator():
    for id in range(len(dataset)):
        yield {
            "sample" : pickle.dumps(dataset[id]),
        }

hf_dataset_gen = huggingface_bridge.plaid_generator_to_huggingface(generator, infos, problem)
print()
print(f"{hf_dataset_gen = }")
print(f"{hf_dataset_gen.description = }")


## Section 3: Convert a hugginface dataset to plaid

Plaid dataset infos and problem_defitinion are recovered from the huggingface dataset

In [None]:
dataset_2, problem_2 = huggingface_bridge.huggingface_dataset_to_plaid(hf_dataset)
print()
print(f"{dataset_2 = }")
print(f"{dataset_2.get_infos() = }")
print(f"{problem_2 = }")


## Section 4: Save and Load hugginface datasets

### From and to disk

In [None]:
# Save to disk
hf_dataset.save_to_disk("/tmp/path/to/dir")


In [None]:
# Load from disk
from datasets import load_from_disk
loaded_hf_dataset = load_from_disk("/tmp/path/to/dir")

print()
print(f"{loaded_hf_dataset = }")
print(f"{loaded_hf_dataset.description = }")

### From and to the huggingface hub

You need an huggingface account, with a configured access token, and to install huggingface_hub[cli].
Pushing and loading a huggingface dataset without loss of information requires the configuration of a DatasetCard.

Find below example of instruction (not executed by this notebook).

### Push to the hub

First login the huggingface cli:
```bash
    
    huggingface-cli login

```
and enter you access token.

Then, the following python instruction enable pushing a dataset to the hub:
```python
    
    hf_dataset.push_to_hub("chanel/dataset")

    from datasets import load_dataset_builder

    datasetInfo = load_dataset_builder("chanel/dataset").__getstate__()['info']

    from huggingface_hub import DatasetCard

    card_text = create_string_for_huggingface_dataset_card(
        description = description,
        download_size_bytes = datasetInfo.download_size,
        dataset_size_bytes = datasetInfo.dataset_size,
        ...)
    dataset_card = DatasetCard(card_text)
    dataset_card.push_to_hub("chanel/dataset")

```

The second upload of the dataset_card is required to ensure that load_dataset from the hub will populate
the hf-dataset.description field, and be compatible for conversion to plaid. Wihtout a dataset_card, the description field is lost.


### Load from hub

```python

    dataset = load_dataset("chanel/dataset", split="all_samples")

```

## Section 5: Handle plaid samples from huggingface datasets without converting the complete dataset to plaid

To fully exploit optimzed data handling of the huggingface datasets library, it is possible to extract information from the huggingface dataset without converting to plaid. The ``description`` atttribute includes the plaid dataset _infos attribute and plaid problem_definition attributes.

In [None]:
print(f"{loaded_hf_dataset.description = }")

Get the first sample of the first split

In [None]:
split_names = list(loaded_hf_dataset.description["split"].keys())
id = loaded_hf_dataset.description["split"][split_names[0]]
hf_sample = loaded_hf_dataset[id[0]]["sample"]

print(f"{hf_sample = }")

We notice that ``hf_sample`` is a binary object efficiently handled by huggingface datasets. It can be converted into a plaid sample using a specific constructor relying on a pydantic validator.

In [None]:
plaid_sample = Sample.model_validate(pickle.loads(hf_sample))

show_sample(plaid_sample)