# Experiment the Data Loader

I wrote a script to validate the data loader for the MedMNIST dataset, as I was getting weird results from the model training. I wanted to ensure that the data loader was working as expected and that the data was being loaded correctly.

In [4]:
from medmnist import ChestMNIST
import torch
from medmnist import INFO
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision.transforms as transforms

# Set matplotlib to display inline in notebook
plt.style.use("default")

# Load the datasets
train_dataset = ChestMNIST(split="train", download=True, size=64)
val_dataset = ChestMNIST(split="val", download=True, size=64)
test_dataset = ChestMNIST(split="test", download=True, size=64)

# Print dataset sizes
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Get the chest info and labels
chest_info = INFO["chestmnist"]
CHEST_CLASSES = [chest_info["label"][str(i)] for i in range(14)]

# Create transform for the images
transform = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

# 1. Visualize sample images with their labels
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flatten()):
    img, label = train_dataset[i]
    img_tensor = transform(img)
    ax.imshow(img_tensor.squeeze(), cmap="gray")

    # Convert binary vector to condition names
    if isinstance(label, (torch.Tensor, np.ndarray)):
        conditions = [CHEST_CLASSES[j] for j, v in enumerate(label) if v]
        label_text = "\n".join(conditions) if conditions else "Normal"
    else:
        label_text = CHEST_CLASSES[label]

    ax.set_title(f"Condition:\n{label_text}", fontsize=8)
    ax.axis("off")

plt.suptitle("Sample Chest X-ray Images", y=1.05)
plt.tight_layout()
plt.show()

# 2. Analyze class distribution
labels = [label for _, label in train_dataset]
if isinstance(labels[0], (torch.Tensor, np.ndarray)):
    # For multi-label case
    labels_array = np.array(labels)
    label_counts = np.sum(labels_array, axis=0)

    # Create distribution plot
    plt.figure(figsize=(15, 6))
    sns.barplot(x=CHEST_CLASSES, y=label_counts)
    plt.title(
        f"Class Distribution in Training Set\nTotal samples: {len(train_dataset)}"
    )
    plt.xlabel("Condition")
    plt.ylabel("Count")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.show()

    # Print statistics
    print("\nLabel Statistics:")
    print(f"Total number of samples: {len(train_dataset)}")
    for i, count in enumerate(label_counts):
        percentage = (count / len(train_dataset)) * 100
        print(f"{CHEST_CLASSES[i]}: {count} ({percentage:.1f}%)")

    # 3. Optional: Print class imbalance ratio
    max_count = max(label_counts)
    min_count = min(label_counts)
    imbalance_ratio = max_count / min_count
    print(f"\nImbalance ratio (max/min): {imbalance_ratio:.2f}")
    print(f"Most common condition: {CHEST_CLASSES[np.argmax(label_counts)]}")
    print(f"Least common condition: {CHEST_CLASSES[np.argmin(label_counts)]}")

ImportError: cannot import name 'ChestMNIST' from 'medmnist' (/home/a/.cache/pypoetry/virtualenvs/mnist-chest-xray-classification-TNGwvxiU-py3.12/lib/python3.12/site-packages/medmnist/__init__.py)