# 🧮 Class Distribution Analysis

In this notebook, we'll analyze the **class distribution** in the training dataset. We'll count how many samples exist for each class and visualize the results to spot potential imbalances.

In [None]:
from data_utils import GroceryDataset, get_classes, TRAIN_CSV
import matplotlib.pyplot as plt
import pandas as pd
from collections import Counter

📂 Load dataset and class label mapping

In [None]:
# Load class names: dict of int → string
class_names = get_classes()

# Load training data
dataset = GroceryDataset(csv_file=TRAIN_CSV)

📊 Count number of samples per class

In [None]:
# Get all labels
labels = [dataset[i][1] for i in range(len(dataset))]

# Count class occurrences
label_counts = Counter(labels)

# Create two lists: class names and corresponding counts
class_ids = sorted(label_counts.keys())
counts = [label_counts[class_id] for class_id in class_ids]
names = [class_names[class_id] for class_id in class_ids]

📈 Plot class distribution as bar chart

In [None]:
plt.figure(figsize=(12, 6))
plt.bar(names, counts)
plt.xticks(rotation=45, ha='right')
plt.title("Class Distribution in Training Set")
plt.xlabel("Class")
plt.ylabel("Number of Samples")
plt.tight_layout()
plt.show()