In [None]:
pip install pydicom nibabel numpy torch torchvision segmentation-models-pytorch scikit-learn

Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install albumentations

Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install pandas

Note: you may need to restart the kernel to use updated packages.


In [None]:

import os
import numpy as np
import pydicom
import nibabel as nib
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import random
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn as nn
from scipy.ndimage import zoom
from operator import itemgetter
from sklearn.metrics import jaccard_score

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

num_classes = 9
batch_size = 8
img_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Images"
mask_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Annotations"
metadata_path = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/segmentation_with_racegender.csv"
test_augmentations = A.Compose([
    A.Resize(height=256, width=256),
    A.Normalize(mean=(0.485,), std=(0.229,)),
    ToTensorV2(),
])

class MulticlassHipSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, metadata_df, paired_files, num_classes, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.metadata_df = metadata_df
        self.paired_files = paired_files
        self.num_classes = num_classes
        self.transforms = transforms
        self.preprocessing = preprocessing
        self.used_pairs = set()

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        image_file, mask_file = self.paired_files[idx]
        if not os.path.exists(os.path.join(self.mask_root, mask_file)):
            return None, None, None

        dicom_image = pydicom.dcmread(os.path.join(self.img_root, image_file))
        image = dicom_image.pixel_array.astype(np.float32)
        annotation = nib.load(os.path.join(self.mask_root, mask_file))
        annotation_data = annotation.get_fdata()
        if len(annotation_data.shape) == 3:
            annotation_data = annotation_data[:, :, 0]

        annotation_data = self.calculate_flipped_rotated_mask(annotation_data)

        if annotation_data.ndim > 2 and annotation_data.shape[-1] != 1:
            raise ValueError('Mask has multiple channels')

        if image.shape != annotation_data.shape:
            zoom_factors = np.array(image.shape) / np.array(annotation_data.shape)
            annotation_data = zoom(annotation_data, zoom_factors, order=0)

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=annotation_data)
            image = transformed["image"]
            annotation_data = transformed["mask"]

        annotation_data_onehot = self.one_hot_encode(annotation_data)

        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=annotation_data_onehot)
            image = transformed["image"]
            annotation_data_onehot = transformed["mask"]

        patient_id = int(float(image_file.split(".")[0]))
        racegender_info = self.metadata_df.loc[self.metadata_df['id'] == patient_id]['racegender'].values
        racegender = racegender_info[0] if racegender_info and racegender_info[0] in racegenders else None

        return image, annotation_data_onehot, racegender

    def one_hot_encode(self, mask):
        one_hot_mask = np.zeros((self.num_classes, *mask.shape), dtype=np.float32)
        for class_idx in range(self.num_classes):
            one_hot_mask[class_idx][mask == class_idx] = 1.0
        return one_hot_mask

    def calculate_flipped_rotated_mask(self, mask):
        rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        flipped_rotated_mask = cv2.flip(rotated_mask, 1)
        return flipped_rotated_mask

metadata_df = pd.read_csv(metadata_path)

image_files = sorted(os.listdir(img_root))
mask_files = sorted(os.listdir(mask_root))

paired_files = []

for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    mask_file = f"{image_id}.nii.gz"
    if mask_file in mask_files:
        paired_files.append((image_file, mask_file))

random.shuffle(paired_files)

train_size = int(0.7 * len(paired_files))
valid_size = int(0.1 * len(paired_files))
test_size = len(paired_files) - train_size - valid_size

train_pairs = paired_files[:train_size]
valid_pairs = paired_files[train_size:train_size + valid_size]
test_pairs = paired_files[train_size + valid_size:]

train_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, train_pairs, num_classes,
    transforms=test_augmentations
)

model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",
    in_channels=1,
    classes=num_classes,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

valid_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, valid_pairs, num_classes,
    transforms=test_augmentations
)

valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=2)

num_epochs = 100

racegenders = metadata_df['racegender'].unique()

# list to store indices of each racegender
racegender_data_indices = {racegender: [] for racegender in racegenders}

# Populate the list with indices of paired_files belonging to each racegender
for idx, (_, _, racegender) in enumerate(train_set):
    if racegender in racegenders:
        racegender_data_indices[racegender].append(idx)

# list to store the interleaved indices
interleaved_indices = []

# Interleave indices based on racegenders
for _ in range(len(train_set) // len(racegenders)):
    for racegender in racegenders:
        interleaved_indices.extend(racegender_data_indices[racegender])

# sampler using interleaved indices
sampler = torch.utils.data.sampler.SubsetRandomSampler(interleaved_indices)

for epoch in range(num_epochs):
    # custom sampler for DataLoader to achieve the racegender pattern
    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=2
    )

    model.train()

    epoch_iou_list = []

    for batch_idx, (images, masks, _) in enumerate(train_loader):
        if images is None:
            continue

        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.argmax(dim=1))
        loss.backward()
        optimizer.step()


        predicted_masks = torch.argmax(outputs, dim=1)
        batch_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )
        epoch_iou_list.append(batch_iou)


    epoch_iou_avg = np.mean(epoch_iou_list)

    model.eval()
    valid_iou_list = []

    for batch_idx, (images, masks, racegender) in enumerate(valid_loader):
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(images)
        predicted_masks = torch.argmax(outputs, dim=1)

        valid_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )

        valid_iou_list.append(valid_iou)

    valid_iou_avg = np.mean(valid_iou_list)

    print(f"Epoch [{epoch + 1}/{num_epochs}] - Train IoU: {epoch_iou_avg:.4f} - Validation IoU: {valid_iou_avg:.4f}")


  racegender = racegender_info[0] if racegender_info and racegender_info[0] in racegenders else None
  _warn_about_invalid_encoding(encoding, patched)


Epoch [1/100] - Train IoU: 0.0691 - Validation IoU: 0.0603


  _warn_about_invalid_encoding(encoding, patched)


Epoch [2/100] - Train IoU: 0.1797 - Validation IoU: 0.2046


  _warn_about_invalid_encoding(encoding, patched)


Epoch [3/100] - Train IoU: 0.3328 - Validation IoU: 0.3787


  _warn_about_invalid_encoding(encoding, patched)


Epoch [4/100] - Train IoU: 0.4616 - Validation IoU: 0.5220


  _warn_about_invalid_encoding(encoding, patched)


Epoch [5/100] - Train IoU: 0.5706 - Validation IoU: 0.6447


  _warn_about_invalid_encoding(encoding, patched)


Epoch [6/100] - Train IoU: 0.6513 - Validation IoU: 0.6842


  _warn_about_invalid_encoding(encoding, patched)


Epoch [7/100] - Train IoU: 0.7092 - Validation IoU: 0.7245


  _warn_about_invalid_encoding(encoding, patched)


Epoch [8/100] - Train IoU: 0.7512 - Validation IoU: 0.7456


  _warn_about_invalid_encoding(encoding, patched)


Epoch [9/100] - Train IoU: 0.7812 - Validation IoU: 0.7629


  _warn_about_invalid_encoding(encoding, patched)


Epoch [10/100] - Train IoU: 0.8097 - Validation IoU: 0.7767


  _warn_about_invalid_encoding(encoding, patched)


Epoch [11/100] - Train IoU: 0.8386 - Validation IoU: 0.8033


  _warn_about_invalid_encoding(encoding, patched)


Epoch [12/100] - Train IoU: 0.8753 - Validation IoU: 0.8165


  _warn_about_invalid_encoding(encoding, patched)


Epoch [13/100] - Train IoU: 0.9030 - Validation IoU: 0.8316


  _warn_about_invalid_encoding(encoding, patched)


Epoch [14/100] - Train IoU: 0.9150 - Validation IoU: 0.8337


  _warn_about_invalid_encoding(encoding, patched)


Epoch [15/100] - Train IoU: 0.9226 - Validation IoU: 0.8482


  _warn_about_invalid_encoding(encoding, patched)


Epoch [16/100] - Train IoU: 0.9276 - Validation IoU: 0.8373


  _warn_about_invalid_encoding(encoding, patched)


Epoch [17/100] - Train IoU: 0.9331 - Validation IoU: 0.8468


  _warn_about_invalid_encoding(encoding, patched)


Epoch [18/100] - Train IoU: 0.9371 - Validation IoU: 0.8508


  _warn_about_invalid_encoding(encoding, patched)


Epoch [19/100] - Train IoU: 0.9397 - Validation IoU: 0.8520


  _warn_about_invalid_encoding(encoding, patched)


Epoch [20/100] - Train IoU: 0.9429 - Validation IoU: 0.8558


  _warn_about_invalid_encoding(encoding, patched)


Epoch [21/100] - Train IoU: 0.9442 - Validation IoU: 0.8581


  _warn_about_invalid_encoding(encoding, patched)


Epoch [22/100] - Train IoU: 0.9474 - Validation IoU: 0.8574


  _warn_about_invalid_encoding(encoding, patched)


Epoch [23/100] - Train IoU: 0.9483 - Validation IoU: 0.8525


  _warn_about_invalid_encoding(encoding, patched)


Epoch [24/100] - Train IoU: 0.9507 - Validation IoU: 0.8613


  _warn_about_invalid_encoding(encoding, patched)


Epoch [25/100] - Train IoU: 0.9518 - Validation IoU: 0.8626


  _warn_about_invalid_encoding(encoding, patched)


Epoch [26/100] - Train IoU: 0.9533 - Validation IoU: 0.8583


  _warn_about_invalid_encoding(encoding, patched)


Epoch [27/100] - Train IoU: 0.9553 - Validation IoU: 0.8637


  _warn_about_invalid_encoding(encoding, patched)


Epoch [28/100] - Train IoU: 0.9582 - Validation IoU: 0.8622


  _warn_about_invalid_encoding(encoding, patched)


Epoch [29/100] - Train IoU: 0.9581 - Validation IoU: 0.8610


  _warn_about_invalid_encoding(encoding, patched)


Epoch [30/100] - Train IoU: 0.9582 - Validation IoU: 0.8649


  _warn_about_invalid_encoding(encoding, patched)


Epoch [31/100] - Train IoU: 0.9598 - Validation IoU: 0.8638


  _warn_about_invalid_encoding(encoding, patched)


Epoch [32/100] - Train IoU: 0.9615 - Validation IoU: 0.8659


  _warn_about_invalid_encoding(encoding, patched)


Epoch [33/100] - Train IoU: 0.9618 - Validation IoU: 0.8683


  _warn_about_invalid_encoding(encoding, patched)


Epoch [34/100] - Train IoU: 0.9626 - Validation IoU: 0.8734


  _warn_about_invalid_encoding(encoding, patched)


Epoch [35/100] - Train IoU: 0.9634 - Validation IoU: 0.8726


  _warn_about_invalid_encoding(encoding, patched)


Epoch [36/100] - Train IoU: 0.9640 - Validation IoU: 0.8738


  _warn_about_invalid_encoding(encoding, patched)


Epoch [37/100] - Train IoU: 0.9649 - Validation IoU: 0.8713


  _warn_about_invalid_encoding(encoding, patched)


Epoch [38/100] - Train IoU: 0.9652 - Validation IoU: 0.8744


  _warn_about_invalid_encoding(encoding, patched)


Epoch [39/100] - Train IoU: 0.9653 - Validation IoU: 0.8762


  _warn_about_invalid_encoding(encoding, patched)


Epoch [40/100] - Train IoU: 0.9658 - Validation IoU: 0.8794


  _warn_about_invalid_encoding(encoding, patched)


Epoch [41/100] - Train IoU: 0.9664 - Validation IoU: 0.8743


  _warn_about_invalid_encoding(encoding, patched)


Epoch [42/100] - Train IoU: 0.9663 - Validation IoU: 0.8809


  _warn_about_invalid_encoding(encoding, patched)


Epoch [43/100] - Train IoU: 0.9660 - Validation IoU: 0.8677


  _warn_about_invalid_encoding(encoding, patched)


Epoch [44/100] - Train IoU: 0.9670 - Validation IoU: 0.8861


  _warn_about_invalid_encoding(encoding, patched)


Epoch [45/100] - Train IoU: 0.9674 - Validation IoU: 0.8819


  _warn_about_invalid_encoding(encoding, patched)


Epoch [46/100] - Train IoU: 0.9684 - Validation IoU: 0.8867


  _warn_about_invalid_encoding(encoding, patched)


Epoch [47/100] - Train IoU: 0.9693 - Validation IoU: 0.8851


  _warn_about_invalid_encoding(encoding, patched)


Epoch [48/100] - Train IoU: 0.9729 - Validation IoU: 0.8878


  _warn_about_invalid_encoding(encoding, patched)


Epoch [49/100] - Train IoU: 0.9781 - Validation IoU: 0.8894


  _warn_about_invalid_encoding(encoding, patched)


Epoch [50/100] - Train IoU: 0.9809 - Validation IoU: 0.8860


  _warn_about_invalid_encoding(encoding, patched)


Epoch [51/100] - Train IoU: 0.9835 - Validation IoU: 0.8931


  _warn_about_invalid_encoding(encoding, patched)


Epoch [52/100] - Train IoU: 0.9846 - Validation IoU: 0.8927


  _warn_about_invalid_encoding(encoding, patched)


Epoch [53/100] - Train IoU: 0.9860 - Validation IoU: 0.8911


  _warn_about_invalid_encoding(encoding, patched)


Epoch [54/100] - Train IoU: 0.9868 - Validation IoU: 0.8952


  _warn_about_invalid_encoding(encoding, patched)


Epoch [55/100] - Train IoU: 0.9876 - Validation IoU: 0.8941


  _warn_about_invalid_encoding(encoding, patched)


Epoch [56/100] - Train IoU: 0.9868 - Validation IoU: 0.8932


  _warn_about_invalid_encoding(encoding, patched)


Epoch [57/100] - Train IoU: 0.9876 - Validation IoU: 0.8947


  _warn_about_invalid_encoding(encoding, patched)


Epoch [58/100] - Train IoU: 0.9884 - Validation IoU: 0.8953


  _warn_about_invalid_encoding(encoding, patched)


Epoch [59/100] - Train IoU: 0.9884 - Validation IoU: 0.8954


  _warn_about_invalid_encoding(encoding, patched)


Epoch [60/100] - Train IoU: 0.9891 - Validation IoU: 0.8976


  _warn_about_invalid_encoding(encoding, patched)


Epoch [61/100] - Train IoU: 0.9893 - Validation IoU: 0.9006


  _warn_about_invalid_encoding(encoding, patched)


Epoch [62/100] - Train IoU: 0.9892 - Validation IoU: 0.8970


  _warn_about_invalid_encoding(encoding, patched)


Epoch [63/100] - Train IoU: 0.9897 - Validation IoU: 0.8984


  _warn_about_invalid_encoding(encoding, patched)


Epoch [64/100] - Train IoU: 0.9896 - Validation IoU: 0.8958


  _warn_about_invalid_encoding(encoding, patched)


Epoch [65/100] - Train IoU: 0.9896 - Validation IoU: 0.8956


  _warn_about_invalid_encoding(encoding, patched)


Epoch [66/100] - Train IoU: 0.9899 - Validation IoU: 0.8984


  _warn_about_invalid_encoding(encoding, patched)


Epoch [67/100] - Train IoU: 0.9903 - Validation IoU: 0.8971


  _warn_about_invalid_encoding(encoding, patched)


Epoch [68/100] - Train IoU: 0.9905 - Validation IoU: 0.8970


  _warn_about_invalid_encoding(encoding, patched)


Epoch [69/100] - Train IoU: 0.9905 - Validation IoU: 0.9032


  _warn_about_invalid_encoding(encoding, patched)


Epoch [70/100] - Train IoU: 0.9911 - Validation IoU: 0.9045


  _warn_about_invalid_encoding(encoding, patched)


Epoch [71/100] - Train IoU: 0.9913 - Validation IoU: 0.9019


  _warn_about_invalid_encoding(encoding, patched)


Epoch [72/100] - Train IoU: 0.9918 - Validation IoU: 0.9016


  _warn_about_invalid_encoding(encoding, patched)


Epoch [73/100] - Train IoU: 0.9922 - Validation IoU: 0.9062


  _warn_about_invalid_encoding(encoding, patched)


Epoch [74/100] - Train IoU: 0.9923 - Validation IoU: 0.9029


  _warn_about_invalid_encoding(encoding, patched)


Epoch [75/100] - Train IoU: 0.9921 - Validation IoU: 0.9019


  _warn_about_invalid_encoding(encoding, patched)


Epoch [76/100] - Train IoU: 0.9921 - Validation IoU: 0.9067


  _warn_about_invalid_encoding(encoding, patched)


Epoch [77/100] - Train IoU: 0.9921 - Validation IoU: 0.9043


  _warn_about_invalid_encoding(encoding, patched)


Epoch [78/100] - Train IoU: 0.9921 - Validation IoU: 0.9052


  _warn_about_invalid_encoding(encoding, patched)


Epoch [79/100] - Train IoU: 0.9920 - Validation IoU: 0.9017


  _warn_about_invalid_encoding(encoding, patched)


Epoch [80/100] - Train IoU: 0.9925 - Validation IoU: 0.9000


  _warn_about_invalid_encoding(encoding, patched)


Epoch [81/100] - Train IoU: 0.9923 - Validation IoU: 0.9053


  _warn_about_invalid_encoding(encoding, patched)


Epoch [82/100] - Train IoU: 0.9921 - Validation IoU: 0.9054


  _warn_about_invalid_encoding(encoding, patched)


Epoch [83/100] - Train IoU: 0.9924 - Validation IoU: 0.9004


  _warn_about_invalid_encoding(encoding, patched)


Epoch [84/100] - Train IoU: 0.9924 - Validation IoU: 0.9027


  _warn_about_invalid_encoding(encoding, patched)


Epoch [85/100] - Train IoU: 0.9933 - Validation IoU: 0.9055


  _warn_about_invalid_encoding(encoding, patched)


Epoch [86/100] - Train IoU: 0.9932 - Validation IoU: 0.9025


  _warn_about_invalid_encoding(encoding, patched)


Epoch [87/100] - Train IoU: 0.9934 - Validation IoU: 0.9071


  _warn_about_invalid_encoding(encoding, patched)


Epoch [88/100] - Train IoU: 0.9933 - Validation IoU: 0.9054


  _warn_about_invalid_encoding(encoding, patched)


Epoch [89/100] - Train IoU: 0.9930 - Validation IoU: 0.9020


  _warn_about_invalid_encoding(encoding, patched)


Epoch [90/100] - Train IoU: 0.9930 - Validation IoU: 0.9079


  _warn_about_invalid_encoding(encoding, patched)


Epoch [91/100] - Train IoU: 0.9936 - Validation IoU: 0.9098


  _warn_about_invalid_encoding(encoding, patched)


Epoch [92/100] - Train IoU: 0.9936 - Validation IoU: 0.8987


  _warn_about_invalid_encoding(encoding, patched)


Epoch [93/100] - Train IoU: 0.9937 - Validation IoU: 0.9106


  _warn_about_invalid_encoding(encoding, patched)


Epoch [94/100] - Train IoU: 0.9939 - Validation IoU: 0.9086


  _warn_about_invalid_encoding(encoding, patched)


Epoch [95/100] - Train IoU: 0.9941 - Validation IoU: 0.9068


  _warn_about_invalid_encoding(encoding, patched)


Epoch [96/100] - Train IoU: 0.9945 - Validation IoU: 0.9072


  _warn_about_invalid_encoding(encoding, patched)


Epoch [97/100] - Train IoU: 0.9943 - Validation IoU: 0.9085


  _warn_about_invalid_encoding(encoding, patched)


Epoch [98/100] - Train IoU: 0.9943 - Validation IoU: 0.9024


  _warn_about_invalid_encoding(encoding, patched)


Epoch [99/100] - Train IoU: 0.9943 - Validation IoU: 0.9073


  _warn_about_invalid_encoding(encoding, patched)


Epoch [100/100] - Train IoU: 0.9941 - Validation IoU: 0.9064
