In [24]:
import datasets
import pandas as pd

hf_data = datasets.load_dataset("pufanyi/cassava-leaf-disease-classification", "full")

In [25]:
train_df = hf_data["train"].to_pandas()
val_df = hf_data["validation"].to_pandas()
df = pd.concat([train_df, val_df])

In [26]:
len(df)

21397

In [27]:
labels = df["label"].unique()
labels

array([0, 3, 1, 2, 4])

In [28]:
data = {}
for label in labels:
    data[label] = df[df["label"] == label]

In [29]:
for label, d in data.items():
    print(f"Label {label}: {len(d)} samples")

Label 0: 1087 samples
Label 3: 13158 samples
Label 1: 2189 samples
Label 2: 2386 samples
Label 4: 2577 samples


In [43]:
from sklearn.model_selection import train_test_split

final_df_train, final_df_val = train_test_split(df, test_size=0.05, random_state=42)

In [44]:
len(final_df_train), len(final_df_val)

(20327, 1070)

In [46]:
def gen(df: pd.DataFrame):
    for _, row in df.iterrows():
        yield row.to_dict()


final_data = {
    "train": datasets.Dataset.from_generator(
        gen, gen_kwargs={"df": final_df_train}, features=hf_data["train"].features
    ),
    "validation": datasets.Dataset.from_generator(
        gen, gen_kwargs={"df": final_df_val}, features=hf_data["train"].features
    ),
    "test": hf_data["test"],
}

Generating train split: 20327 examples [00:04, 4858.34 examples/s]
Generating train split: 1070 examples [00:00, 4196.34 examples/s]


In [47]:
final_data

{'train': Dataset({
     features: ['image_id', 'image', 'label'],
     num_rows: 20327
 }),
 'validation': Dataset({
     features: ['image_id', 'image', 'label'],
     num_rows: 1070
 }),
 'test': Dataset({
     features: ['image_id', 'image', 'label'],
     num_rows: 1
 })}

In [48]:
final_data = datasets.DatasetDict(final_data)

In [49]:
final_data.push_to_hub("pufanyi/cassava-leaf-disease-classification", "full")

Map: 100%|██████████| 4066/4066 [00:01<00:00, 3499.41 examples/s]s]
Creating parquet from Arrow format: 100%|██████████| 41/41 [00:00<00:00, 61.63ba/s]
Map: 100%|██████████| 4066/4066 [00:01<00:00, 3381.06 examples/s]31.74s/it]
Creating parquet from Arrow format: 100%|██████████| 41/41 [00:00<00:00, 62.99ba/s]
Map: 100%|██████████| 4065/4065 [00:01<00:00, 3327.36 examples/s]27.97s/it]
Creating parquet from Arrow format: 100%|██████████| 41/41 [00:00<00:00, 52.86ba/s]
Map: 100%|██████████| 4065/4065 [00:01<00:00, 3319.88 examples/s]26.27s/it]
Creating parquet from Arrow format: 100%|██████████| 41/41 [00:00<00:00, 57.85ba/s]
Map: 100%|██████████| 4065/4065 [00:01<00:00, 3623.94 examples/s]24.56s/it]
Creating parquet from Arrow format: 100%|██████████| 41/41 [00:00<00:00, 58.56ba/s]
Uploading the dataset shards: 100%|██████████| 5/5 [02:07<00:00, 25.44s/it]
Map: 100%|██████████| 1070/1070 [00:00<00:00, 8922.05 examples/s]s]
Creating parquet from Arrow format: 100%|██████████| 11/11 [00:0

CommitInfo(commit_url='https://huggingface.co/datasets/pufanyi/cassava-leaf-disease-classification/commit/c8b14bae522158c357a845544ad2ba510774d78b', commit_message='Upload dataset', commit_description='', oid='c8b14bae522158c357a845544ad2ba510774d78b', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/pufanyi/cassava-leaf-disease-classification', endpoint='https://huggingface.co', repo_type='dataset', repo_id='pufanyi/cassava-leaf-disease-classification'), pr_revision=None, pr_num=None)