In [None]:
pip install pydicom nibabel numpy torch torchvision segmentation-models-pytorch scikit-learn

Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install albumentations

Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install pandas

Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import numpy as np
import pydicom
import nibabel as nib
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import random
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn as nn
from scipy.ndimage import zoom
from operator import itemgetter
from sklearn.metrics import jaccard_score

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

num_classes = 7
batch_size = 8
img_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Hip/Images"
mask_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Hip/Annotations"
metadata_path = "/home/ealam/JHIR_Hip_Knee_Datasets/Hip/segmentation_with_racegender.csv"
test_augmentations = A.Compose([
    A.Resize(height=256, width=256),
    A.Normalize(mean=(0.485,), std=(0.229,)),
    ToTensorV2(),
])

class MulticlassHipSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, metadata_df, paired_files, num_classes, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.metadata_df = metadata_df
        self.paired_files = paired_files
        self.num_classes = num_classes
        self.transforms = transforms
        self.preprocessing = preprocessing
        self.used_pairs = set()

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        image_file, mask_file = self.paired_files[idx]
        if not os.path.exists(os.path.join(self.mask_root, mask_file)):
            return None

        dicom_image = pydicom.dcmread(os.path.join(self.img_root, image_file))
        image = dicom_image.pixel_array.astype(np.float32)
        annotation = nib.load(os.path.join(self.mask_root, mask_file))
        annotation_data = annotation.get_fdata()
        if len(annotation_data.shape) == 3:
            annotation_data = annotation_data[:, :, 0]

        annotation_data = self.calculate_flipped_rotated_mask(annotation_data)

        if annotation_data.ndim > 2 and annotation_data.shape[-1] != 1:
            raise ValueError('Mask has multiple channels')

        if image.shape != annotation_data.shape:
            zoom_factors = np.array(image.shape) / np.array(annotation_data.shape)
            annotation_data = zoom(annotation_data, zoom_factors, order=0)

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=annotation_data)
            image = transformed["image"]
            annotation_data = transformed["mask"]

        annotation_data_onehot = self.one_hot_encode(annotation_data)

        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=annotation_data_onehot)
            image = transformed["image"]
            annotation_data_onehot = transformed["mask"]

        patient_id = int(float(image_file.split(".")[0]))
        racegender_info = self.metadata_df.loc[self.metadata_df['id'] == patient_id]['racegender'].values
        racegender = racegender_info[0] if racegender_info else 'Unknown'

        return image, annotation_data_onehot, racegender

    def one_hot_encode(self, mask):
        one_hot_mask = np.zeros((self.num_classes, *mask.shape), dtype=np.float32)
        for class_idx in range(self.num_classes):
            one_hot_mask[class_idx][mask == class_idx] = 1.0
        return one_hot_mask

    def calculate_flipped_rotated_mask(self, mask):
        rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        flipped_rotated_mask = cv2.flip(rotated_mask, 1)
        return flipped_rotated_mask

metadata_df = pd.read_csv(metadata_path)

image_files = sorted(os.listdir(img_root))
mask_files = sorted(os.listdir(mask_root))

paired_files = []

for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    mask_file = f"{image_id}.nii.gz"
    if mask_file in mask_files:
        paired_files.append((image_file, mask_file))

random.shuffle(paired_files)

train_size = int(0.7 * len(paired_files))
valid_size = int(0.1 * len(paired_files))
test_size = len(paired_files) - train_size - valid_size

train_pairs = paired_files[:train_size]
valid_pairs = paired_files[train_size:train_size + valid_size]
test_pairs = paired_files[train_size + valid_size:]

train_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, train_pairs, num_classes,
    transforms=test_augmentations
)

model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",
    in_channels=1,
    classes=num_classes,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

valid_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, valid_pairs, num_classes,
    transforms=test_augmentations
)

valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=2)

num_epochs = 50

racegenders = metadata_df['racegender'].unique()

# list to store indices of each racegender
racegender_data_indices = {racegender: [] for racegender in racegenders}

# Populate the list with indices of paired_files belonging to each racegender
for idx, (_, _, racegender) in enumerate(train_set):
    racegender_data_indices[racegender].append(idx)

# list to store the interleaved indices
interleaved_indices = []

# Interleave indices based on racegenders
for _ in range(len(train_set) // len(racegenders)):
    for racegender in racegenders:
        interleaved_indices.extend(racegender_data_indices[racegender])

# sampler using interleaved indices
sampler = torch.utils.data.sampler.SubsetRandomSampler(interleaved_indices)

for epoch in range(num_epochs):
    # custom sampler for DataLoader to achieve the racegender pattern
    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=2
    )

    model.train()

    epoch_iou_list = []

    for batch_idx, (images, masks, _) in enumerate(train_loader):
        if images is None:
            continue

        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.argmax(dim=1))
        loss.backward()
        optimizer.step()


        predicted_masks = torch.argmax(outputs, dim=1)
        batch_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )
        epoch_iou_list.append(batch_iou)


    epoch_iou_avg = np.mean(epoch_iou_list)

    model.eval()
    valid_iou_list = []

    for batch_idx, (images, masks, racegender) in enumerate(valid_loader):
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(images)
        predicted_masks = torch.argmax(outputs, dim=1)

        valid_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )

        valid_iou_list.append(valid_iou)

    valid_iou_avg = np.mean(valid_iou_list)

    print(f"Epoch [{epoch + 1}/{num_epochs}] - Train IoU: {epoch_iou_avg:.4f} - Validation IoU: {valid_iou_avg:.4f}")


Epoch [1/50] - Train IoU: 0.0892 - Validation IoU: 0.0602
Epoch [2/50] - Train IoU: 0.1282 - Validation IoU: 0.1177
Epoch [3/50] - Train IoU: 0.1814 - Validation IoU: 0.1518
Epoch [4/50] - Train IoU: 0.2204 - Validation IoU: 0.1869
Epoch [5/50] - Train IoU: 0.2857 - Validation IoU: 0.2272
Epoch [6/50] - Train IoU: 0.3541 - Validation IoU: 0.2785
Epoch [7/50] - Train IoU: 0.4238 - Validation IoU: 0.3346
Epoch [8/50] - Train IoU: 0.4863 - Validation IoU: 0.3707
Epoch [9/50] - Train IoU: 0.5346 - Validation IoU: 0.4270
Epoch [10/50] - Train IoU: 0.5581 - Validation IoU: 0.4675
Epoch [11/50] - Train IoU: 0.6064 - Validation IoU: 0.4980
Epoch [12/50] - Train IoU: 0.6271 - Validation IoU: 0.5364
Epoch [13/50] - Train IoU: 0.6349 - Validation IoU: 0.5732
Epoch [14/50] - Train IoU: 0.6883 - Validation IoU: 0.6046
Epoch [15/50] - Train IoU: 0.6723 - Validation IoU: 0.6157
Epoch [16/50] - Train IoU: 0.7076 - Validation IoU: 0.6369
Epoch [17/50] - Train IoU: 0.7318 - Validation IoU: 0.6641
Epoch 