In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split

In [None]:
SEED = 42
DATA_DIR = Path("data")
DATASET_ROOT = DATA_DIR / Path("mushrooms")
LABELS_TXT = DATASET_ROOT / Path("mushrooms.txt")
MUSHROOMS_DIR = DATASET_ROOT / Path("data/data")

In [None]:
files = MUSHROOMS_DIR.glob("*/*")

image_lst = []
label_lst = []
for _file in files:
    image_lst.append(_file)
    label_lst.append(_file.parent.name)
    
dataset_df = pd.DataFrame.from_dict({"image":image_lst, "label":label_lst})
dataset_df.head()

In [None]:
sns.set_style("white")
hist= sns.histplot(dataset_df["label"])
hist.set(xticklabels=[]) 
plt.show()

In [None]:
# Stratify is None, because dataset is balanced.
train_df, val_df = train_test_split(dataset_df,random_state=SEED,shuffle=False)
train_df.reset_index(inplace=True,drop="index")
val_df.reset_index(inplace=True,drop="index")
train_df["is_valid"] = False
val_df["is_valid"] = True
dataset_df = pd.concat([train_df,val_df])
dataset_df.head()

In [None]:
dataset_df.to_csv(DATA_DIR / Path("mushrooms.csv"), index=False)

In [None]:
df_labels = np.sort(dataset_df["label"].unique())
labels = np.sort(np.loadtxt(LABELS_TXT,dtype=str))
print((df_labels==labels).all())