# CelluScan 🧫

CelluScan is a deep learning classification model designed to classify blood cells, mainly white blood cells, and detect abnormalities.

This notebook covers the full pipeline from data cleaning and augmentation, through dataloader creation, to model training and evaluation.

## Features

- **Classifies multiple blood cell types** (e.g., Neutrophil, Lymphocyte, Monocyte, etc.)
- **Handles class imbalance** via augmentation
- **Easily extensible** for new classes or abnormality detection

## Workflow Overview

1. Data cleaning and balancing
2. Data augmentation for minority classes
3. Dataloader creation and splitting
4. Model training and evaluation


## 0. Imports & Colors

We import all necessary libraries for data handling, augmentation, and model building.

We also define a set of ANSI color codes for improved terminal output readability.


In [27]:
import torch
import torchvision

from torch import nn
from torchvision import transforms
from going_modular.GPU_check import GPU_check

# ANSI escape codes for colors in Python
colors = {
    "reset": "\033[0m",

    # Regular Colors
    "black": "\033[30m",
    "red": "\033[31m",
    "green": "\033[32m",
    "yellow": "\033[33m",
    "blue": "\033[34m",
    "magenta": "\033[35m",
    "cyan": "\033[36m",
    "white": "\033[37m",

    # Bright Colors
    "bright_black": "\033[90m",
    "bright_red": "\033[91m",
    "bright_green": "\033[92m",
    "bright_yellow": "\033[93m",
    "bright_blue": "\033[94m",
    "bright_magenta": "\033[95m",
    "bright_cyan": "\033[96m",
    "bright_white": "\033[97m",

    # Background Colors
    "bg_black": "\033[40m",
    "bg_red": "\033[41m",
    "bg_green": "\033[42m",
    "bg_yellow": "\033[43m",
    "bg_blue": "\033[44m",
    "bg_magenta": "\033[45m",
    "bg_cyan": "\033[46m",
    "bg_white": "\033[47m",

    # Bright Backgrounds
    "bg_bright_black": "\033[100m",
    "bg_bright_red": "\033[101m",
    "bg_bright_green": "\033[102m",
    "bg_bright_yellow": "\033[103m",
    "bg_bright_blue": "\033[104m",
    "bg_bright_magenta": "\033[105m",
    "bg_bright_cyan": "\033[106m",
    "bg_bright_white": "\033[107m",

    # Styles
    "bold": "\033[1m",
    "dim": "\033[2m",
    "italic": "\033[3m",
    "underline": "\033[4m",
    "blink": "\033[5m",
    "reverse": "\033[7m",
    "hidden": "\033[8m"
}

device = GPU_check()
device

CUDA available: True
CUDA devices: 1
Current device: 0
Device name: NVIDIA GeForce RTX 3060 Laptop GPU


'cuda'

## 1. Cleaning the Data

### Initial Collected Data

Below is the class distribution in the raw dataset:

| Class Name           | Count |
| -------------------- | ----- |
| Immature Granulocyte | 151   |
| Promyelocytes        | 592   |
| Myeloblast           | 1,000 |
| Metamyelocytes       | 1,015 |
| Myelocytes           | 1,137 |
| Erythroblast         | 1,551 |
| Band Neutrophil      | 1,634 |
| Basophil             | 1,653 |
| Platlets             | 2,348 |
| Segmented Neutrophil | 2,646 |
| Monocyte             | 5,046 |
| Neutrophil           | 6,779 |
| Eosinophil           | 7,141 |
| Lymphocyte           | 8,685 |

### Our Goal

- **Target:** At least 4,500 samples per class for balanced training.
- **How:** Duplicate and augment (using `TrivialAugmentWide`) only for classes below the target.
- **Result:** All classes will have at least 4,500 images, improving model robustness and reducing bias.

The following code block performs this balancing and augmentation automatically.


In [6]:
from torchvision import transforms
from torchvision.datasets import ImageFolder

from pathlib import Path
from PIL import Image
from tqdm import tqdm
import random

# Set paths
DATA_DIR = Path("Blood cells datasets/")
TARGET_COUNT = 4500 # Target minimum number of samples per class

# Define the augmentation transformation
augment_transform = transforms.TrivialAugmentWide()

# Load dataset
base_dataset = ImageFolder(root=DATA_DIR)
class_to_idx = base_dataset.class_to_idx # e.g., {"Basophil": 0, ...}

# Create a mapping from class index to class name
idx_to_class = {v: k for k, v in class_to_idx.items()}

# Count current number of images in each class
from collections import defaultdict
class_counts = defaultdict(int) # Creates an empty dictionary to store image counts per class, initialized to 0.
for path, label in base_dataset.samples: # loop over every image and retrieves its label
    class_name = idx_to_class[label] # Convert label index to class name
    class_counts[class_name] += 1  # Increment count for that class

# Test print class counts
print (class_counts)
print(f"\n{colors['green']}[INFO] Starting augmentation...{colors['reset']}")


# Loop over classes and duplicate if needed
for class_name, count in class_counts.items():
    if count >= TARGET_COUNT:
        continue  # Already balanced

    folder_path = DATA_DIR / class_name
    images = list(folder_path.glob("*"))  # All image files in the class folder

    to_generate = TARGET_COUNT - count
    print(f"Augmenting {class_name}: current={count}, generating={to_generate}")

    for i in tqdm(range(to_generate)):
        # Randomly pick an existing image
        src_path = random.choice(images)
        with Image.open(src_path) as img:
            img = img.convert("RGB")  # ensure consistency
            augmented = augment_transform(img)

            # Create new file name
            base_name = src_path.stem
            new_filename = f"{base_name}_aug{i}.jpg"
            new_path = folder_path / new_filename

            # Save
            augmented.save(new_path)

print(f"\n{colors['blue']}[INFO]Augmentation complete. All classes now have at least 4,500 samples.{colors['reset']}")


defaultdict(<class 'int'>, {'Band Neutrophil': 3449, 'Basophil': 1653, 'Eosinophil': 7141, 'Erythroblast': 1551, 'Immature Granulocyte': 151, 'Lymphocyte': 8685, 'Metamyelocytes': 1015, 'Monocyte': 5046, 'Myeloblast': 1000, 'Myelocytes': 1137, 'Neutrophil': 6779, 'Platelets': 2348, 'Promyelocytes': 592, 'Segmented Neutrophil': 2646})

[32m[INFO] Starting augmentation...[0m
Augmenting Band Neutrophil: current=3449, generating=1051


100%|██████████| 1051/1051 [00:09<00:00, 113.39it/s]


Augmenting Basophil: current=1653, generating=2847


100%|██████████| 2847/2847 [00:22<00:00, 127.06it/s]


Augmenting Erythroblast: current=1551, generating=2949


100%|██████████| 2949/2949 [00:22<00:00, 132.86it/s]


Augmenting Immature Granulocyte: current=151, generating=4349


100%|██████████| 4349/4349 [00:13<00:00, 324.73it/s]


Augmenting Metamyelocytes: current=1015, generating=3485


100%|██████████| 3485/3485 [00:17<00:00, 202.56it/s]


Augmenting Myeloblast: current=1000, generating=3500


100%|██████████| 3500/3500 [00:17<00:00, 200.56it/s]


Augmenting Myelocytes: current=1137, generating=3363


100%|██████████| 3363/3363 [00:15<00:00, 212.41it/s]


Augmenting Platelets: current=2348, generating=2152


100%|██████████| 2152/2152 [00:19<00:00, 110.28it/s]


Augmenting Promyelocytes: current=592, generating=3908


100%|██████████| 3908/3908 [00:19<00:00, 198.90it/s]


Augmenting Segmented Neutrophil: current=2646, generating=1854


100%|██████████| 1854/1854 [00:20<00:00, 90.70it/s] 


[34m[INFO]Augmentation complete. All classes now have at least 4,500 samples.[0m





## 2. Get Dataloaders

After cleaning and balancing the data, we prepare the dataset for model training.

### Steps:

1. **Resize and transform images** to a consistent shape for model input.
2. **Split the dataset** into training (80%) and testing (20%) sets.
3. **Create PyTorch dataloaders** for efficient batch processing.

This ensures that our model receives well-structured and balanced data for both training and evaluation.


In [21]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

DATA_DIR = Path("Blood cells datasets/")

# Define a transform that resizes all images to the same shape
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

# Get dataset using ImageFolder
dataset = datasets.ImageFolder(
    root= DATA_DIR,
    transform= transform
)

# Define train and test lengths (80% train, 20% test)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

# Random split dataset into train and test datasets
train_dataset, test_dataset = random_split(dataset, lengths=[train_size, test_size])

# Get dataloader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle= True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle= False)

# Get class names
class_names = dataset.classes

print (f"Training dataset = {train_dataset}")
print (f"Testing dataset = {test_dataset}")
print (f"Training dataloader = {train_dataloader}")
print (f"Testing dataloader = {test_dataloader}")
print (f"Class names = {class_names}")

Training dataset = <torch.utils.data.dataset.Subset object at 0x000001D09CB41480>
Testing dataset = <torch.utils.data.dataset.Subset object at 0x000001D09CB43580>
Training dataloader = <torch.utils.data.dataloader.DataLoader object at 0x000001D09CE0F820>
Testing dataloader = <torch.utils.data.dataloader.DataLoader object at 0x000001D09CE0F130>
Class names = ['Band Neutrophil', 'Basophil', 'Eosinophil', 'Erythroblast', 'Immature Granulocyte', 'Lymphocyte', 'Metamyelocytes', 'Monocyte', 'Myeloblast', 'Myelocytes', 'Neutrophil', 'Platelets', 'Promyelocytes', 'Segmented Neutrophil']


## 2.1 Using a Subset of the Data

For rapid prototyping and to avoid overfitting during early experimentation, we use only 20% of the data from each class.

This approach allows for faster training and debugging, especially when working with large datasets.

The function below samples 20% of each class and returns a corresponding dataloader.


In [25]:
from torch.utils.data import DataLoader, Subset
from collections import defaultdict
import random

def get_20_percent_dataloader(dataset: torch.utils.data.Dataset, 
                            batch_size:int =32, 
                            shuffle:bool =True,
                            seed:int =42) -> DataLoader:
    """
    Returns a DataLoader containing 20% of each class from the given dataset.

    Args:
        dataset (torch.utils.data.Dataset or torch.utils.data.Subset): A dataset object (e.g. ImageFolder or Subset).
        batch_size (int): Batch size for the returned DataLoader.
        shuffle (bool): Whether to shuffle the DataLoader output.
        seed (int): Random seed for reproducibility.

    Returns:
        DataLoader: A PyTorch DataLoader with 20% of each class from the dataset.
    """
    random.seed(seed)

    # Get true targets and indices from base dataset
    if isinstance(dataset, Subset):
        base_dataset = dataset.dataset
        base_indices = dataset.indices
        targets = [base_dataset.targets[i] for i in base_indices]
    else:
        base_dataset = dataset
        base_indices = list(range(len(dataset)))
        targets = base_dataset.targets
        

    # Group indices by class label
    class_to_idx = defaultdict(list)
    for base_idx, label in zip(base_indices, targets):
        class_to_idx[label].append(base_idx)

    # Sample 20% of each class
    selected_indices = []
    for label, idxs in class_to_idx.items():
        n = max(1, int(0.2 * len(idxs)))
        selected_indices.extend(random.sample(idxs, n))

    # Create a subset and dataloader
    subset = Subset(base_dataset, selected_indices)
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=shuffle)

    return dataloader



# Create dataloaders with 20% of the data in each class (to avoid dropouts)
train_dataloader_20_percent = get_20_percent_dataloader(dataset= train_dataset,
                                                        batch_size= 32,
                                                        shuffle= True)

test_dataloader_20_percent = get_20_percent_dataloader(dataset= test_dataset,
                                                    batch_size= 32,
                                                    shuffle= False)

train_dataloader_20_percent, test_dataloader_20_percent

(<torch.utils.data.dataloader.DataLoader at 0x1d09e33b550>,
 <torch.utils.data.dataloader.DataLoader at 0x1d09e33b3d0>)

In [26]:
# Inspect one batch
images_batch, labels_batch = next(iter(train_dataloader_20_percent))

images_batch.shape, labels_batch.shape

(torch.Size([32, 3, 224, 224]), torch.Size([32]))