# Medicinal Plant Dataset Exploration

Use this notebook to sanity-check split directories, inspect class balance, and visualize samples before training. It reads from `config.yaml` so path changes propagate automatically.

What you get:
- quick counts of valid/invalid files per split
- class imbalance snapshots (top/bottom classes)
- visual checks for distribution and augmentations


In [None]:
from pathlib import Path
import sys
import yaml
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import random

plt.style.use("seaborn-v0_8")
pd.set_option("display.max_rows", 20)
pd.options.display.float_format = "{:,.2f}".format

NOTEBOOK_CWD = Path.cwd().resolve()
PROJECT_ROOT = next(
    (
        p
        for p in [NOTEBOOK_CWD, NOTEBOOK_CWD.parent, NOTEBOOK_CWD.parent.parent]
        if (p / "config.yaml").exists() and (p / "src").exists()
    ),
    None,
)
if PROJECT_ROOT is None:
    raise RuntimeError("Unable to locate project root. Please run the notebook from within the repository.")

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

with open(PROJECT_ROOT / "config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

DATA_ROOT = PROJECT_ROOT / cfg.get("data_dir", "data")
SPLITS = {
    "train": PROJECT_ROOT / cfg.get("train_dir", "data/train"),
    "val": PROJECT_ROOT / cfg.get("val_dir", "data/val"),
    "test": PROJECT_ROOT / cfg.get("test_dir", "data/test"),
}

print(f"Project root: {PROJECT_ROOT}")
print(f"Data root: {DATA_ROOT}")
cfg


In [None]:
from src.data import _is_valid_image

records = []
missing_splits = []

for split_name, split_path in SPLITS.items():
    if not split_path.exists():
        missing_splits.append(split_name)
        continue

    for cls_dir in sorted(p for p in split_path.iterdir() if p.is_dir()):
        total_files = 0
        valid_files = 0
        for file_path in cls_dir.iterdir():
            if not file_path.is_file():
                continue
            total_files += 1
            if _is_valid_image(str(file_path)):
                valid_files += 1

        records.append(
            {
                "split": split_name,
                "class_id": cls_dir.name,
                "total_files": total_files,
                "valid_files": valid_files,
                "invalid_files": total_files - valid_files,
            }
        )

counts_df = pd.DataFrame(records)
display(counts_df.head(10))

if missing_splits:
    print(f"[WARN] missing split directories: {', '.join(missing_splits)}")
elif counts_df.empty:
    print("[WARN] no data found. Check config paths above.")


In [None]:
if counts_df.empty:
    print("No split data to summarize.")
else:
    split_summary = (
        counts_df
        .groupby("split")[
            ["total_files", "valid_files", "invalid_files"]
        ]
        .sum()
        .assign(valid_pct=lambda df: 100 * df["valid_files"] / df["total_files"])
        .round(2)
        .sort_index()
    )

    class_dirs = counts_df.groupby("split")["class_id"].nunique().rename("class_dirs")
    non_empty_classes = (
        counts_df[counts_df["valid_files"] > 0]
        .groupby("split")["class_id"]
        .nunique()
        .rename("classes_with_images")
    )
    coverage = pd.concat([class_dirs, non_empty_classes], axis=1).fillna(0).astype(int)

    print("Per-split totals (valid vs invalid files):")
    display(split_summary)

    print("Per-split class coverage:")
    display(coverage)

    class_totals = (
        counts_df
        .groupby("class_id")[["valid_files", "invalid_files"]]
        .sum()
        .sort_values("valid_files", ascending=False)
    )
    print("Top classes by valid images (all splits combined):")
    display(class_totals.head(10))


In [None]:
if counts_df.empty:
    print("No class distribution to inspect.")
else:
    train_counts = counts_df[counts_df["split"] == "train"].sort_values("valid_files", ascending=False)
    if train_counts.empty:
        print("Train split is empty; skipping class balance preview.")
    else:
        top10 = train_counts.head(10)[["class_id", "valid_files", "invalid_files"]]
        bottom10 = train_counts.tail(10).sort_values("valid_files")[["class_id", "valid_files", "invalid_files"]]

        print("Largest train classes (by valid files):")
        display(top10.reset_index(drop=True))

        print("Smallest train classes (by valid files):")
        display(bottom10.reset_index(drop=True))


In [None]:
if counts_df.empty:
    print("No data to plot.")
else:
    fig, axes = plt.subplots(1, 3, figsize=(20, 5))

    totals = counts_df.groupby("split")[["valid_files", "invalid_files"]].sum()
    if not totals.empty:
        totals.plot(kind="bar", stacked=True, ax=axes[0], color=["#2a9d8f", "#e76f51"])
        axes[0].set_title("Valid vs invalid images by split")
        axes[0].set_ylabel("images")
        axes[0].legend(loc="upper right")
    else:
        axes[0].axis("off")
        axes[0].set_title("No split totals")

    for split, color in [("train", "#264653"), ("val", "#2a9d8f")]:
        subset = counts_df[counts_df["split"] == split]["valid_files"]
        if not subset.empty:
            axes[1].hist(subset, bins=20, alpha=0.65, label=split, color=color)
    axes[1].set_title("Per-class valid image counts")
    axes[1].set_xlabel("images per class")
    axes[1].set_ylabel("frequency")
    axes[1].legend()

    train_counts = counts_df[counts_df["split"] == "train"].sort_values("valid_files")
    if train_counts.empty:
        axes[2].axis("off")
        axes[2].set_title("No train data")
    else:
        tail = train_counts.head(min(5, len(train_counts)))
        head = train_counts.tail(min(5, len(train_counts)))
        combined = pd.concat([tail, head])
        axes[2].barh(combined["class_id"], combined["valid_files"], color="#457b9d")
        axes[2].set_title("Train: smallest vs largest classes")
        axes[2].set_xlabel("valid images")
        axes[2].invert_yaxis()

    plt.tight_layout()
    plt.show()


## Class distribution shares

How much of each split is dominated by a handful of classes? The chart below shows the proportion of images contributed by the top classes (everything beyond the top 15 is grouped under `other`).


In [None]:
if counts_df.empty:
    print("No class distribution to plot.")
else:
    ncols = len(SPLITS)
    fig, axes = plt.subplots(1, ncols, figsize=(7 * ncols, 5))
    if ncols == 1:
        axes = [axes]

    for ax, (split, _) in zip(axes, SPLITS.items()):
        subset = counts_df[counts_df["split"] == split].sort_values("valid_files", ascending=False)
        if subset.empty:
            ax.axis("off")
            ax.set_title(f"{split} (no data)")
            continue

        total_valid = subset["valid_files"].sum()
        top = subset.head(15)
        other_count = subset.iloc[15:]["valid_files"].sum()

        labels = list(top["class_id"])
        shares = list((top["valid_files"] / total_valid * 100).round(2))
        if other_count > 0:
            labels.append("other")
            shares.append(round(other_count / total_valid * 100, 2))

        ax.barh(labels[::-1], shares[::-1], color="#2a9d8f")
        ax.set_title(f"{split}: share of images by class")
        ax.set_xlabel("% of split")
        ax.set_xlim(0, max(shares) * 1.15)

    plt.tight_layout()


## Imbalance spotlight

Two quick visuals to surface long-tail issues:
- **Cumulative share**: how many classes account for 80-90% of images.
- **Count ratio**: how class image counts compare to the median (log-scaled).


In [None]:
if counts_df.empty:
    print("No data for imbalance charts.")
else:
    fig, axes = plt.subplots(1, 2, figsize=(18, 5))

    # Cumulative share per split (Lorenz-style)
    for split, color in [("train", "#2a9d8f"), ("val", "#e76f51")]:
        subset = (
            counts_df[(counts_df["split"] == split) & (counts_df["valid_files"] > 0)]
            .sort_values("valid_files", ascending=False)
        )
        if subset.empty:
            continue

        counts = subset["valid_files"].to_numpy()
        cumulative = (counts.cumsum() / counts.sum()) * 100
        x = range(1, len(counts) + 1)

        axes[0].plot(x, cumulative, marker="o", markersize=3, label=f"{split} ({len(counts)} classes)", color=color)

        for threshold in (80, 90):
            idx = next((i for i, v in enumerate(cumulative, start=1) if v >= threshold), None)
            if idx:
                axes[0].axvline(idx, color=color, linestyle="--", alpha=0.25)
                axes[0].text(idx, threshold + 2, f"{threshold}% in {idx} classes", color=color, ha="right", va="bottom", fontsize=9)

    axes[0].set_title("Cumulative share of images by class rank")
    axes[0].set_xlabel("class rank (most to least images)")
    axes[0].set_ylabel("% of split covered")
    axes[0].set_ylim(0, 105)
    axes[0].grid(alpha=0.2)
    axes[0].legend()

    # Distribution of class counts vs median (train split)
    train_counts = counts_df[(counts_df["split"] == "train") & (counts_df["valid_files"] > 0)]
    if train_counts.empty:
        axes[1].axis("off")
        axes[1].set_title("Train split empty")
    else:
        ratios = train_counts["valid_files"] / train_counts["valid_files"].median()
        axes[1].hist(ratios, bins=25, color="#264653", alpha=0.85)
        axes[1].axvline(1, color="black", linestyle="--", linewidth=1)
        axes[1].set_xscale("log")
        axes[1].set_title("Train class counts vs median (log scale)")
        axes[1].set_xlabel("count / median count")
        axes[1].set_ylabel("number of classes")

        max_min_ratio = train_counts["valid_files"].max() / train_counts["valid_files"].min()
        axes[1].text(0.98, 0.95, f"max/min: {max_min_ratio:.1f}x", transform=axes[1].transAxes, ha="right", va="top", fontsize=9, bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.7))

    plt.tight_layout()


In [None]:
import torch
from src.data import build_transforms, IMAGENET_MEAN, IMAGENET_STD

transform = build_transforms(cfg["img_size"], is_train=True)
mean = torch.tensor(cfg.get("data_cfg", {}).get("mean", IMAGENET_MEAN)).view(3, 1, 1)
std = torch.tensor(cfg.get("data_cfg", {}).get("std", IMAGENET_STD)).view(3, 1, 1)

def _denormalize(img_tensor):
    return img_tensor * std + mean

def show_random_sample(split="train", n=4, seed=None, apply_transform=True):
    rng = random.Random(seed)
    split_path = SPLITS.get(split)
    if split_path is None or not split_path.exists():
        raise ValueError(f"Unknown or missing split: {split}")

    class_dirs = [p for p in split_path.iterdir() if p.is_dir()]
    if not class_dirs:
        raise RuntimeError(f"No class directories found in {split_path}")

    valid_images = [p for p in split_path.rglob("*") if p.is_file() and _is_valid_image(str(p))]
    if not valid_images:
        raise RuntimeError(f"No valid images found in {split_path}")

    chosen = rng.sample(valid_images, k=min(n, len(valid_images)))
    cols = 2 if apply_transform else 1
    fig, axes = plt.subplots(len(chosen), cols, figsize=(6 * cols, 3 * len(chosen)), squeeze=False)

    for ax_row, img_path in zip(axes, chosen):
        pil_image = Image.open(img_path).convert("RGB")
        ax_row[0].imshow(pil_image)
        ax_row[0].set_title(f"Original: {img_path.parent.name}")
        ax_row[0].axis("off")

        if apply_transform:
            tensor_image = transform(pil_image)
            vis_tensor = _denormalize(tensor_image).permute(1, 2, 0).clamp(0, 1)
            ax_row[1].imshow(vis_tensor)
            ax_row[1].set_title("Transformed (train pipeline)")
            ax_row[1].axis("off")

    plt.tight_layout()

show_random_sample("train", n=3, seed=0)
