# Weighted Random Sampling

Let us verify that `WeightedRandomSampler` correctly balances the class distribution in our training batches. Without sampling, the batches would reflect the natural imbalanced distribution of the dataset, potentially leading to statistical bias during model training.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
from pathlib import Path

project_root = Path("../..").resolve()
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from jute_disease.data.jute_datamodule import JuteDataModule
from tqdm import tqdm

from jute_disease.utils.constants import DEFAULT_SEED
from jute_disease.utils.logger import get_logger
from jute_disease.utils.seed import seed_everything

In [None]:
logger = get_logger(__name__)
seed_everything(DEFAULT_SEED)

Let us instantiate two `JuteDataModule` instances: one that does not use a weighted sampler and one that does.

In [None]:
dm_natural = JuteDataModule(use_weighted_sampler=False)
dm_natural.prepare_data()
dm_natural.setup()

dm_weighted = JuteDataModule(use_weighted_sampler=True)
dm_weighted.prepare_data()
dm_weighted.setup()

logger.info(f"Classes: {dm_weighted.classes}")

In [None]:
def collect_labels(dm, num_batches=50):
    loader = dm.train_dataloader()
    all_labels = []
    logger.info(f"Checking {num_batches} batches...")
    for i, batch in tqdm(enumerate(loader), total=num_batches):
        if i >= num_batches:
            break
        _, labels = batch
        all_labels.extend(labels.tolist())
    return all_labels


logger.info("Collecting natural samples...")
natural_labels = collect_labels(dm_natural)

logger.info("Collecting weighted samples...")
weighted_labels = collect_labels(dm_weighted)

Let us visualize and compare the two distributions.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 6))

# Natural Sampler Distribution
sns.countplot(x=natural_labels, ax=axes[0])
axes[0].set_title("Natural Sampler Class Distribution")
axes[0].set_xlabel("Class Index")
axes[0].set_ylabel("Count")
axes[0].set_xticklabels(dm_natural.classes, rotation=45, ha="right")
axes[0].grid(axis="y", linestyle="--", alpha=0.7)

# Weighted Sampler Distribution
sns.countplot(x=weighted_labels, ax=axes[1])
axes[1].set_title("Weighted Sampler Class Distribution")
axes[1].set_xlabel("Class Index")
axes[1].set_ylabel("Count")
axes[1].set_xticklabels(dm_weighted.classes, rotation=45, ha="right")
axes[1].grid(axis="y", linestyle="--", alpha=0.7)

plt.tight_layout()

From the histograms above, we can see that the weighted random sampler is more appropriate than the natural sampler for our case as it results in a more balanced distribution of classes from a sample.