# Medicinal Plant Dataset Exploration

Use this notebook to inspect dataset health, class distribution, and a few sample images before/after transformations. It relies on `config.yaml` so tweaks there propagate automatically.

In [None]:
from pathlib import Path
import sys
import yaml
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import random

plt.style.use('seaborn-v0_8')

NOTEBOOK_CWD = Path.cwd().resolve()
PROJECT_ROOT = None
for candidate in [NOTEBOOK_CWD, NOTEBOOK_CWD.parent, NOTEBOOK_CWD.parent.parent]:
    if (candidate / "config.yaml").exists() and (candidate / "src").exists():
        PROJECT_ROOT = candidate
        break
if PROJECT_ROOT is None:
    raise RuntimeError("Unable to locate project root. Please run the notebook from within the repository.")

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
print(f"Project root: {PROJECT_ROOT}")

with open(PROJECT_ROOT / "config.yaml", "r") as f:
    cfg = yaml.safe_load(f)
cfg

In [None]:
from src.data import _is_valid_image
import pandas as pd

SPLITS = {
    "train": PROJECT_ROOT / cfg.get("train_dir", "data/train"),
    "val": PROJECT_ROOT / cfg.get("val_dir", "data/val"),
    "test": PROJECT_ROOT / cfg.get("test_dir", "data/test"),
}

records = []
for split_name, split_path in SPLITS.items():
    if not split_path.exists():
        print(f"[WARN] Missing split directory: {split_name} -> {split_path}")
        continue
    for cls_dir in sorted(p for p in split_path.iterdir() if p.is_dir()):
        total_files = 0
        valid_files = 0
        for file_path in cls_dir.iterdir():
            if not file_path.is_file():
                continue
            total_files += 1
            if _is_valid_image(str(file_path)):
                valid_files += 1
        records.append({
            "split": split_name,
            "class_id": cls_dir.name,
            "total_files": total_files,
            "valid_files": valid_files,
            "invalid_files": total_files - valid_files,
        })

counts_df = pd.DataFrame(records)
counts_df.head()

In [None]:
if not counts_df.empty:
    display(counts_df.groupby("split")[['valid_files', 'invalid_files']].sum())
    display(counts_df.groupby("split")["valid_files"].describe())
else:
    print("No split data found.")

In [None]:
if not counts_df.empty:
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))
    for ax, split in zip(axes, ["train", "val"]):
        subset = counts_df[counts_df["split"] == split]
        if subset.empty:
            ax.set_title(f"{split} (no data)")
            ax.axis('off')
            continue
        ax.hist(subset['valid_files'], bins=20, color='#2a9d8f')
        ax.set_title(f"{split} class counts")
        ax.set_xlabel('valid images per class')
        ax.set_ylabel('frequency')
    plt.tight_layout()
else:
    print("No data to plot.")

In [None]:
from src.data import build_transforms

transform = build_transforms(cfg["img_size"], is_train=True)

def show_random_sample(split="train"):
    split_path = SPLITS.get(split)
    if split_path is None or not split_path.exists():
        raise ValueError(f"Unknown or missing split: {split}")
    class_dirs = [p for p in split_path.iterdir() if p.is_dir()]
    if not class_dirs:
        raise RuntimeError(f"No class directories found in {split_path}")
    import random
    chosen_class = random.choice(class_dirs)
    valid_images = [p for p in chosen_class.iterdir() if p.is_file() and _is_valid_image(str(p))]
    if not valid_images:
        raise RuntimeError(f"No valid images in class {chosen_class}")
    chosen_image = random.choice(valid_images)
    pil_image = Image.open(chosen_image).convert("RGB")
    tensor_image = transform(pil_image)

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].imshow(pil_image)
    axes[0].set_title(f"Original ({split}/{chosen_class.name})")
    axes[0].axis('off')
    axes[1].imshow(tensor_image.permute(1, 2, 0))
    axes[1].set_title('Transformed (tensor)')
    axes[1].axis('off')
    plt.tight_layout()

show_random_sample('train')