In [1]:
import datasets
import pandas as pd

hf_data = datasets.load_dataset("pufanyi/cassava-leaf-disease-classification", "full")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_df = hf_data["train"].to_pandas()
val_df = hf_data["validation"].to_pandas()
df = pd.concat([train_df, val_df])

In [3]:
len(df)

21397

In [4]:
labels = df["label"].unique()
labels

array([3, 1, 2, 4, 0])

In [5]:
data = {}
for label in labels:
    data[label] = df[df["label"] == label]

In [6]:
for label, d in data.items():
    print(f"Label {label}: {len(d)} samples")

Label 3: 13158 samples
Label 1: 2189 samples
Label 2: 2386 samples
Label 4: 2577 samples
Label 0: 1087 samples


In [7]:
from sklearn.model_selection import train_test_split

final_df_train, final_df_val = train_test_split(df, test_size=0.2, random_state=42)

In [8]:
len(final_df_train), len(final_df_val)

(17117, 4280)

In [9]:
def gen(df: pd.DataFrame):
    for _, row in df.iterrows():
        yield row.to_dict()


final_data = {
    "train": datasets.Dataset.from_generator(
        gen, gen_kwargs={"df": final_df_train}, features=hf_data["train"].features
    ),
    "validation": datasets.Dataset.from_generator(
        gen, gen_kwargs={"df": final_df_val}, features=hf_data["train"].features
    ),
    "test": hf_data["test"],
}

Generating train split: 17117 examples [00:07, 2413.56 examples/s]
Generating train split: 4280 examples [00:01, 2867.51 examples/s]


In [10]:
final_data

{'train': Dataset({
     features: ['image_id', 'image', 'label'],
     num_rows: 17117
 }),
 'validation': Dataset({
     features: ['image_id', 'image', 'label'],
     num_rows: 4280
 }),
 'test': Dataset({
     features: ['image_id', 'image', 'label'],
     num_rows: 1
 })}

In [11]:
final_data = datasets.DatasetDict(final_data)

In [12]:
final_data.push_to_hub("pufanyi/cassava-leaf-disease-classification", "resized")

Map: 100%|██████████| 3424/3424 [00:01<00:00, 2693.15 examples/s]s]
Creating parquet from Arrow format: 100%|██████████| 35/35 [00:01<00:00, 31.55ba/s]
Map: 100%|██████████| 3424/3424 [00:01<00:00, 2397.87 examples/s]24.61s/it]
Creating parquet from Arrow format: 100%|██████████| 35/35 [00:01<00:00, 32.78ba/s]
Map: 100%|██████████| 3423/3423 [00:01<00:00, 2405.90 examples/s]21.52s/it]
Creating parquet from Arrow format: 100%|██████████| 35/35 [00:01<00:00, 32.11ba/s]
Map: 100%|██████████| 3423/3423 [00:01<00:00, 2357.82 examples/s]20.69s/it]
Creating parquet from Arrow format: 100%|██████████| 35/35 [00:01<00:00, 32.65ba/s]
Map: 100%|██████████| 3423/3423 [00:01<00:00, 2576.16 examples/s]22.14s/it]
Creating parquet from Arrow format: 100%|██████████| 35/35 [00:01<00:00, 32.63ba/s]
Uploading the dataset shards: 100%|██████████| 5/5 [01:51<00:00, 22.34s/it]
Map: 100%|██████████| 2140/2140 [00:00<00:00, 3321.77 examples/s]s]
Creating parquet from Arrow format: 100%|██████████| 22/22 [00:0

CommitInfo(commit_url='https://huggingface.co/datasets/pufanyi/cassava-leaf-disease-classification/commit/f75b4f48e2ce96d25dc91c5cdf8feff478d89e54', commit_message='Upload dataset', commit_description='', oid='f75b4f48e2ce96d25dc91c5cdf8feff478d89e54', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/pufanyi/cassava-leaf-disease-classification', endpoint='https://huggingface.co', repo_type='dataset', repo_id='pufanyi/cassava-leaf-disease-classification'), pr_revision=None, pr_num=None)