In [None]:
pip install pydicom nibabel numpy torch torchvision segmentation-models-pytorch scikit-learn

Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install albumentations

Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install pandas

Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import numpy as np
import pydicom
import nibabel as nib
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import random
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn as nn
from scipy.ndimage import zoom
from sklearn.metrics import jaccard_score

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

num_classes = 9
batch_size = 8
img_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Images"
mask_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Annotations"
metadata_path = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/segmentation_with_racegender.csv"
test_augmentations = A.Compose([
    A.Resize(height=256, width=256),
    A.Normalize(mean=(0.485,), std=(0.229,)),
    ToTensorV2(),
])

class MulticlassHipSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, metadata_df, paired_files, num_classes, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.metadata_df = metadata_df
        self.paired_files = paired_files
        self.num_classes = num_classes
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        image_file, mask_file = self.paired_files[idx]
        if not os.path.exists(os.path.join(self.mask_root, mask_file)):
            return None, None, None

        dicom_image = pydicom.dcmread(os.path.join(self.img_root, image_file))
        image = dicom_image.pixel_array.astype(np.float32)
        annotation = nib.load(os.path.join(self.mask_root, mask_file))
        annotation_data = annotation.get_fdata()
        if len(annotation_data.shape) == 3:
            annotation_data = annotation_data[:, :, 0]

        if annotation_data.ndim > 2 and annotation_data.shape[-1] != 1:
            raise ValueError('Mask has multiple channels')

        # Apply rotation to the mask
        annotation_data = self.calculate_flipped_rotated_mask(annotation_data)

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=annotation_data)
            image = transformed["image"]
            annotation_data = transformed["mask"]

        annotation_data_onehot = self.one_hot_encode(annotation_data)

        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=annotation_data_onehot)
            image = transformed["image"]
            annotation_data_onehot = transformed["mask"]

        patient_id = int(float(image_file.split(".")[0]))
        racegender_info = self.metadata_df.loc[self.metadata_df['id'] == patient_id]['racegender'].values
        racegender = racegender_info[0] if racegender_info and racegender_info[0] in racegenders else None

        return image, annotation_data_onehot, racegender

    def one_hot_encode(self, mask):
        one_hot_mask = np.zeros((self.num_classes, *mask.shape), dtype=np.float32)
        for class_idx in range(self.num_classes):
            one_hot_mask[class_idx][mask == class_idx] = 1.0
        return one_hot_mask

    def calculate_flipped_rotated_mask(self, mask):
        rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        flipped_rotated_mask = cv2.flip(rotated_mask, 1)
        return flipped_rotated_mask

metadata_df = pd.read_csv(metadata_path)

image_files = sorted(os.listdir(img_root))
mask_files = sorted(os.listdir(mask_root))

paired_files = []

for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    mask_file = f"{image_id}.nii.gz"
    if mask_file in mask_files:
        paired_files.append((image_file, mask_file))

random.shuffle(paired_files)

train_size = int(0.7 * len(paired_files))
valid_size = int(0.1 * len(paired_files))
test_size = len(paired_files) - train_size - valid_size

train_pairs = paired_files[:train_size]
valid_pairs = paired_files[train_size:train_size + valid_size]
test_pairs = paired_files[train_size + valid_size:]

train_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, train_pairs, num_classes,
    transforms=test_augmentations
)

model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",
    in_channels=1,
    classes=num_classes,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

valid_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, valid_pairs, num_classes,
    transforms=test_augmentations
)

valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=2)

num_epochs = 100

racegenders = metadata_df['racegender'].unique()

for epoch in range(num_epochs):
    model.train()

    epoch_iou_list = []

    for batch_idx, (images, masks, _) in enumerate(train_loader):
        if images is None:
            continue

        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.argmax(dim=1))
        loss.backward()
        optimizer.step()


        predicted_masks = torch.argmax(outputs, dim=1)
        batch_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )
        epoch_iou_list.append(batch_iou)


    epoch_iou_avg = np.mean(epoch_iou_list)

    model.eval()
    valid_iou_list = []

    for batch_idx, (images, masks, racegender) in enumerate(valid_loader):
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(images)
        predicted_masks = torch.argmax(outputs, dim=1)

        valid_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )

        valid_iou_list.append(valid_iou)

    valid_iou_avg = np.mean(valid_iou_list)

    print(f"Epoch [{epoch + 1}/{num_epochs}] - Train IoU: {epoch_iou_avg:.4f} - Validation IoU: {valid_iou_avg:.4f}")


  _warn_about_invalid_encoding(encoding, patched)


Epoch [1/100] - Train IoU: 0.0691 - Validation IoU: 0.0563


  _warn_about_invalid_encoding(encoding, patched)


Epoch [2/100] - Train IoU: 0.1796 - Validation IoU: 0.1877


  _warn_about_invalid_encoding(encoding, patched)


Epoch [3/100] - Train IoU: 0.3313 - Validation IoU: 0.3505


  _warn_about_invalid_encoding(encoding, patched)


Epoch [4/100] - Train IoU: 0.4624 - Validation IoU: 0.5172


  _warn_about_invalid_encoding(encoding, patched)


Epoch [5/100] - Train IoU: 0.5700 - Validation IoU: 0.6290


  _warn_about_invalid_encoding(encoding, patched)


Epoch [6/100] - Train IoU: 0.6493 - Validation IoU: 0.6892


  _warn_about_invalid_encoding(encoding, patched)


Epoch [7/100] - Train IoU: 0.7068 - Validation IoU: 0.7185


  _warn_about_invalid_encoding(encoding, patched)


Epoch [8/100] - Train IoU: 0.7492 - Validation IoU: 0.7365


  _warn_about_invalid_encoding(encoding, patched)


Epoch [9/100] - Train IoU: 0.7863 - Validation IoU: 0.7536


  _warn_about_invalid_encoding(encoding, patched)


Epoch [10/100] - Train IoU: 0.8195 - Validation IoU: 0.7764


  _warn_about_invalid_encoding(encoding, patched)


Epoch [11/100] - Train IoU: 0.8561 - Validation IoU: 0.7980


  _warn_about_invalid_encoding(encoding, patched)


Epoch [12/100] - Train IoU: 0.8879 - Validation IoU: 0.8147


  _warn_about_invalid_encoding(encoding, patched)


Epoch [13/100] - Train IoU: 0.9003 - Validation IoU: 0.8280


  _warn_about_invalid_encoding(encoding, patched)


Epoch [14/100] - Train IoU: 0.9101 - Validation IoU: 0.8322


  _warn_about_invalid_encoding(encoding, patched)


Epoch [15/100] - Train IoU: 0.9178 - Validation IoU: 0.8340


  _warn_about_invalid_encoding(encoding, patched)


Epoch [16/100] - Train IoU: 0.9232 - Validation IoU: 0.8315


  _warn_about_invalid_encoding(encoding, patched)


Epoch [17/100] - Train IoU: 0.9277 - Validation IoU: 0.8406


  _warn_about_invalid_encoding(encoding, patched)


Epoch [18/100] - Train IoU: 0.9329 - Validation IoU: 0.8395


  _warn_about_invalid_encoding(encoding, patched)


Epoch [19/100] - Train IoU: 0.9373 - Validation IoU: 0.8508


  _warn_about_invalid_encoding(encoding, patched)


Epoch [20/100] - Train IoU: 0.9463 - Validation IoU: 0.8494


  _warn_about_invalid_encoding(encoding, patched)


Epoch [21/100] - Train IoU: 0.9561 - Validation IoU: 0.8609


  _warn_about_invalid_encoding(encoding, patched)


Epoch [22/100] - Train IoU: 0.9623 - Validation IoU: 0.8636


  _warn_about_invalid_encoding(encoding, patched)


Epoch [23/100] - Train IoU: 0.9640 - Validation IoU: 0.8632


  _warn_about_invalid_encoding(encoding, patched)


Epoch [24/100] - Train IoU: 0.9666 - Validation IoU: 0.8660


  _warn_about_invalid_encoding(encoding, patched)


Epoch [25/100] - Train IoU: 0.9683 - Validation IoU: 0.8699


  _warn_about_invalid_encoding(encoding, patched)


Epoch [26/100] - Train IoU: 0.9702 - Validation IoU: 0.8722


  _warn_about_invalid_encoding(encoding, patched)


Epoch [27/100] - Train IoU: 0.9720 - Validation IoU: 0.8745


  _warn_about_invalid_encoding(encoding, patched)


Epoch [28/100] - Train IoU: 0.9752 - Validation IoU: 0.8734


  _warn_about_invalid_encoding(encoding, patched)


Epoch [29/100] - Train IoU: 0.9751 - Validation IoU: 0.8796


  _warn_about_invalid_encoding(encoding, patched)


Epoch [30/100] - Train IoU: 0.9743 - Validation IoU: 0.8703


  _warn_about_invalid_encoding(encoding, patched)


Epoch [31/100] - Train IoU: 0.9761 - Validation IoU: 0.8827


  _warn_about_invalid_encoding(encoding, patched)


Epoch [32/100] - Train IoU: 0.9781 - Validation IoU: 0.8785


  _warn_about_invalid_encoding(encoding, patched)


Epoch [33/100] - Train IoU: 0.9780 - Validation IoU: 0.8809


  _warn_about_invalid_encoding(encoding, patched)


Epoch [34/100] - Train IoU: 0.9799 - Validation IoU: 0.8808


  _warn_about_invalid_encoding(encoding, patched)


Epoch [35/100] - Train IoU: 0.9807 - Validation IoU: 0.8867


  _warn_about_invalid_encoding(encoding, patched)


Epoch [36/100] - Train IoU: 0.9819 - Validation IoU: 0.8850


  _warn_about_invalid_encoding(encoding, patched)


Epoch [37/100] - Train IoU: 0.9831 - Validation IoU: 0.8869


  _warn_about_invalid_encoding(encoding, patched)


Epoch [38/100] - Train IoU: 0.9832 - Validation IoU: 0.8855


  _warn_about_invalid_encoding(encoding, patched)


Epoch [39/100] - Train IoU: 0.9837 - Validation IoU: 0.8850


  _warn_about_invalid_encoding(encoding, patched)


Epoch [40/100] - Train IoU: 0.9840 - Validation IoU: 0.8891


  _warn_about_invalid_encoding(encoding, patched)


Epoch [41/100] - Train IoU: 0.9850 - Validation IoU: 0.8866


  _warn_about_invalid_encoding(encoding, patched)


Epoch [42/100] - Train IoU: 0.9851 - Validation IoU: 0.8891


  _warn_about_invalid_encoding(encoding, patched)


Epoch [43/100] - Train IoU: 0.9855 - Validation IoU: 0.8896


  _warn_about_invalid_encoding(encoding, patched)


Epoch [44/100] - Train IoU: 0.9856 - Validation IoU: 0.8926


  _warn_about_invalid_encoding(encoding, patched)


Epoch [45/100] - Train IoU: 0.9859 - Validation IoU: 0.8945


  _warn_about_invalid_encoding(encoding, patched)


Epoch [46/100] - Train IoU: 0.9865 - Validation IoU: 0.8938


  _warn_about_invalid_encoding(encoding, patched)


Epoch [47/100] - Train IoU: 0.9866 - Validation IoU: 0.8921


  _warn_about_invalid_encoding(encoding, patched)


Epoch [48/100] - Train IoU: 0.9872 - Validation IoU: 0.8928


  _warn_about_invalid_encoding(encoding, patched)


Epoch [49/100] - Train IoU: 0.9873 - Validation IoU: 0.8941


  _warn_about_invalid_encoding(encoding, patched)


Epoch [50/100] - Train IoU: 0.9863 - Validation IoU: 0.8889


  _warn_about_invalid_encoding(encoding, patched)


Epoch [51/100] - Train IoU: 0.9876 - Validation IoU: 0.8962


  _warn_about_invalid_encoding(encoding, patched)


Epoch [52/100] - Train IoU: 0.9879 - Validation IoU: 0.8988


  _warn_about_invalid_encoding(encoding, patched)


Epoch [53/100] - Train IoU: 0.9887 - Validation IoU: 0.9023


  _warn_about_invalid_encoding(encoding, patched)


Epoch [54/100] - Train IoU: 0.9891 - Validation IoU: 0.8986


  _warn_about_invalid_encoding(encoding, patched)


Epoch [55/100] - Train IoU: 0.9893 - Validation IoU: 0.8976


  _warn_about_invalid_encoding(encoding, patched)


Epoch [56/100] - Train IoU: 0.9888 - Validation IoU: 0.8952


  _warn_about_invalid_encoding(encoding, patched)


Epoch [57/100] - Train IoU: 0.9892 - Validation IoU: 0.9006


  _warn_about_invalid_encoding(encoding, patched)


Epoch [58/100] - Train IoU: 0.9898 - Validation IoU: 0.8985


  _warn_about_invalid_encoding(encoding, patched)


Epoch [59/100] - Train IoU: 0.9897 - Validation IoU: 0.8990


  _warn_about_invalid_encoding(encoding, patched)


Epoch [60/100] - Train IoU: 0.9901 - Validation IoU: 0.9010


  _warn_about_invalid_encoding(encoding, patched)


Epoch [61/100] - Train IoU: 0.9902 - Validation IoU: 0.8982


  _warn_about_invalid_encoding(encoding, patched)


Epoch [62/100] - Train IoU: 0.9905 - Validation IoU: 0.9034


  _warn_about_invalid_encoding(encoding, patched)


Epoch [63/100] - Train IoU: 0.9907 - Validation IoU: 0.9001


  _warn_about_invalid_encoding(encoding, patched)


Epoch [64/100] - Train IoU: 0.9906 - Validation IoU: 0.9025


  _warn_about_invalid_encoding(encoding, patched)


Epoch [65/100] - Train IoU: 0.9904 - Validation IoU: 0.9022


  _warn_about_invalid_encoding(encoding, patched)


Epoch [66/100] - Train IoU: 0.9908 - Validation IoU: 0.9048


  _warn_about_invalid_encoding(encoding, patched)


Epoch [67/100] - Train IoU: 0.9910 - Validation IoU: 0.9051


  _warn_about_invalid_encoding(encoding, patched)


Epoch [68/100] - Train IoU: 0.9912 - Validation IoU: 0.9045


  _warn_about_invalid_encoding(encoding, patched)


Epoch [69/100] - Train IoU: 0.9913 - Validation IoU: 0.9041


  _warn_about_invalid_encoding(encoding, patched)


Epoch [70/100] - Train IoU: 0.9913 - Validation IoU: 0.9072


  _warn_about_invalid_encoding(encoding, patched)


Epoch [71/100] - Train IoU: 0.9911 - Validation IoU: 0.9050


  _warn_about_invalid_encoding(encoding, patched)


Epoch [72/100] - Train IoU: 0.9918 - Validation IoU: 0.9054


  _warn_about_invalid_encoding(encoding, patched)


Epoch [73/100] - Train IoU: 0.9925 - Validation IoU: 0.9062


  _warn_about_invalid_encoding(encoding, patched)


Epoch [74/100] - Train IoU: 0.9926 - Validation IoU: 0.9057


  _warn_about_invalid_encoding(encoding, patched)


Epoch [75/100] - Train IoU: 0.9924 - Validation IoU: 0.9031


  _warn_about_invalid_encoding(encoding, patched)


Epoch [76/100] - Train IoU: 0.9922 - Validation IoU: 0.9043


  _warn_about_invalid_encoding(encoding, patched)


Epoch [77/100] - Train IoU: 0.9924 - Validation IoU: 0.9062


  _warn_about_invalid_encoding(encoding, patched)


Epoch [78/100] - Train IoU: 0.9927 - Validation IoU: 0.9080


  _warn_about_invalid_encoding(encoding, patched)


Epoch [79/100] - Train IoU: 0.9928 - Validation IoU: 0.9068


  _warn_about_invalid_encoding(encoding, patched)


Epoch [80/100] - Train IoU: 0.9932 - Validation IoU: 0.9060


  _warn_about_invalid_encoding(encoding, patched)


Epoch [81/100] - Train IoU: 0.9932 - Validation IoU: 0.9072


  _warn_about_invalid_encoding(encoding, patched)


Epoch [82/100] - Train IoU: 0.9928 - Validation IoU: 0.9084


  _warn_about_invalid_encoding(encoding, patched)


Epoch [83/100] - Train IoU: 0.9926 - Validation IoU: 0.9070


  _warn_about_invalid_encoding(encoding, patched)


Epoch [84/100] - Train IoU: 0.9923 - Validation IoU: 0.9085


  _warn_about_invalid_encoding(encoding, patched)


Epoch [85/100] - Train IoU: 0.9933 - Validation IoU: 0.9111


  _warn_about_invalid_encoding(encoding, patched)


Epoch [86/100] - Train IoU: 0.9933 - Validation IoU: 0.9069


  _warn_about_invalid_encoding(encoding, patched)


Epoch [87/100] - Train IoU: 0.9933 - Validation IoU: 0.9094


  _warn_about_invalid_encoding(encoding, patched)


Epoch [88/100] - Train IoU: 0.9932 - Validation IoU: 0.9117


  _warn_about_invalid_encoding(encoding, patched)


Epoch [89/100] - Train IoU: 0.9929 - Validation IoU: 0.9039


  _warn_about_invalid_encoding(encoding, patched)


Epoch [90/100] - Train IoU: 0.9921 - Validation IoU: 0.9122


  _warn_about_invalid_encoding(encoding, patched)


Epoch [91/100] - Train IoU: 0.9931 - Validation IoU: 0.9129


  _warn_about_invalid_encoding(encoding, patched)


Epoch [92/100] - Train IoU: 0.9932 - Validation IoU: 0.9087


  _warn_about_invalid_encoding(encoding, patched)


Epoch [93/100] - Train IoU: 0.9936 - Validation IoU: 0.9120


  _warn_about_invalid_encoding(encoding, patched)


Epoch [94/100] - Train IoU: 0.9938 - Validation IoU: 0.9118


  _warn_about_invalid_encoding(encoding, patched)


Epoch [95/100] - Train IoU: 0.9943 - Validation IoU: 0.9121


  _warn_about_invalid_encoding(encoding, patched)


Epoch [96/100] - Train IoU: 0.9946 - Validation IoU: 0.9132


  _warn_about_invalid_encoding(encoding, patched)


Epoch [97/100] - Train IoU: 0.9942 - Validation IoU: 0.9141


  _warn_about_invalid_encoding(encoding, patched)


Epoch [98/100] - Train IoU: 0.9944 - Validation IoU: 0.9134


  _warn_about_invalid_encoding(encoding, patched)


Epoch [99/100] - Train IoU: 0.9947 - Validation IoU: 0.9123


  _warn_about_invalid_encoding(encoding, patched)


Epoch [100/100] - Train IoU: 0.9946 - Validation IoU: 0.9130


In [None]:
saved_model_path = 'trained_model_ Multi Diversity baseline 100 epoch with Decay1e-5 ioulib9241.pth'


torch.save(model.state_dict(), saved_model_path)


print(f"Model saved to {saved_model_path}")

Model saved to trained_model_ Multi Diversity baseline 100 epoch with Decay1e-5 ioulib9241.pth
