In [18]:
# Import necessary libraries for file handling, data manipulation, and visualization
import os
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Import libraries for working with images and transformations
from PIL import Image
import cv2 as cv

# Import PyTorch modules for model building, data handling, and evaluation
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.models.quantization as quant_models
from torch.utils.data import Dataset, DataLoader, Subset
from torchinfo import summary

# Import libraries for machine learning metrics and model evaluation
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve, mean_squared_error, mean_absolute_error, r2_score
import torchmetrics
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

# Set the seed.
seed = 42
torch.manual_seed(seed)

<torch._C.Generator at 0x753ed31df230>

In [19]:
data_dir="/workspace/cp-anemia-detection/data/cp-anemia/"
weights_dir="/workspace/cp-anemia-detection/data/notebooks/weights/"
anemic_dir=data_dir+"/Anemic"
non_anemic_dir=data_dir+"/Non-anemic"
signature = "02042024"

In [20]:
data_sheet_path = data_dir+"Anemia_Data_Collection_Sheet.csv"
data_sheet = pd.read_csv(data_sheet_path)
display(data_sheet)

Unnamed: 0,IMAGE_ID,HB_LEVEL,Severity,Age(Months),GENDER,REMARK,HOSPITAL,CITY/TOWN,MUNICIPALITY/DISTRICT,REGION,COUNTRY
0,Image_001,9.80,Moderate,6,Female,Anemic,Nkawie-Toase Government Hospital,Nkawie-Toase,Atwima Nwabiagya South,Ashanti,Ghana
1,Image_002,9.90,Moderate,24,Male,Anemic,Ejusu Government Hospital,Ejusu,Ejusu Municipality,Ashanti,Ghana
2,Image_003,11.10,Non-Anemic,24,Female,Non-anemic,Ahmadiyya Muslim Hospital,Tachiman,Techiman Municipality,Bono-East,Ghana
3,Image_004,12.50,Non-Anemic,12,Male,Non-anemic,Ahmadiyya Muslim Hospital,Tachiman,Techiman Municipality,Bono-East,Ghana
4,Image_005,9.90,Moderate,24,Male,Anemic,Sunyani Municipal Hospital,Sunyani,Sunyani Municipality,Bono,Ghana
...,...,...,...,...,...,...,...,...,...,...,...
705,Image_706,12.80,Non-Anemic,48,Male,Non-anemic,Bolgatanga Regional Hospital,Bolgatanga,Bolgatanga Municipality,Upper East,Ghana
706,Image_707,11.47,Non-Anemic,48,Female,Non-anemic,Ahmadiyya Muslim Hospital,Tachiman,Techiman Municipality,Bono-East,Ghana
707,Image_708,11.60,Non-Anemic,60,Male,Non-anemic,Komfo Anokye Teaching Hospital,Kumasi,Kumasi Metropolitan,Ashanti,Ghana
708,Image_709,12.10,Non-Anemic,48,Male,Non-anemic,Bolgatanga Regional Hospital,Bolgatanga,Bolgatanga Municipality,Upper East,Ghana


In [21]:
# Mapping diagnosis to severity
severity_mapping = {
    "Non-Anemic": 0,
    "Mild": 1,
    "Moderate": 2,
    "Severe": 3,
}

data_sheet['Severity'] = data_sheet['Severity'].map(severity_mapping)
display(data_sheet)

Unnamed: 0,IMAGE_ID,HB_LEVEL,Severity,Age(Months),GENDER,REMARK,HOSPITAL,CITY/TOWN,MUNICIPALITY/DISTRICT,REGION,COUNTRY
0,Image_001,9.80,2,6,Female,Anemic,Nkawie-Toase Government Hospital,Nkawie-Toase,Atwima Nwabiagya South,Ashanti,Ghana
1,Image_002,9.90,2,24,Male,Anemic,Ejusu Government Hospital,Ejusu,Ejusu Municipality,Ashanti,Ghana
2,Image_003,11.10,0,24,Female,Non-anemic,Ahmadiyya Muslim Hospital,Tachiman,Techiman Municipality,Bono-East,Ghana
3,Image_004,12.50,0,12,Male,Non-anemic,Ahmadiyya Muslim Hospital,Tachiman,Techiman Municipality,Bono-East,Ghana
4,Image_005,9.90,2,24,Male,Anemic,Sunyani Municipal Hospital,Sunyani,Sunyani Municipality,Bono,Ghana
...,...,...,...,...,...,...,...,...,...,...,...
705,Image_706,12.80,0,48,Male,Non-anemic,Bolgatanga Regional Hospital,Bolgatanga,Bolgatanga Municipality,Upper East,Ghana
706,Image_707,11.47,0,48,Female,Non-anemic,Ahmadiyya Muslim Hospital,Tachiman,Techiman Municipality,Bono-East,Ghana
707,Image_708,11.60,0,60,Male,Non-anemic,Komfo Anokye Teaching Hospital,Kumasi,Kumasi Metropolitan,Ashanti,Ghana
708,Image_709,12.10,0,48,Male,Non-anemic,Bolgatanga Regional Hospital,Bolgatanga,Bolgatanga Municipality,Upper East,Ghana


In [22]:
# Define data augmentations or transformations
transform = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.RandomHorizontalFlip(p=np.random.rand()),
    transforms.RandomVerticalFlip(p=np.random.rand()),
    transforms.RandomRotation(degrees=np.random.randint(0, 360)),
    transforms.RandomAffine(degrees=np.random.randint(0, 360)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Custom dataset class
class CPAnemiCDataset(Dataset):
    def __init__(self, dir, df, transform=None):
        self.dir = dir
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row['IMAGE_ID']
        img_folder = row['REMARK']
        img_path = os.path.join(self.dir, img_folder, img_id + ".png")
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        multiclass_label = torch.tensor(row['Severity'])
        hb_level = torch.tensor(row['HB_LEVEL'])

        return img, multiclass_label, hb_level

    # Load the dataset
image_dataset = CPAnemiCDataset(data_dir, data_sheet, transform=transform)
train_dataset, test_dataset = train_test_split(image_dataset, test_size=0.20, shuffle=True)

print(f"Image Dataset Size (All): {len(image_dataset)}, \
        Train Size: {len(train_dataset)}, \
        Test Size: {len(test_dataset)}")

BATCH_SIZE = 32
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Image Dataset Size (All): 710,         Train Size: 568,         Test Size: 142


In [29]:
# Default device
device = torch.device('cpu')

# Check for CUDA availability
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    print("CUDA is not available, using CPU.")

print(f"Selected device: {device}")

CUDA is not available, using CPU.
Selected device: cpu


In [30]:
def get_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    model_size = "Model Size: %.2f MB" %(os.path.getsize("tmp.pt")/1e6)
    os.remove('tmp.pt')
    return model_size

# Static Weighting Function. Set eta_class to desired importance (Classification > .5, Regression < .5, Equal == .5)
def sw_loss(loss_class, loss_reg, eta_class=0.5):
    eta_reg = 1 - eta_class
    total_loss = (eta_class * loss_class) + (eta_reg * loss_reg)
    return total_loss

In [44]:
import torch
import torch.nn.functional as F

def train(dataloader, model, loss_fn_class, loss_fn_reg, optimizer):
    model.train()
    total_loss = 0
    total_ce_loss = 0
    total_mse_loss = 0
    correct = 0
    total_samples = 0
    total_mae = 0

    for _, (img, multiclass, hb_level) in enumerate(dataloader):
        img = img.to(device)
        multiclass = multiclass.to(device).long()  # Ensure correct type for CrossEntropyLoss
        hb_level = hb_level.to(device).unsqueeze(1).float()

        optimizer.zero_grad()

        # Forward pass
        class_pred, reg_pred = model(img)

        # Compute losses
        ce_loss = loss_fn_class(class_pred, multiclass)  # CrossEntropy for classification
        mse_loss = loss_fn_reg(reg_pred, hb_level)  # MSE for regression
        loss = sw_loss(ce_loss, mse_loss, 0.7)  # Weighted loss

        # Backpropagation
        loss.backward()
        optimizer.step()

        # Track total losses
        total_loss += loss.item()
        total_ce_loss += ce_loss.item()
        total_mse_loss += mse_loss.item()

        # Compute classification accuracy
        class_probs = F.softmax(class_pred, dim=1)
        highest_prob_class = torch.argmax(class_probs, dim=1)
        correct += (highest_prob_class == multiclass).sum().item()
        total_samples += multiclass.size(0)

        # Compute regression MAE
        total_mae += torch.abs(reg_pred - hb_level).sum().item()

    avg_loss = total_loss / len(dataloader)
    avg_ce_loss = total_ce_loss / len(dataloader)
    avg_mse_loss = total_mse_loss / len(dataloader)
    classification_accuracy = correct / total_samples
    regression_mae = total_mae / total_samples

    return avg_loss, classification_accuracy, regression_mae, avg_ce_loss, avg_mse_loss

In [46]:
def eval(dataloader, model, loss_fn_class, loss_fn_reg):
    model.eval()
    total_loss = 0
    total_ce_loss = 0
    total_mse_loss = 0
    correct = 0
    total_samples = 0
    total_mae = 0

    with torch.no_grad():
        for _, (img, multiclass, hb_level) in enumerate(dataloader):
            img = img.to(device)
            multiclass = multiclass.to(device).long()  # Ensure correct type for CrossEntropyLoss
            hb_level = hb_level.to(device).unsqueeze(1).float()

            # Forward pass
            class_pred, reg_pred = model(img)

            # Compute losses
            ce_loss = loss_fn_class(class_pred, multiclass)  # CrossEntropy for classification
            mse_loss = loss_fn_reg(reg_pred, hb_level)  # MSE for regression
            loss = sw_loss(ce_loss, mse_loss, 0.7)  # Weighted loss

            # Track total losses
            total_loss += loss.item()
            total_ce_loss += ce_loss.item()
            total_mse_loss += mse_loss.item()

            # Compute classification accuracy
            class_probs = F.softmax(class_pred, dim=1)
            highest_prob_class = torch.argmax(class_probs, dim=1)
            correct += (highest_prob_class == multiclass).sum().item()
            total_samples += multiclass.size(0)

            # Compute regression MAE
            total_mae += torch.abs(reg_pred - hb_level).sum().item()

    avg_loss = total_loss / len(dataloader)
    avg_ce_loss = total_ce_loss / len(dataloader)
    avg_mse_loss = total_mse_loss / len(dataloader)
    classification_accuracy = correct / total_samples
    regression_mae = total_mae / total_samples

    return avg_loss, classification_accuracy, regression_mae, avg_ce_loss, avg_mse_loss

In [42]:
class MobileNetMultiOutput(nn.Module):
    def __init__(self):
        super(MobileNetMultiOutput, self).__init__()
        self.mobilenet = models.mobilenet_v2(pretrained=False)
        num_ftrs = self.mobilenet.classifier[1].in_features
        self.mobilenet.classifier = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Linear(128, 5)  # 4-class classification + 1 regression output
        )

    def forward(self, x):
        output = self.mobilenet(x)
        class_output = output[:, :4]  # First 4 values = class probabilities
        reg_output = output[:, 4]  # Last value = Hb level estimate
        return class_output, reg_output  # Return as two separate outputs

# Load the modified MobileNet
model = MobileNetMultiOutput().to(device)

In [None]:
# # Training parameters
# BATCH_SIZE = 32
# EPOCHS = 10
# FOLDS = 5

# # Device configuration
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Initialize model and loss functions
# model = MobileNetMultiOutput().to(device)

# loss_fn_class = torch.nn.CrossEntropyLoss()  # Multi-class classification loss
# loss_fn_reg = torch.nn.MSELoss()  # Regression loss
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# # 5-Fold Cross Validation
# kf = KFold(n_splits=FOLDS, shuffle=True, random_state=42)

# # Directory to save the best model
# weights_dir = "weights"
# os.makedirs(weights_dir, exist_ok=True)

# best_val_acc = -float('inf')  # Track best validation accuracy
# train_metrics_df = []
# val_metrics_df = []

# # Training loop
# for epoch in range(EPOCHS):
#     print(f"\nEpoch {epoch+1}/{EPOCHS}")
#     fold = 1

#     for train_idx, val_idx in kf.split(range(len(image_dataset))):  # FIX: Ensure correct splitting
#         train_subset = Subset(image_dataset, train_idx)
#         val_subset = Subset(image_dataset, val_idx)

#         train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
#         val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)

#         if fold == FOLDS:
#             # Validation phase
#             avg_val_loss, val_acc, val_ce_loss, val_mse_loss = eval(val_loader, model, loss_fn_class, loss_fn_reg)

#             print(f"Validation: Fold {fold} - Total Loss: {avg_val_loss:.4f}, Accuracy: {val_acc:.4f}, CrossEntropy: {val_ce_loss:.4f}, MSE: {val_mse_loss:.4f}")

#             # Save model with the best validation accuracy
#             if val_acc > best_val_acc:
#                 best_val_acc = val_acc
#                 val_metrics_dict = {"Loss": avg_val_loss, "Accuracy": val_acc}
#                 val_metrics_df.append(val_metrics_dict)
#                 torch.save(model.state_dict(), f"{weights_dir}/model_best_accuracy_{signature}.pth")
#                 print(f"Best model saved with Accuracy: {best_val_acc:.4f}")

#         else:
#             # Training phase
#             avg_train_loss, train_acc, train_ce_loss, train_mse_loss = train(train_loader, model, loss_fn_class, loss_fn_reg, optimizer)

#             print(f"Training: Fold {fold} - Total Loss: {avg_train_loss:.4f}, Accuracy: {train_acc:.4f}, CrossEntropy: {train_ce_loss:.4f}, MSE: {train_mse_loss:.4f}")

#             train_metrics_dict = {"Loss": avg_train_loss, "Accuracy": train_acc, "CrossEntropy": train_ce_loss, "MSE":train_mse_loss}
#             train_metrics_df.append(train_metrics_dict)

#         fold += 1  # Move to next fold

# # Ensure `get_model_size()` exists or remove this line
# print(get_model_size(model))


Epoch 1/10
Training: Fold 1 - Total Loss: 27.7721, Accuracy: 0.2113, CrossEntropy: 1.3738, MSE: 66.8065
Training: Fold 2 - Total Loss: 14.9889, Accuracy: 0.3363, CrossEntropy: 1.3414, MSE: 31.4981
Training: Fold 3 - Total Loss: 6.5687, Accuracy: 0.3785, CrossEntropy: 1.2595, MSE: 12.0592
Training: Fold 4 - Total Loss: 3.3187, Accuracy: 0.4173, CrossEntropy: 1.3444, MSE: 6.0949
Validation: Fold 5 - Total Loss: 4.5402, Accuracy: 0.3099, CrossEntropy: 1.1267, MSE: 20.7433
Best model saved with Accuracy: 0.3099

Epoch 2/10
Training: Fold 1 - Total Loss: 2.4814, Accuracy: 0.4120, CrossEntropy: 1.2473, MSE: 5.8315
Training: Fold 2 - Total Loss: 2.3825, Accuracy: 0.3961, CrossEntropy: 1.2555, MSE: 7.4170
Training: Fold 3 - Total Loss: 2.3606, Accuracy: 0.3627, CrossEntropy: 1.1308, MSE: 6.4752
Training: Fold 4 - Total Loss: 2.4772, Accuracy: 0.4208, CrossEntropy: 1.1836, MSE: 3.9011
Validation: Fold 5 - Total Loss: 2.4764, Accuracy: 0.3873, CrossEntropy: 1.0486, MSE: 7.7985
Best model saved 

In [48]:
# Training parameters
BATCH_SIZE = 32
EPOCHS = 150
FOLDS = 5

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model and loss functions
model = MobileNetMultiOutput().to(device)

loss_fn_class = torch.nn.CrossEntropyLoss()  # Multi-class classification loss
loss_fn_reg = torch.nn.MSELoss()  # Regression loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# 5-Fold Cross Validation
kf = KFold(n_splits=FOLDS, shuffle=True, random_state=42)

# Directory to save the best model
weights_dir = "weights"
os.makedirs(weights_dir, exist_ok=True)

best_val_acc = -float('inf')  # Track best validation accuracy
train_metrics_df = []
val_metrics_df = []

# Training loop
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    fold = 1

    for train_idx, val_idx in kf.split(range(len(image_dataset))):  # FIX: Ensure correct splitting
        train_subset = Subset(image_dataset, train_idx)
        val_subset = Subset(image_dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)

        if fold == FOLDS:
            # Validation phase
            avg_val_loss, val_acc, val_mae_loss, val_ce_loss, val_mse_loss = eval(val_loader, model, loss_fn_class, loss_fn_reg)

            print(f"Validation: Fold {fold} - Total Loss: {avg_val_loss:.4f}, Accuracy: {val_acc:.4f}, CrossEntropy: {val_ce_loss:.4f}, MSE: {val_mse_loss:.4f}, MAE: {val_mae_loss:.4f}")

            # Save model with the best validation accuracy
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                val_metrics_dict = {"Loss": avg_val_loss, "Accuracy": val_acc}
                val_metrics_df.append(val_metrics_dict)
                torch.save(model.state_dict(), f"{weights_dir}/model_best_accuracy_{signature}.pth")
                print(f"Best model saved with Accuracy: {best_val_acc:.4f}")

        else:
            # Training phase
            avg_train_loss, train_acc, train_mae_loss, train_ce_loss, train_mse_loss = train(train_loader, model, loss_fn_class, loss_fn_reg, optimizer)

            print(f"Training: Fold {fold} - Total Loss: {avg_train_loss:.4f}, Accuracy: {train_acc:.4f}, CrossEntropy: {train_ce_loss:.4f}, MSE: {train_mse_loss:.4f}, MAE: {train_mae_loss:.4f}")

            train_metrics_dict = {"Loss": avg_train_loss, "Accuracy": train_acc, "CrossEntropy": train_ce_loss, "MSE":train_mse_loss}
            train_metrics_df.append(train_metrics_dict)

        fold += 1  # Move to next fold

# Ensure `get_model_size()` exists or remove this line
print(get_model_size(model))


Epoch 1/150
Training: Fold 1 - Total Loss: 25.9885, Accuracy: 0.3169, CrossEntropy: 1.3106, MSE: 83.5700, MAE: 279.6700
Training: Fold 2 - Total Loss: 12.7453, Accuracy: 0.3292, CrossEntropy: 1.3104, MSE: 39.4269, MAE: 186.0849
Training: Fold 3 - Total Loss: 5.0892, Accuracy: 0.3345, CrossEntropy: 1.2994, MSE: 13.9322, MAE: 102.1539
Training: Fold 4 - Total Loss: 2.8084, Accuracy: 0.3715, CrossEntropy: 1.2913, MSE: 6.3484, MAE: 65.1593
Validation: Fold 5 - Total Loss: 5.9686, Accuracy: 0.3099, CrossEntropy: 1.2517, MSE: 16.9746, MAE: 103.4768
Best model saved with Accuracy: 0.3099

Epoch 2/150
Training: Fold 1 - Total Loss: 2.4996, Accuracy: 0.4085, CrossEntropy: 1.2675, MSE: 5.3745, MAE: 58.7573
Training: Fold 2 - Total Loss: 2.3806, Accuracy: 0.4014, CrossEntropy: 1.2389, MSE: 5.0444, MAE: 56.6611
Training: Fold 3 - Total Loss: 2.3515, Accuracy: 0.3715, CrossEntropy: 1.2518, MSE: 4.9173, MAE: 55.6357
Training: Fold 4 - Total Loss: 2.4918, Accuracy: 0.4137, CrossEntropy: 1.2425, MSE:

In [51]:
%%time

model.load_state_dict(torch.load(f"{weights_dir}/model_best_accuracy_{signature}.pth"))
avg_test_loss, test_acc, test_mae_loss, test_ce_loss, test_mse_loss = eval(test_loader, model, loss_fn_class, loss_fn_reg)
print(f"Testing: Total Loss: {avg_test_loss:.4f}, Accuracy: {test_acc:.4f}, CrossEntropy: {test_ce_loss:.4f}, MSE: {test_mse_loss:.4f}, MAE: {test_mae_loss:.4f}")


Testing: Total Loss: 1.7491, Accuracy: 0.7394, CrossEntropy: 0.4724, MSE: 4.7280, MAE: 53.4031
CPU times: user 5.63 s, sys: 5.62 ms, total: 5.63 s
Wall time: 244 ms
