### 🧮 Class Distribution Analysis

In this notebook, you'll analyze the **class distribution** in the training dataset.
Your goal is to count how many samples exist for each class and visualize the results.
This helps detect **class imbalance** — which is common in real-world datasets!

👉 You only need to complete one part: counting how many images there are for each class.


In [None]:
# 📦 Imports
from data_utils import GroceryDataset, get_classes, TRAIN_CSV
import matplotlib.pyplot as plt

# 📂 Load class names and dataset
class_names = get_classes()
dataset = GroceryDataset(csv_file=TRAIN_CSV)

# 📝 TASK: Count how many samples there are for each class
# - Loop over the dataset
# - Extract the label for each sample
# - Count how often each label appears
# - Store the result in a dictionary

# ❗ Replace this with your own code:
label_counts = {}  # TODO: fill this with counts per class

for _, label in dataset:
    if label in label_counts:
        label_counts[label] += 1
    else:
        label_counts[label] = 1

# ✅ After you're done, we’ll plot the class distribution

# 🎨 Prepare data for plotting
class_ids = sorted(label_counts.keys())
counts = [label_counts[class_id] for class_id in class_ids]
names = [class_names[class_id] for class_id in class_ids]

# 📊 Plot the bar chart
plt.figure(figsize=(12, 6))
plt.bar(names, counts)
plt.xticks(rotation=45, ha='right')
plt.title("Class Distribution in Training Set")
plt.xlabel("Class")
plt.ylabel("Number of Samples")
plt.tight_layout()
plt.show()
