In [None]:
pip install pydicom nibabel numpy torch torchvision segmentation-models-pytorch scikit-learn

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install albumentations

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install pandas

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install opencv-python

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import numpy as np
import pydicom
import nibabel as nib
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import random
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import jaccard_score


random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


num_classes = 9
batch_size = 4
img_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Images"
mask_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Annotations"
metadata_path = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/segmentation_with_racegender.csv"


test_augmentations = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.485, 0.485), std=(0.229, 0.229, 0.229)),
    ToTensorV2(),
])

class MulticlassHipSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, metadata_df, paired_files, num_classes, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.metadata_df = metadata_df
        self.paired_files = paired_files
        self.num_classes = num_classes
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        image_file, mask_file = self.paired_files[idx]
        if not os.path.exists(os.path.join(self.mask_root, mask_file)):
            return None

        dicom_image = pydicom.dcmread(os.path.join(self.img_root, image_file))
        image = dicom_image.pixel_array.astype(np.float32)
        image = cv2.resize(image, (512, 512))


        image = np.stack([image] * 3, axis=-1)

        annotation = nib.load(os.path.join(self.mask_root, mask_file))
        annotation_data = annotation.get_fdata()
        if len(annotation_data.shape) == 3:
            annotation_data = annotation_data[:, :, 0]

        annotation_data = self.calculate_flipped_rotated_mask(annotation_data)
        annotation_data = cv2.resize(annotation_data, (512, 512))

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=annotation_data)
            image = transformed["image"]
            annotation_data = transformed["mask"]

        annotation_data_onehot = self.one_hot_encode(annotation_data)

        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=annotation_data_onehot)
            image = transformed["image"]
            annotation_data_onehot = transformed["mask"]

        patient_id = int(float(image_file.split(".")[0]))
        racegender_info = self.metadata_df.loc[self.metadata_df['id'] == patient_id]['racegender'].values
        racegender = racegender_info[0] if racegender_info.size > 0 else 'Unknown'

        return image, annotation_data_onehot, racegender

    def one_hot_encode(self, mask):
        one_hot_mask = np.zeros((self.num_classes, *mask.shape), dtype=np.float32)
        for class_idx in range(self.num_classes):
            one_hot_mask[class_idx][mask == class_idx] = 1.0
        return one_hot_mask

    def calculate_flipped_rotated_mask(self, mask):
        rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        flipped_rotated_mask = cv2.flip(rotated_mask, 1)
        return flipped_rotated_mask

metadata_df = pd.read_csv(metadata_path)

image_files = sorted(os.listdir(img_root))
mask_files = sorted(os.listdir(mask_root))

paired_files = []

for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    mask_file = f"{image_id}.nii.gz"
    if mask_file in mask_files:
        paired_files.append((image_file, mask_file))

random.shuffle(paired_files)

train_size = int(0.7 * len(paired_files))
valid_size = int(0.1 * len(paired_files))
test_size = len(paired_files) - train_size - valid_size

train_pairs = paired_files[:train_size]
valid_pairs = paired_files[train_size:train_size + valid_size]
test_pairs = paired_files[train_size + valid_size:]

train_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, train_pairs, num_classes,
    transforms=test_augmentations
)

def train_unetplusplus_model(num_epochs=100, encoder_name="resnet18"):

    model = smp.UnetPlusPlus(
        encoder_name=encoder_name,
        encoder_weights="imagenet",
        in_channels=3,
        classes=num_classes,
    )

    device = torch.device("cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    valid_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, valid_pairs, num_classes,
        transforms=test_augmentations
    )

    valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=2)

    for epoch in range(num_epochs):
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

        model.train()
        train_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(train_loader):
            if images is None:
                continue

            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)

            predicted_masks = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, masks.argmax(dim=1))
            loss.backward()
            optimizer.step()

            train_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )
            train_iou_list.append(train_iou)

        model.eval()
        valid_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(valid_loader):
            images, masks = images.to(device), masks.to(device)
            with torch.no_grad():
                outputs = model(images)
            predicted_masks = torch.argmax(outputs, dim=1)

            valid_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )

            valid_iou_list.append(valid_iou)

        train_iou_avg = np.mean(train_iou_list)
        valid_iou_avg = np.mean(valid_iou_list)

        print(f"Epoch [{epoch + 1}/{num_epochs}] - Train IoU: {train_iou_avg:.4f} - Validation IoU: {valid_iou_avg:.4f}")


    torch.save(model.state_dict(), 'unetplusplus_model_knee_init.pth')


    test_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, test_pairs, num_classes,
        transforms=test_augmentations
    )

    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    model.eval()
    test_iou_list = []

    for batch_idx, (images, masks, racegender) in enumerate(test_loader):
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(images)
        predicted_masks = torch.argmax(outputs, dim=1)

        test_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )

        test_iou_list.append(test_iou)

    test_iou_avg = np.mean(test_iou_list)

    print("Model saved successfully.")


print("Training U-Net++ ..")
train_unetplusplus_model(encoder_name="resnet18")


Training U-Net++ ..


  _warn_about_invalid_encoding(encoding, patched)


Epoch [1/100] - Train IoU: 0.0615 - Validation IoU: 0.0853


  _warn_about_invalid_encoding(encoding, patched)


Epoch [2/100] - Train IoU: 0.1074 - Validation IoU: 0.1087


  _warn_about_invalid_encoding(encoding, patched)


Epoch [3/100] - Train IoU: 0.1624 - Validation IoU: 0.1790


  _warn_about_invalid_encoding(encoding, patched)


Epoch [4/100] - Train IoU: 0.2186 - Validation IoU: 0.2174


  _warn_about_invalid_encoding(encoding, patched)


Epoch [5/100] - Train IoU: 0.2874 - Validation IoU: 0.2527


  _warn_about_invalid_encoding(encoding, patched)


Epoch [6/100] - Train IoU: 0.3821 - Validation IoU: 0.3361


  _warn_about_invalid_encoding(encoding, patched)


Epoch [7/100] - Train IoU: 0.4800 - Validation IoU: 0.4398


  _warn_about_invalid_encoding(encoding, patched)


Epoch [8/100] - Train IoU: 0.5562 - Validation IoU: 0.5219


  _warn_about_invalid_encoding(encoding, patched)


Epoch [9/100] - Train IoU: 0.6292 - Validation IoU: 0.5973


  _warn_about_invalid_encoding(encoding, patched)


Epoch [10/100] - Train IoU: 0.6868 - Validation IoU: 0.6635


  _warn_about_invalid_encoding(encoding, patched)


Epoch [11/100] - Train IoU: 0.7265 - Validation IoU: 0.7094


  _warn_about_invalid_encoding(encoding, patched)


Epoch [12/100] - Train IoU: 0.7533 - Validation IoU: 0.7294


  _warn_about_invalid_encoding(encoding, patched)


Epoch [13/100] - Train IoU: 0.7980 - Validation IoU: 0.7724


  _warn_about_invalid_encoding(encoding, patched)


Epoch [14/100] - Train IoU: 0.8413 - Validation IoU: 0.8204


  _warn_about_invalid_encoding(encoding, patched)


Epoch [15/100] - Train IoU: 0.8730 - Validation IoU: 0.8442


  _warn_about_invalid_encoding(encoding, patched)


Epoch [16/100] - Train IoU: 0.8872 - Validation IoU: 0.8598


  _warn_about_invalid_encoding(encoding, patched)


Epoch [17/100] - Train IoU: 0.9002 - Validation IoU: 0.8655


  _warn_about_invalid_encoding(encoding, patched)


Epoch [18/100] - Train IoU: 0.9084 - Validation IoU: 0.8761


  _warn_about_invalid_encoding(encoding, patched)


Epoch [19/100] - Train IoU: 0.9163 - Validation IoU: 0.8828


  _warn_about_invalid_encoding(encoding, patched)


Epoch [20/100] - Train IoU: 0.9226 - Validation IoU: 0.8902


  _warn_about_invalid_encoding(encoding, patched)


Epoch [21/100] - Train IoU: 0.9273 - Validation IoU: 0.8927


  _warn_about_invalid_encoding(encoding, patched)


Epoch [22/100] - Train IoU: 0.9300 - Validation IoU: 0.8984


  _warn_about_invalid_encoding(encoding, patched)


Epoch [23/100] - Train IoU: 0.9347 - Validation IoU: 0.9017


  _warn_about_invalid_encoding(encoding, patched)


Epoch [24/100] - Train IoU: 0.9355 - Validation IoU: 0.9064


  _warn_about_invalid_encoding(encoding, patched)


Epoch [25/100] - Train IoU: 0.9402 - Validation IoU: 0.9097


  _warn_about_invalid_encoding(encoding, patched)


Epoch [26/100] - Train IoU: 0.9411 - Validation IoU: 0.9102


  _warn_about_invalid_encoding(encoding, patched)


Epoch [27/100] - Train IoU: 0.9445 - Validation IoU: 0.9128


  _warn_about_invalid_encoding(encoding, patched)


Epoch [28/100] - Train IoU: 0.9468 - Validation IoU: 0.9128


  _warn_about_invalid_encoding(encoding, patched)


Epoch [29/100] - Train IoU: 0.9480 - Validation IoU: 0.9152


  _warn_about_invalid_encoding(encoding, patched)


Epoch [30/100] - Train IoU: 0.9482 - Validation IoU: 0.9145


  _warn_about_invalid_encoding(encoding, patched)


Epoch [31/100] - Train IoU: 0.9500 - Validation IoU: 0.9157


  _warn_about_invalid_encoding(encoding, patched)


Epoch [32/100] - Train IoU: 0.9511 - Validation IoU: 0.9166


  _warn_about_invalid_encoding(encoding, patched)


Epoch [33/100] - Train IoU: 0.9530 - Validation IoU: 0.9189


  _warn_about_invalid_encoding(encoding, patched)


Epoch [34/100] - Train IoU: 0.9540 - Validation IoU: 0.9197


  _warn_about_invalid_encoding(encoding, patched)


Epoch [35/100] - Train IoU: 0.9546 - Validation IoU: 0.9172


  _warn_about_invalid_encoding(encoding, patched)


Epoch [36/100] - Train IoU: 0.9549 - Validation IoU: 0.9193


  _warn_about_invalid_encoding(encoding, patched)


Epoch [37/100] - Train IoU: 0.9565 - Validation IoU: 0.9211


  _warn_about_invalid_encoding(encoding, patched)


Epoch [38/100] - Train IoU: 0.9571 - Validation IoU: 0.9219


  _warn_about_invalid_encoding(encoding, patched)


Epoch [39/100] - Train IoU: 0.9584 - Validation IoU: 0.9244


  _warn_about_invalid_encoding(encoding, patched)


Epoch [40/100] - Train IoU: 0.9588 - Validation IoU: 0.9244


  _warn_about_invalid_encoding(encoding, patched)


Epoch [41/100] - Train IoU: 0.9603 - Validation IoU: 0.9253


  _warn_about_invalid_encoding(encoding, patched)


Epoch [42/100] - Train IoU: 0.9607 - Validation IoU: 0.9251


  _warn_about_invalid_encoding(encoding, patched)


Epoch [43/100] - Train IoU: 0.9615 - Validation IoU: 0.9238


  _warn_about_invalid_encoding(encoding, patched)


Epoch [44/100] - Train IoU: 0.9614 - Validation IoU: 0.9265


  _warn_about_invalid_encoding(encoding, patched)


Epoch [45/100] - Train IoU: 0.9619 - Validation IoU: 0.9267


  _warn_about_invalid_encoding(encoding, patched)


Epoch [46/100] - Train IoU: 0.9633 - Validation IoU: 0.9249


  _warn_about_invalid_encoding(encoding, patched)


Epoch [47/100] - Train IoU: 0.9643 - Validation IoU: 0.9272


  _warn_about_invalid_encoding(encoding, patched)


Epoch [48/100] - Train IoU: 0.9657 - Validation IoU: 0.9271


  _warn_about_invalid_encoding(encoding, patched)


Epoch [49/100] - Train IoU: 0.9672 - Validation IoU: 0.9260


  _warn_about_invalid_encoding(encoding, patched)


Epoch [50/100] - Train IoU: 0.9694 - Validation IoU: 0.9268


  _warn_about_invalid_encoding(encoding, patched)


Epoch [51/100] - Train IoU: 0.9728 - Validation IoU: 0.9295


  _warn_about_invalid_encoding(encoding, patched)


Epoch [52/100] - Train IoU: 0.9765 - Validation IoU: 0.9315


  _warn_about_invalid_encoding(encoding, patched)


Epoch [53/100] - Train IoU: 0.9784 - Validation IoU: 0.9315


  _warn_about_invalid_encoding(encoding, patched)


Epoch [54/100] - Train IoU: 0.9796 - Validation IoU: 0.9361


  _warn_about_invalid_encoding(encoding, patched)


Epoch [55/100] - Train IoU: 0.9802 - Validation IoU: 0.9360


  _warn_about_invalid_encoding(encoding, patched)


Epoch [56/100] - Train IoU: 0.9809 - Validation IoU: 0.9358


  _warn_about_invalid_encoding(encoding, patched)


Epoch [57/100] - Train IoU: 0.9810 - Validation IoU: 0.9387


  _warn_about_invalid_encoding(encoding, patched)


Epoch [58/100] - Train IoU: 0.9816 - Validation IoU: 0.9378


  _warn_about_invalid_encoding(encoding, patched)


Epoch [59/100] - Train IoU: 0.9830 - Validation IoU: 0.9378


  _warn_about_invalid_encoding(encoding, patched)


Epoch [60/100] - Train IoU: 0.9823 - Validation IoU: 0.9367


  _warn_about_invalid_encoding(encoding, patched)


Epoch [61/100] - Train IoU: 0.9839 - Validation IoU: 0.9389


  _warn_about_invalid_encoding(encoding, patched)


Epoch [62/100] - Train IoU: 0.9833 - Validation IoU: 0.9371


  _warn_about_invalid_encoding(encoding, patched)


Epoch [63/100] - Train IoU: 0.9829 - Validation IoU: 0.9395


  _warn_about_invalid_encoding(encoding, patched)


Epoch [64/100] - Train IoU: 0.9842 - Validation IoU: 0.9380


  _warn_about_invalid_encoding(encoding, patched)


Epoch [65/100] - Train IoU: 0.9835 - Validation IoU: 0.9398


  _warn_about_invalid_encoding(encoding, patched)


Epoch [66/100] - Train IoU: 0.9849 - Validation IoU: 0.9369


  _warn_about_invalid_encoding(encoding, patched)


Epoch [67/100] - Train IoU: 0.9848 - Validation IoU: 0.9392


  _warn_about_invalid_encoding(encoding, patched)


Epoch [68/100] - Train IoU: 0.9844 - Validation IoU: 0.9403


  _warn_about_invalid_encoding(encoding, patched)


Epoch [69/100] - Train IoU: 0.9852 - Validation IoU: 0.9396


  _warn_about_invalid_encoding(encoding, patched)


Epoch [70/100] - Train IoU: 0.9844 - Validation IoU: 0.9399


  _warn_about_invalid_encoding(encoding, patched)


Epoch [71/100] - Train IoU: 0.9854 - Validation IoU: 0.9407


  _warn_about_invalid_encoding(encoding, patched)


Epoch [72/100] - Train IoU: 0.9855 - Validation IoU: 0.9407


  _warn_about_invalid_encoding(encoding, patched)


Epoch [73/100] - Train IoU: 0.9850 - Validation IoU: 0.9406


  _warn_about_invalid_encoding(encoding, patched)


Epoch [74/100] - Train IoU: 0.9849 - Validation IoU: 0.9389


  _warn_about_invalid_encoding(encoding, patched)


Epoch [75/100] - Train IoU: 0.9862 - Validation IoU: 0.9427


  _warn_about_invalid_encoding(encoding, patched)


Epoch [76/100] - Train IoU: 0.9864 - Validation IoU: 0.9389


  _warn_about_invalid_encoding(encoding, patched)


Epoch [77/100] - Train IoU: 0.9865 - Validation IoU: 0.9418


  _warn_about_invalid_encoding(encoding, patched)


Epoch [78/100] - Train IoU: 0.9867 - Validation IoU: 0.9390


  _warn_about_invalid_encoding(encoding, patched)


Epoch [79/100] - Train IoU: 0.9874 - Validation IoU: 0.9415


  _warn_about_invalid_encoding(encoding, patched)


Epoch [80/100] - Train IoU: 0.9866 - Validation IoU: 0.9381


  _warn_about_invalid_encoding(encoding, patched)


Epoch [81/100] - Train IoU: 0.9867 - Validation IoU: 0.9398


  _warn_about_invalid_encoding(encoding, patched)


Epoch [82/100] - Train IoU: 0.9869 - Validation IoU: 0.9404


  _warn_about_invalid_encoding(encoding, patched)


Epoch [83/100] - Train IoU: 0.9878 - Validation IoU: 0.9403


  _warn_about_invalid_encoding(encoding, patched)


Epoch [84/100] - Train IoU: 0.9882 - Validation IoU: 0.9401


  _warn_about_invalid_encoding(encoding, patched)


Epoch [85/100] - Train IoU: 0.9878 - Validation IoU: 0.9402


  _warn_about_invalid_encoding(encoding, patched)


Epoch [86/100] - Train IoU: 0.9881 - Validation IoU: 0.9406


  _warn_about_invalid_encoding(encoding, patched)


Epoch [87/100] - Train IoU: 0.9882 - Validation IoU: 0.9405


  _warn_about_invalid_encoding(encoding, patched)


Epoch [88/100] - Train IoU: 0.9884 - Validation IoU: 0.9395


  _warn_about_invalid_encoding(encoding, patched)


Epoch [89/100] - Train IoU: 0.9882 - Validation IoU: 0.9403


  _warn_about_invalid_encoding(encoding, patched)


Epoch [90/100] - Train IoU: 0.9888 - Validation IoU: 0.9407


  _warn_about_invalid_encoding(encoding, patched)


Epoch [91/100] - Train IoU: 0.9889 - Validation IoU: 0.9406


  _warn_about_invalid_encoding(encoding, patched)


Epoch [92/100] - Train IoU: 0.9894 - Validation IoU: 0.9412


  _warn_about_invalid_encoding(encoding, patched)


Epoch [93/100] - Train IoU: 0.9892 - Validation IoU: 0.9417


  _warn_about_invalid_encoding(encoding, patched)


Epoch [94/100] - Train IoU: 0.9894 - Validation IoU: 0.9397


  _warn_about_invalid_encoding(encoding, patched)


Epoch [95/100] - Train IoU: 0.9894 - Validation IoU: 0.9397


  _warn_about_invalid_encoding(encoding, patched)


Epoch [96/100] - Train IoU: 0.9901 - Validation IoU: 0.9401


  _warn_about_invalid_encoding(encoding, patched)


Epoch [97/100] - Train IoU: 0.9897 - Validation IoU: 0.9414


  _warn_about_invalid_encoding(encoding, patched)


Epoch [98/100] - Train IoU: 0.9901 - Validation IoU: 0.9414


  _warn_about_invalid_encoding(encoding, patched)


Epoch [99/100] - Train IoU: 0.9900 - Validation IoU: 0.9395


  _warn_about_invalid_encoding(encoding, patched)


Epoch [100/100] - Train IoU: 0.9901 - Validation IoU: 0.9402
Model saved successfully.


In [None]:
import os
import numpy as np
import pydicom
import nibabel as nib
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import random
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import jaccard_score


random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


num_classes = 9
batch_size = 4
img_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Images"
mask_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Annotations"
metadata_path = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/segmentation_with_racegender.csv"


test_augmentations = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.485, 0.485), std=(0.229, 0.229, 0.229)),
    ToTensorV2(),
])

class MulticlassHipSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, metadata_df, paired_files, num_classes, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.metadata_df = metadata_df
        self.paired_files = paired_files
        self.num_classes = num_classes
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        image_file, mask_file = self.paired_files[idx]
        if not os.path.exists(os.path.join(self.mask_root, mask_file)):
            return None

        dicom_image = pydicom.dcmread(os.path.join(self.img_root, image_file))
        image = dicom_image.pixel_array.astype(np.float32)
        image = cv2.resize(image, (512, 512))


        image = np.stack([image] * 3, axis=-1)

        annotation = nib.load(os.path.join(self.mask_root, mask_file))
        annotation_data = annotation.get_fdata()
        if len(annotation_data.shape) == 3:
            annotation_data = annotation_data[:, :, 0]

        annotation_data = self.calculate_flipped_rotated_mask(annotation_data)
        annotation_data = cv2.resize(annotation_data, (512, 512))

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=annotation_data)
            image = transformed["image"]
            annotation_data = transformed["mask"]

        annotation_data_onehot = self.one_hot_encode(annotation_data)

        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=annotation_data_onehot)
            image = transformed["image"]
            annotation_data_onehot = transformed["mask"]

        patient_id = int(float(image_file.split(".")[0]))
        racegender_info = self.metadata_df.loc[self.metadata_df['id'] == patient_id]['racegender'].values
        racegender = racegender_info[0] if racegender_info.size > 0 else 'Unknown'

        return image, annotation_data_onehot, racegender

    def one_hot_encode(self, mask):
        one_hot_mask = np.zeros((self.num_classes, *mask.shape), dtype=np.float32)
        for class_idx in range(self.num_classes):
            one_hot_mask[class_idx][mask == class_idx] = 1.0
        return one_hot_mask

    def calculate_flipped_rotated_mask(self, mask):
        rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        flipped_rotated_mask = cv2.flip(rotated_mask, 1)
        return flipped_rotated_mask

metadata_df = pd.read_csv(metadata_path)

image_files = sorted(os.listdir(img_root))
mask_files = sorted(os.listdir(mask_root))

paired_files = []

for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    mask_file = f"{image_id}.nii.gz"
    if mask_file in mask_files:
        paired_files.append((image_file, mask_file))

random.shuffle(paired_files)

train_size = int(0.7 * len(paired_files))
valid_size = int(0.1 * len(paired_files))
test_size = len(paired_files) - train_size - valid_size

train_pairs = paired_files[:train_size]
valid_pairs = paired_files[train_size:train_size + valid_size]
test_pairs = paired_files[train_size + valid_size:]

train_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, train_pairs, num_classes,
    transforms=test_augmentations
)

def train_linknet_model(num_epochs=100, encoder_name="resnet18"):

    model = smp.Linknet(
        encoder_name=encoder_name,
        encoder_weights="imagenet",
        in_channels=3,
        classes=num_classes,
    )

    device = torch.device("cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    valid_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, valid_pairs, num_classes,
        transforms=test_augmentations
    )

    valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=2)

    for epoch in range(num_epochs):
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

        model.train()
        train_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(train_loader):
            if images is None:
                continue

            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)

            predicted_masks = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, masks.argmax(dim=1))
            loss.backward()
            optimizer.step()

            train_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )
            train_iou_list.append(train_iou)

        model.eval()
        valid_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(valid_loader):
            images, masks = images.to(device), masks.to(device)
            with torch.no_grad():
                outputs = model(images)
            predicted_masks = torch.argmax(outputs, dim=1)

            valid_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )

            valid_iou_list.append(valid_iou)

        train_iou_avg = np.mean(train_iou_list)
        valid_iou_avg = np.mean(valid_iou_list)

        print(f"Epoch [{epoch + 1}/{num_epochs}] - Train IoU: {train_iou_avg:.4f} - Validation IoU: {valid_iou_avg:.4f}")


    torch.save(model.state_dict(), 'linknet_model_knee_init.pth')


    test_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, test_pairs, num_classes,
        transforms=test_augmentations
    )

    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    model.eval()
    test_iou_list = []

    for batch_idx, (images, masks, racegender) in enumerate(test_loader):
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(images)
        predicted_masks = torch.argmax(outputs, dim=1)

        test_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )

        test_iou_list.append(test_iou)

    test_iou_avg = np.mean(test_iou_list)


    print("Model saved successfully.")


print("Training LinkNet ..")
train_linknet_model(encoder_name="resnet18")


Training LinkNet ..


  _warn_about_invalid_encoding(encoding, patched)


Epoch [1/100] - Train IoU: 0.1305 - Validation IoU: 0.0288


  _warn_about_invalid_encoding(encoding, patched)


Epoch [2/100] - Train IoU: 0.1534 - Validation IoU: 0.0288


  _warn_about_invalid_encoding(encoding, patched)


Epoch [3/100] - Train IoU: 0.1731 - Validation IoU: 0.0288


  _warn_about_invalid_encoding(encoding, patched)


Epoch [4/100] - Train IoU: 0.1908 - Validation IoU: 0.0282


  _warn_about_invalid_encoding(encoding, patched)


Epoch [5/100] - Train IoU: 0.2099 - Validation IoU: 0.0284


  _warn_about_invalid_encoding(encoding, patched)


Epoch [6/100] - Train IoU: 0.2295 - Validation IoU: 0.0430


  _warn_about_invalid_encoding(encoding, patched)


Epoch [7/100] - Train IoU: 0.2524 - Validation IoU: 0.0940


  _warn_about_invalid_encoding(encoding, patched)


Epoch [8/100] - Train IoU: 0.2749 - Validation IoU: 0.1765


  _warn_about_invalid_encoding(encoding, patched)


Epoch [9/100] - Train IoU: 0.3005 - Validation IoU: 0.2588


  _warn_about_invalid_encoding(encoding, patched)


Epoch [10/100] - Train IoU: 0.3249 - Validation IoU: 0.3162


  _warn_about_invalid_encoding(encoding, patched)


Epoch [11/100] - Train IoU: 0.3514 - Validation IoU: 0.3685


  _warn_about_invalid_encoding(encoding, patched)


Epoch [12/100] - Train IoU: 0.3770 - Validation IoU: 0.4123


  _warn_about_invalid_encoding(encoding, patched)


Epoch [13/100] - Train IoU: 0.4038 - Validation IoU: 0.4469


  _warn_about_invalid_encoding(encoding, patched)


Epoch [14/100] - Train IoU: 0.4303 - Validation IoU: 0.4742


  _warn_about_invalid_encoding(encoding, patched)


Epoch [15/100] - Train IoU: 0.4558 - Validation IoU: 0.4975


  _warn_about_invalid_encoding(encoding, patched)


Epoch [16/100] - Train IoU: 0.4805 - Validation IoU: 0.5202


  _warn_about_invalid_encoding(encoding, patched)


Epoch [17/100] - Train IoU: 0.5049 - Validation IoU: 0.5351


  _warn_about_invalid_encoding(encoding, patched)


Epoch [18/100] - Train IoU: 0.5299 - Validation IoU: 0.5404


  _warn_about_invalid_encoding(encoding, patched)


Epoch [19/100] - Train IoU: 0.5439 - Validation IoU: 0.5530


  _warn_about_invalid_encoding(encoding, patched)


Epoch [20/100] - Train IoU: 0.5614 - Validation IoU: 0.5512


  _warn_about_invalid_encoding(encoding, patched)


Epoch [21/100] - Train IoU: 0.5720 - Validation IoU: 0.5615


  _warn_about_invalid_encoding(encoding, patched)


Epoch [22/100] - Train IoU: 0.5831 - Validation IoU: 0.5779


  _warn_about_invalid_encoding(encoding, patched)


Epoch [23/100] - Train IoU: 0.5958 - Validation IoU: 0.5923


  _warn_about_invalid_encoding(encoding, patched)


Epoch [24/100] - Train IoU: 0.6072 - Validation IoU: 0.5928


  _warn_about_invalid_encoding(encoding, patched)


Epoch [25/100] - Train IoU: 0.6158 - Validation IoU: 0.5948


  _warn_about_invalid_encoding(encoding, patched)


Epoch [26/100] - Train IoU: 0.6217 - Validation IoU: 0.6023


  _warn_about_invalid_encoding(encoding, patched)


Epoch [27/100] - Train IoU: 0.6293 - Validation IoU: 0.6131


  _warn_about_invalid_encoding(encoding, patched)


Epoch [28/100] - Train IoU: 0.6320 - Validation IoU: 0.6158


  _warn_about_invalid_encoding(encoding, patched)


Epoch [29/100] - Train IoU: 0.6366 - Validation IoU: 0.6202


  _warn_about_invalid_encoding(encoding, patched)


Epoch [30/100] - Train IoU: 0.6403 - Validation IoU: 0.6190


  _warn_about_invalid_encoding(encoding, patched)


Epoch [31/100] - Train IoU: 0.6432 - Validation IoU: 0.6245


  _warn_about_invalid_encoding(encoding, patched)


Epoch [32/100] - Train IoU: 0.6460 - Validation IoU: 0.6251


  _warn_about_invalid_encoding(encoding, patched)


Epoch [33/100] - Train IoU: 0.6479 - Validation IoU: 0.6253


  _warn_about_invalid_encoding(encoding, patched)


Epoch [34/100] - Train IoU: 0.6492 - Validation IoU: 0.6311


  _warn_about_invalid_encoding(encoding, patched)


Epoch [35/100] - Train IoU: 0.6531 - Validation IoU: 0.6292


  _warn_about_invalid_encoding(encoding, patched)


Epoch [36/100] - Train IoU: 0.6521 - Validation IoU: 0.6283


  _warn_about_invalid_encoding(encoding, patched)


Epoch [37/100] - Train IoU: 0.6557 - Validation IoU: 0.6301


  _warn_about_invalid_encoding(encoding, patched)


Epoch [38/100] - Train IoU: 0.6549 - Validation IoU: 0.6327


  _warn_about_invalid_encoding(encoding, patched)


Epoch [39/100] - Train IoU: 0.6586 - Validation IoU: 0.6324


  _warn_about_invalid_encoding(encoding, patched)


Epoch [40/100] - Train IoU: 0.6586 - Validation IoU: 0.6315


  _warn_about_invalid_encoding(encoding, patched)


Epoch [41/100] - Train IoU: 0.6607 - Validation IoU: 0.6328


  _warn_about_invalid_encoding(encoding, patched)


Epoch [42/100] - Train IoU: 0.6620 - Validation IoU: 0.6339


  _warn_about_invalid_encoding(encoding, patched)


Epoch [43/100] - Train IoU: 0.6635 - Validation IoU: 0.6346


  _warn_about_invalid_encoding(encoding, patched)


Epoch [44/100] - Train IoU: 0.6648 - Validation IoU: 0.6375


  _warn_about_invalid_encoding(encoding, patched)


Epoch [45/100] - Train IoU: 0.6673 - Validation IoU: 0.6370


  _warn_about_invalid_encoding(encoding, patched)


Epoch [46/100] - Train IoU: 0.6682 - Validation IoU: 0.6417


  _warn_about_invalid_encoding(encoding, patched)


Epoch [47/100] - Train IoU: 0.6739 - Validation IoU: 0.6470


  _warn_about_invalid_encoding(encoding, patched)


Epoch [48/100] - Train IoU: 0.6777 - Validation IoU: 0.6536


  _warn_about_invalid_encoding(encoding, patched)


Epoch [49/100] - Train IoU: 0.6863 - Validation IoU: 0.6521


  _warn_about_invalid_encoding(encoding, patched)


Epoch [50/100] - Train IoU: 0.6908 - Validation IoU: 0.6643


  _warn_about_invalid_encoding(encoding, patched)


Epoch [51/100] - Train IoU: 0.6971 - Validation IoU: 0.6656


  _warn_about_invalid_encoding(encoding, patched)


Epoch [52/100] - Train IoU: 0.7027 - Validation IoU: 0.6733


  _warn_about_invalid_encoding(encoding, patched)


Epoch [53/100] - Train IoU: 0.7041 - Validation IoU: 0.6763


  _warn_about_invalid_encoding(encoding, patched)


Epoch [54/100] - Train IoU: 0.7100 - Validation IoU: 0.6830


  _warn_about_invalid_encoding(encoding, patched)


Epoch [55/100] - Train IoU: 0.7106 - Validation IoU: 0.6875


  _warn_about_invalid_encoding(encoding, patched)


Epoch [56/100] - Train IoU: 0.7159 - Validation IoU: 0.6888


  _warn_about_invalid_encoding(encoding, patched)


Epoch [57/100] - Train IoU: 0.7181 - Validation IoU: 0.6918


  _warn_about_invalid_encoding(encoding, patched)


Epoch [58/100] - Train IoU: 0.7216 - Validation IoU: 0.6923


  _warn_about_invalid_encoding(encoding, patched)


Epoch [59/100] - Train IoU: 0.7243 - Validation IoU: 0.6956


  _warn_about_invalid_encoding(encoding, patched)


Epoch [60/100] - Train IoU: 0.7261 - Validation IoU: 0.6948


  _warn_about_invalid_encoding(encoding, patched)


Epoch [61/100] - Train IoU: 0.7287 - Validation IoU: 0.6974


  _warn_about_invalid_encoding(encoding, patched)


Epoch [62/100] - Train IoU: 0.7292 - Validation IoU: 0.6996


  _warn_about_invalid_encoding(encoding, patched)


Epoch [63/100] - Train IoU: 0.7317 - Validation IoU: 0.6989


  _warn_about_invalid_encoding(encoding, patched)


Epoch [64/100] - Train IoU: 0.7318 - Validation IoU: 0.7008


  _warn_about_invalid_encoding(encoding, patched)


Epoch [65/100] - Train IoU: 0.7343 - Validation IoU: 0.7010


  _warn_about_invalid_encoding(encoding, patched)


Epoch [66/100] - Train IoU: 0.7341 - Validation IoU: 0.7011


  _warn_about_invalid_encoding(encoding, patched)


Epoch [67/100] - Train IoU: 0.7354 - Validation IoU: 0.7016


  _warn_about_invalid_encoding(encoding, patched)


Epoch [68/100] - Train IoU: 0.7370 - Validation IoU: 0.7022


  _warn_about_invalid_encoding(encoding, patched)


Epoch [69/100] - Train IoU: 0.7366 - Validation IoU: 0.7042


  _warn_about_invalid_encoding(encoding, patched)


Epoch [70/100] - Train IoU: 0.7383 - Validation IoU: 0.7015


  _warn_about_invalid_encoding(encoding, patched)


Epoch [71/100] - Train IoU: 0.7378 - Validation IoU: 0.7041


  _warn_about_invalid_encoding(encoding, patched)


Epoch [72/100] - Train IoU: 0.7395 - Validation IoU: 0.7047


  _warn_about_invalid_encoding(encoding, patched)


Epoch [73/100] - Train IoU: 0.7402 - Validation IoU: 0.7067


  _warn_about_invalid_encoding(encoding, patched)


Epoch [74/100] - Train IoU: 0.7399 - Validation IoU: 0.7073


  _warn_about_invalid_encoding(encoding, patched)


Epoch [75/100] - Train IoU: 0.7420 - Validation IoU: 0.7083


  _warn_about_invalid_encoding(encoding, patched)


Epoch [76/100] - Train IoU: 0.7416 - Validation IoU: 0.7091


  _warn_about_invalid_encoding(encoding, patched)


Epoch [77/100] - Train IoU: 0.7439 - Validation IoU: 0.7080


  _warn_about_invalid_encoding(encoding, patched)


Epoch [78/100] - Train IoU: 0.7432 - Validation IoU: 0.7084


  _warn_about_invalid_encoding(encoding, patched)


Epoch [79/100] - Train IoU: 0.7453 - Validation IoU: 0.7103


  _warn_about_invalid_encoding(encoding, patched)


Epoch [80/100] - Train IoU: 0.7461 - Validation IoU: 0.7111


  _warn_about_invalid_encoding(encoding, patched)


Epoch [81/100] - Train IoU: 0.7482 - Validation IoU: 0.7130


  _warn_about_invalid_encoding(encoding, patched)


Epoch [82/100] - Train IoU: 0.7487 - Validation IoU: 0.7132


  _warn_about_invalid_encoding(encoding, patched)


Epoch [83/100] - Train IoU: 0.7515 - Validation IoU: 0.7158


  _warn_about_invalid_encoding(encoding, patched)


Epoch [84/100] - Train IoU: 0.7530 - Validation IoU: 0.7165


  _warn_about_invalid_encoding(encoding, patched)


Epoch [85/100] - Train IoU: 0.7550 - Validation IoU: 0.7199


  _warn_about_invalid_encoding(encoding, patched)


Epoch [86/100] - Train IoU: 0.7578 - Validation IoU: 0.7195


  _warn_about_invalid_encoding(encoding, patched)


Epoch [87/100] - Train IoU: 0.7576 - Validation IoU: 0.7213


  _warn_about_invalid_encoding(encoding, patched)


Epoch [88/100] - Train IoU: 0.7620 - Validation IoU: 0.7211


  _warn_about_invalid_encoding(encoding, patched)


Epoch [89/100] - Train IoU: 0.7625 - Validation IoU: 0.7245


  _warn_about_invalid_encoding(encoding, patched)


Epoch [90/100] - Train IoU: 0.7647 - Validation IoU: 0.7236


  _warn_about_invalid_encoding(encoding, patched)


Epoch [91/100] - Train IoU: 0.7667 - Validation IoU: 0.7278


  _warn_about_invalid_encoding(encoding, patched)


Epoch [92/100] - Train IoU: 0.7688 - Validation IoU: 0.7282


  _warn_about_invalid_encoding(encoding, patched)


Epoch [93/100] - Train IoU: 0.7695 - Validation IoU: 0.7304


  _warn_about_invalid_encoding(encoding, patched)


Epoch [94/100] - Train IoU: 0.7733 - Validation IoU: 0.7323


  _warn_about_invalid_encoding(encoding, patched)


Epoch [95/100] - Train IoU: 0.7755 - Validation IoU: 0.7344


  _warn_about_invalid_encoding(encoding, patched)


Epoch [96/100] - Train IoU: 0.7755 - Validation IoU: 0.7335


  _warn_about_invalid_encoding(encoding, patched)


Epoch [97/100] - Train IoU: 0.7799 - Validation IoU: 0.7375


  _warn_about_invalid_encoding(encoding, patched)


Epoch [98/100] - Train IoU: 0.7787 - Validation IoU: 0.7406


  _warn_about_invalid_encoding(encoding, patched)


Epoch [99/100] - Train IoU: 0.7840 - Validation IoU: 0.7439


  _warn_about_invalid_encoding(encoding, patched)


Epoch [100/100] - Train IoU: 0.7848 - Validation IoU: 0.7452
Model saved successfully.


In [None]:
import os
import numpy as np
import pydicom
import nibabel as nib
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import random
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import jaccard_score


random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


num_classes = 9
batch_size = 4
img_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Images"
mask_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Annotations"
metadata_path = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/segmentation_with_racegender.csv"


test_augmentations = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.485, 0.485), std=(0.229, 0.229, 0.229)),
    ToTensorV2(),
])

class MulticlassHipSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, metadata_df, paired_files, num_classes, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.metadata_df = metadata_df
        self.paired_files = paired_files
        self.num_classes = num_classes
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        image_file, mask_file = self.paired_files[idx]
        if not os.path.exists(os.path.join(self.mask_root, mask_file)):
            return None

        dicom_image = pydicom.dcmread(os.path.join(self.img_root, image_file))
        image = dicom_image.pixel_array.astype(np.float32)
        image = cv2.resize(image, (512, 512))


        image = np.stack([image] * 3, axis=-1)

        annotation = nib.load(os.path.join(self.mask_root, mask_file))
        annotation_data = annotation.get_fdata()
        if len(annotation_data.shape) == 3:
            annotation_data = annotation_data[:, :, 0]

        annotation_data = self.calculate_flipped_rotated_mask(annotation_data)
        annotation_data = cv2.resize(annotation_data, (512, 512))  # Resize to match model input size

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=annotation_data)
            image = transformed["image"]
            annotation_data = transformed["mask"]

        annotation_data_onehot = self.one_hot_encode(annotation_data)

        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=annotation_data_onehot)
            image = transformed["image"]
            annotation_data_onehot = transformed["mask"]

        patient_id = int(float(image_file.split(".")[0]))
        racegender_info = self.metadata_df.loc[self.metadata_df['id'] == patient_id]['racegender'].values
        racegender = racegender_info[0] if racegender_info.size > 0 else 'Unknown'

        return image, annotation_data_onehot, racegender

    def one_hot_encode(self, mask):
        one_hot_mask = np.zeros((self.num_classes, *mask.shape), dtype=np.float32)
        for class_idx in range(self.num_classes):
            one_hot_mask[class_idx][mask == class_idx] = 1.0
        return one_hot_mask

    def calculate_flipped_rotated_mask(self, mask):
        rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        flipped_rotated_mask = cv2.flip(rotated_mask, 1)
        return flipped_rotated_mask

metadata_df = pd.read_csv(metadata_path)

image_files = sorted(os.listdir(img_root))
mask_files = sorted(os.listdir(mask_root))

paired_files = []

for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    mask_file = f"{image_id}.nii.gz"
    if mask_file in mask_files:
        paired_files.append((image_file, mask_file))

random.shuffle(paired_files)

train_size = int(0.7 * len(paired_files))
valid_size = int(0.1 * len(paired_files))
test_size = len(paired_files) - train_size - valid_size

train_pairs = paired_files[:train_size]
valid_pairs = paired_files[train_size:train_size + valid_size]
test_pairs = paired_files[train_size + valid_size:]

train_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, train_pairs, num_classes,
    transforms=test_augmentations
)

def train_pspnet_model(num_epochs=100, encoder_name="resnet18"):

    model = smp.PSPNet(
        encoder_name=encoder_name,
        encoder_weights="imagenet",
        in_channels=3,
        classes=num_classes,
    )

    device = torch.device("cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    valid_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, valid_pairs, num_classes,
        transforms=test_augmentations
    )

    valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=2)

    for epoch in range(num_epochs):
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

        model.train()
        train_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(train_loader):
            if images is None:
                continue

            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)

            predicted_masks = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, masks.argmax(dim=1))
            loss.backward()
            optimizer.step()

            train_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )
            train_iou_list.append(train_iou)

        model.eval()
        valid_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(valid_loader):
            images, masks = images.to(device), masks.to(device)
            with torch.no_grad():
                outputs = model(images)
            predicted_masks = torch.argmax(outputs, dim=1)

            valid_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )

            valid_iou_list.append(valid_iou)

        train_iou_avg = np.mean(train_iou_list)
        valid_iou_avg = np.mean(valid_iou_list)

        print(f"Epoch [{epoch + 1}/{num_epochs}] - Train IoU: {train_iou_avg:.4f} - Validation IoU: {valid_iou_avg:.4f}")


    torch.save(model.state_dict(), 'pspnet_model_knee_init.pth')


    test_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, test_pairs, num_classes,
        transforms=test_augmentations
    )

    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    model.eval()
    test_iou_list = []

    for batch_idx, (images, masks, racegender) in enumerate(test_loader):
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(images)
        predicted_masks = torch.argmax(outputs, dim=1)

        test_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )

        test_iou_list.append(test_iou)

    test_iou_avg = np.mean(test_iou_list)

    print("Model saved successfully.")


print("Training PSPNet ..")
train_pspnet_model(encoder_name="resnet18")


Training PSPNet ..


  _warn_about_invalid_encoding(encoding, patched)


Epoch [1/100] - Train IoU: 0.0905 - Validation IoU: 0.1097


  _warn_about_invalid_encoding(encoding, patched)


Epoch [2/100] - Train IoU: 0.4125 - Validation IoU: 0.5950


  _warn_about_invalid_encoding(encoding, patched)


Epoch [3/100] - Train IoU: 0.6076 - Validation IoU: 0.6040


  _warn_about_invalid_encoding(encoding, patched)


Epoch [4/100] - Train IoU: 0.6184 - Validation IoU: 0.6040


  _warn_about_invalid_encoding(encoding, patched)


Epoch [5/100] - Train IoU: 0.6199 - Validation IoU: 0.6040


  _warn_about_invalid_encoding(encoding, patched)


Epoch [6/100] - Train IoU: 0.6311 - Validation IoU: 0.6041


  _warn_about_invalid_encoding(encoding, patched)


Epoch [7/100] - Train IoU: 0.6498 - Validation IoU: 0.6046


  _warn_about_invalid_encoding(encoding, patched)


Epoch [8/100] - Train IoU: 0.6838 - Validation IoU: 0.6078


  _warn_about_invalid_encoding(encoding, patched)


Epoch [9/100] - Train IoU: 0.7028 - Validation IoU: 0.6146


  _warn_about_invalid_encoding(encoding, patched)


Epoch [10/100] - Train IoU: 0.7240 - Validation IoU: 0.6282


  _warn_about_invalid_encoding(encoding, patched)


Epoch [11/100] - Train IoU: 0.7233 - Validation IoU: 0.6550


  _warn_about_invalid_encoding(encoding, patched)


Epoch [12/100] - Train IoU: 0.7326 - Validation IoU: 0.6856


  _warn_about_invalid_encoding(encoding, patched)


Epoch [13/100] - Train IoU: 0.7486 - Validation IoU: 0.7162


  _warn_about_invalid_encoding(encoding, patched)


Epoch [14/100] - Train IoU: 0.7667 - Validation IoU: 0.7455


  _warn_about_invalid_encoding(encoding, patched)


Epoch [15/100] - Train IoU: 0.7822 - Validation IoU: 0.7651


  _warn_about_invalid_encoding(encoding, patched)


Epoch [16/100] - Train IoU: 0.8000 - Validation IoU: 0.7664


  _warn_about_invalid_encoding(encoding, patched)


Epoch [17/100] - Train IoU: 0.8058 - Validation IoU: 0.7838


  _warn_about_invalid_encoding(encoding, patched)


Epoch [18/100] - Train IoU: 0.8262 - Validation IoU: 0.7912


  _warn_about_invalid_encoding(encoding, patched)


Epoch [19/100] - Train IoU: 0.8316 - Validation IoU: 0.7792


  _warn_about_invalid_encoding(encoding, patched)


Epoch [20/100] - Train IoU: 0.8389 - Validation IoU: 0.7918


  _warn_about_invalid_encoding(encoding, patched)


Epoch [21/100] - Train IoU: 0.8463 - Validation IoU: 0.8030


  _warn_about_invalid_encoding(encoding, patched)


Epoch [22/100] - Train IoU: 0.8569 - Validation IoU: 0.8034


  _warn_about_invalid_encoding(encoding, patched)


Epoch [23/100] - Train IoU: 0.8587 - Validation IoU: 0.8069


  _warn_about_invalid_encoding(encoding, patched)


Epoch [24/100] - Train IoU: 0.8674 - Validation IoU: 0.8247


  _warn_about_invalid_encoding(encoding, patched)


Epoch [25/100] - Train IoU: 0.8691 - Validation IoU: 0.8374


  _warn_about_invalid_encoding(encoding, patched)


Epoch [26/100] - Train IoU: 0.8815 - Validation IoU: 0.8498


  _warn_about_invalid_encoding(encoding, patched)


Epoch [27/100] - Train IoU: 0.8783 - Validation IoU: 0.8542


  _warn_about_invalid_encoding(encoding, patched)


Epoch [28/100] - Train IoU: 0.8875 - Validation IoU: 0.8535


  _warn_about_invalid_encoding(encoding, patched)


Epoch [29/100] - Train IoU: 0.8906 - Validation IoU: 0.8549


  _warn_about_invalid_encoding(encoding, patched)


Epoch [30/100] - Train IoU: 0.8984 - Validation IoU: 0.8575


  _warn_about_invalid_encoding(encoding, patched)


Epoch [31/100] - Train IoU: 0.8980 - Validation IoU: 0.8684


  _warn_about_invalid_encoding(encoding, patched)


Epoch [32/100] - Train IoU: 0.9032 - Validation IoU: 0.8760


  _warn_about_invalid_encoding(encoding, patched)


Epoch [33/100] - Train IoU: 0.9040 - Validation IoU: 0.8811


  _warn_about_invalid_encoding(encoding, patched)


Epoch [34/100] - Train IoU: 0.9067 - Validation IoU: 0.8753


  _warn_about_invalid_encoding(encoding, patched)


Epoch [35/100] - Train IoU: 0.9094 - Validation IoU: 0.8794


  _warn_about_invalid_encoding(encoding, patched)


Epoch [36/100] - Train IoU: 0.9142 - Validation IoU: 0.8850


  _warn_about_invalid_encoding(encoding, patched)


Epoch [37/100] - Train IoU: 0.9205 - Validation IoU: 0.8858


  _warn_about_invalid_encoding(encoding, patched)


Epoch [38/100] - Train IoU: 0.9163 - Validation IoU: 0.8878


  _warn_about_invalid_encoding(encoding, patched)


Epoch [39/100] - Train IoU: 0.9182 - Validation IoU: 0.8996


  _warn_about_invalid_encoding(encoding, patched)


Epoch [40/100] - Train IoU: 0.9249 - Validation IoU: 0.9042


  _warn_about_invalid_encoding(encoding, patched)


Epoch [41/100] - Train IoU: 0.9287 - Validation IoU: 0.9043


  _warn_about_invalid_encoding(encoding, patched)


Epoch [42/100] - Train IoU: 0.9301 - Validation IoU: 0.9021


  _warn_about_invalid_encoding(encoding, patched)


Epoch [43/100] - Train IoU: 0.9331 - Validation IoU: 0.9025


  _warn_about_invalid_encoding(encoding, patched)


Epoch [44/100] - Train IoU: 0.9318 - Validation IoU: 0.9064


  _warn_about_invalid_encoding(encoding, patched)


Epoch [45/100] - Train IoU: 0.9339 - Validation IoU: 0.9038


  _warn_about_invalid_encoding(encoding, patched)


Epoch [46/100] - Train IoU: 0.9342 - Validation IoU: 0.9037


  _warn_about_invalid_encoding(encoding, patched)


Epoch [47/100] - Train IoU: 0.9356 - Validation IoU: 0.9035


  _warn_about_invalid_encoding(encoding, patched)


Epoch [48/100] - Train IoU: 0.9383 - Validation IoU: 0.9073


  _warn_about_invalid_encoding(encoding, patched)


Epoch [49/100] - Train IoU: 0.9409 - Validation IoU: 0.9096


  _warn_about_invalid_encoding(encoding, patched)


Epoch [50/100] - Train IoU: 0.9379 - Validation IoU: 0.9141


  _warn_about_invalid_encoding(encoding, patched)


Epoch [51/100] - Train IoU: 0.9396 - Validation IoU: 0.9145


  _warn_about_invalid_encoding(encoding, patched)


Epoch [52/100] - Train IoU: 0.9394 - Validation IoU: 0.9142


  _warn_about_invalid_encoding(encoding, patched)


Epoch [53/100] - Train IoU: 0.9433 - Validation IoU: 0.9144


  _warn_about_invalid_encoding(encoding, patched)


Epoch [54/100] - Train IoU: 0.9423 - Validation IoU: 0.9138


  _warn_about_invalid_encoding(encoding, patched)


Epoch [55/100] - Train IoU: 0.9425 - Validation IoU: 0.9136


  _warn_about_invalid_encoding(encoding, patched)


Epoch [56/100] - Train IoU: 0.9469 - Validation IoU: 0.9132


  _warn_about_invalid_encoding(encoding, patched)


Epoch [57/100] - Train IoU: 0.9466 - Validation IoU: 0.9167


  _warn_about_invalid_encoding(encoding, patched)


Epoch [58/100] - Train IoU: 0.9462 - Validation IoU: 0.9185


  _warn_about_invalid_encoding(encoding, patched)


Epoch [59/100] - Train IoU: 0.9473 - Validation IoU: 0.9182


  _warn_about_invalid_encoding(encoding, patched)


Epoch [60/100] - Train IoU: 0.9486 - Validation IoU: 0.9180


  _warn_about_invalid_encoding(encoding, patched)


Epoch [61/100] - Train IoU: 0.9496 - Validation IoU: 0.9181


  _warn_about_invalid_encoding(encoding, patched)


Epoch [62/100] - Train IoU: 0.9498 - Validation IoU: 0.9172


  _warn_about_invalid_encoding(encoding, patched)


Epoch [63/100] - Train IoU: 0.9511 - Validation IoU: 0.9170


  _warn_about_invalid_encoding(encoding, patched)


Epoch [64/100] - Train IoU: 0.9506 - Validation IoU: 0.9198


  _warn_about_invalid_encoding(encoding, patched)


Epoch [65/100] - Train IoU: 0.9480 - Validation IoU: 0.9212


  _warn_about_invalid_encoding(encoding, patched)


Epoch [66/100] - Train IoU: 0.9503 - Validation IoU: 0.9210


  _warn_about_invalid_encoding(encoding, patched)


Epoch [67/100] - Train IoU: 0.9517 - Validation IoU: 0.9221


  _warn_about_invalid_encoding(encoding, patched)


Epoch [68/100] - Train IoU: 0.9514 - Validation IoU: 0.9229


  _warn_about_invalid_encoding(encoding, patched)


Epoch [69/100] - Train IoU: 0.9538 - Validation IoU: 0.9204


  _warn_about_invalid_encoding(encoding, patched)


Epoch [70/100] - Train IoU: 0.9536 - Validation IoU: 0.9189


  _warn_about_invalid_encoding(encoding, patched)


Epoch [71/100] - Train IoU: 0.9546 - Validation IoU: 0.9204


  _warn_about_invalid_encoding(encoding, patched)


Epoch [72/100] - Train IoU: 0.9540 - Validation IoU: 0.9219


  _warn_about_invalid_encoding(encoding, patched)


Epoch [73/100] - Train IoU: 0.9551 - Validation IoU: 0.9199


  _warn_about_invalid_encoding(encoding, patched)


Epoch [74/100] - Train IoU: 0.9545 - Validation IoU: 0.9204


  _warn_about_invalid_encoding(encoding, patched)


Epoch [75/100] - Train IoU: 0.9559 - Validation IoU: 0.9248


  _warn_about_invalid_encoding(encoding, patched)


Epoch [76/100] - Train IoU: 0.9543 - Validation IoU: 0.9263


  _warn_about_invalid_encoding(encoding, patched)


Epoch [77/100] - Train IoU: 0.9584 - Validation IoU: 0.9261


  _warn_about_invalid_encoding(encoding, patched)


Epoch [78/100] - Train IoU: 0.9570 - Validation IoU: 0.9263


  _warn_about_invalid_encoding(encoding, patched)


Epoch [79/100] - Train IoU: 0.9553 - Validation IoU: 0.9292


  _warn_about_invalid_encoding(encoding, patched)


Epoch [80/100] - Train IoU: 0.9572 - Validation IoU: 0.9311


  _warn_about_invalid_encoding(encoding, patched)


Epoch [81/100] - Train IoU: 0.9582 - Validation IoU: 0.9304


  _warn_about_invalid_encoding(encoding, patched)


Epoch [82/100] - Train IoU: 0.9593 - Validation IoU: 0.9273


  _warn_about_invalid_encoding(encoding, patched)


Epoch [83/100] - Train IoU: 0.9572 - Validation IoU: 0.9239


  _warn_about_invalid_encoding(encoding, patched)


Epoch [84/100] - Train IoU: 0.9609 - Validation IoU: 0.9228


  _warn_about_invalid_encoding(encoding, patched)


Epoch [85/100] - Train IoU: 0.9596 - Validation IoU: 0.9257


  _warn_about_invalid_encoding(encoding, patched)


Epoch [86/100] - Train IoU: 0.9610 - Validation IoU: 0.9276


  _warn_about_invalid_encoding(encoding, patched)


Epoch [87/100] - Train IoU: 0.9599 - Validation IoU: 0.9270


  _warn_about_invalid_encoding(encoding, patched)


Epoch [88/100] - Train IoU: 0.9614 - Validation IoU: 0.9264


  _warn_about_invalid_encoding(encoding, patched)


Epoch [89/100] - Train IoU: 0.9614 - Validation IoU: 0.9280


  _warn_about_invalid_encoding(encoding, patched)


Epoch [90/100] - Train IoU: 0.9616 - Validation IoU: 0.9285


  _warn_about_invalid_encoding(encoding, patched)


Epoch [91/100] - Train IoU: 0.9600 - Validation IoU: 0.9281


  _warn_about_invalid_encoding(encoding, patched)


Epoch [92/100] - Train IoU: 0.9624 - Validation IoU: 0.9298


  _warn_about_invalid_encoding(encoding, patched)


Epoch [93/100] - Train IoU: 0.9620 - Validation IoU: 0.9289


  _warn_about_invalid_encoding(encoding, patched)


Epoch [94/100] - Train IoU: 0.9633 - Validation IoU: 0.9274


  _warn_about_invalid_encoding(encoding, patched)


Epoch [95/100] - Train IoU: 0.9610 - Validation IoU: 0.9301


  _warn_about_invalid_encoding(encoding, patched)


Epoch [96/100] - Train IoU: 0.9631 - Validation IoU: 0.9305


  _warn_about_invalid_encoding(encoding, patched)


Epoch [97/100] - Train IoU: 0.9604 - Validation IoU: 0.9317


  _warn_about_invalid_encoding(encoding, patched)


Epoch [98/100] - Train IoU: 0.9619 - Validation IoU: 0.9313


  _warn_about_invalid_encoding(encoding, patched)


Epoch [99/100] - Train IoU: 0.9633 - Validation IoU: 0.9313


  _warn_about_invalid_encoding(encoding, patched)


Epoch [100/100] - Train IoU: 0.9651 - Validation IoU: 0.9301
Model saved successfully.


In [None]:
import os
import numpy as np
import pydicom
import nibabel as nib
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import random
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import jaccard_score


random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


num_classes = 9
batch_size = 4
img_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Images"
mask_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Annotations"
metadata_path = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/segmentation_with_racegender.csv"


test_augmentations = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.485, 0.485), std=(0.229, 0.229, 0.229)),
    ToTensorV2(),
])

class MulticlassHipSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, metadata_df, paired_files, num_classes, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.metadata_df = metadata_df
        self.paired_files = paired_files
        self.num_classes = num_classes
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        image_file, mask_file = self.paired_files[idx]
        if not os.path.exists(os.path.join(self.mask_root, mask_file)):
            return None

        dicom_image = pydicom.dcmread(os.path.join(self.img_root, image_file))
        image = dicom_image.pixel_array.astype(np.float32)
        image = cv2.resize(image, (512, 512))


        image = np.stack([image] * 3, axis=-1)

        annotation = nib.load(os.path.join(self.mask_root, mask_file))
        annotation_data = annotation.get_fdata()
        if len(annotation_data.shape) == 3:
            annotation_data = annotation_data[:, :, 0]

        annotation_data = self.calculate_flipped_rotated_mask(annotation_data)
        annotation_data = cv2.resize(annotation_data, (512, 512))

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=annotation_data)
            image = transformed["image"]
            annotation_data = transformed["mask"]

        annotation_data_onehot = self.one_hot_encode(annotation_data)

        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=annotation_data_onehot)
            image = transformed["image"]
            annotation_data_onehot = transformed["mask"]

        patient_id = int(float(image_file.split(".")[0]))
        racegender_info = self.metadata_df.loc[self.metadata_df['id'] == patient_id]['racegender'].values
        racegender = racegender_info[0] if racegender_info.size > 0 else 'Unknown'

        return image, annotation_data_onehot, racegender

    def one_hot_encode(self, mask):
        one_hot_mask = np.zeros((self.num_classes, *mask.shape), dtype=np.float32)
        for class_idx in range(self.num_classes):
            one_hot_mask[class_idx][mask == class_idx] = 1.0
        return one_hot_mask

    def calculate_flipped_rotated_mask(self, mask):
        rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        flipped_rotated_mask = cv2.flip(rotated_mask, 1)
        return flipped_rotated_mask

metadata_df = pd.read_csv(metadata_path)

image_files = sorted(os.listdir(img_root))
mask_files = sorted(os.listdir(mask_root))

paired_files = []

for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    mask_file = f"{image_id}.nii.gz"
    if mask_file in mask_files:
        paired_files.append((image_file, mask_file))

random.shuffle(paired_files)

train_size = int(0.7 * len(paired_files))
valid_size = int(0.1 * len(paired_files))
test_size = len(paired_files) - train_size - valid_size

train_pairs = paired_files[:train_size]
valid_pairs = paired_files[train_size:train_size + valid_size]
test_pairs = paired_files[train_size + valid_size:]

train_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, train_pairs, num_classes,
    transforms=test_augmentations
)

def train_fpn_model(num_epochs=100, encoder_name="resnet18"):

    model = smp.FPN(
        encoder_name=encoder_name,
        encoder_weights="imagenet",
        in_channels=3,
        classes=num_classes,
    )

    device = torch.device("cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    valid_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, valid_pairs, num_classes,
        transforms=test_augmentations
    )

    valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=2)

    for epoch in range(num_epochs):
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

        model.train()
        train_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(train_loader):
            if images is None:
                continue

            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)

            predicted_masks = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, masks.argmax(dim=1))
            loss.backward()
            optimizer.step()

            train_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )
            train_iou_list.append(train_iou)

        model.eval()
        valid_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(valid_loader):
            images, masks = images.to(device), masks.to(device)
            with torch.no_grad():
                outputs = model(images)
            predicted_masks = torch.argmax(outputs, dim=1)

            valid_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )

            valid_iou_list.append(valid_iou)

        train_iou_avg = np.mean(train_iou_list)
        valid_iou_avg = np.mean(valid_iou_list)

        print(f"Epoch [{epoch + 1}/{num_epochs}] - Train IoU: {train_iou_avg:.4f} - Validation IoU: {valid_iou_avg:.4f}")


    torch.save(model.state_dict(), 'fpn_model_knee_init.pth')


    test_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, test_pairs, num_classes,
        transforms=test_augmentations
    )

    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    model.eval()
    test_iou_list = []

    for batch_idx, (images, masks, racegender) in enumerate(test_loader):
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(images)
        predicted_masks = torch.argmax(outputs, dim=1)

        test_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )

        test_iou_list.append(test_iou)

    test_iou_avg = np.mean(test_iou_list)

    print("Model saved successfully.")


print("Training FPN ..")
train_fpn_model(encoder_name="resnet18")


Training FPN ..


  _warn_about_invalid_encoding(encoding, patched)


Epoch [1/100] - Train IoU: 0.3603 - Validation IoU: 0.6039


  _warn_about_invalid_encoding(encoding, patched)


Epoch [2/100] - Train IoU: 0.6481 - Validation IoU: 0.6040


  _warn_about_invalid_encoding(encoding, patched)


Epoch [3/100] - Train IoU: 0.6898 - Validation IoU: 0.6270


  _warn_about_invalid_encoding(encoding, patched)


Epoch [4/100] - Train IoU: 0.7210 - Validation IoU: 0.6551


  _warn_about_invalid_encoding(encoding, patched)


Epoch [5/100] - Train IoU: 0.7486 - Validation IoU: 0.6630


  _warn_about_invalid_encoding(encoding, patched)


Epoch [6/100] - Train IoU: 0.7596 - Validation IoU: 0.7021


  _warn_about_invalid_encoding(encoding, patched)


Epoch [7/100] - Train IoU: 0.7884 - Validation IoU: 0.7513


  _warn_about_invalid_encoding(encoding, patched)


Epoch [8/100] - Train IoU: 0.8094 - Validation IoU: 0.7446


  _warn_about_invalid_encoding(encoding, patched)


Epoch [9/100] - Train IoU: 0.8215 - Validation IoU: 0.7544


  _warn_about_invalid_encoding(encoding, patched)


Epoch [10/100] - Train IoU: 0.8369 - Validation IoU: 0.7970


  _warn_about_invalid_encoding(encoding, patched)


Epoch [11/100] - Train IoU: 0.8403 - Validation IoU: 0.7946


  _warn_about_invalid_encoding(encoding, patched)


Epoch [12/100] - Train IoU: 0.8481 - Validation IoU: 0.7782


  _warn_about_invalid_encoding(encoding, patched)


Epoch [13/100] - Train IoU: 0.8572 - Validation IoU: 0.8096


  _warn_about_invalid_encoding(encoding, patched)


Epoch [14/100] - Train IoU: 0.8722 - Validation IoU: 0.8257


  _warn_about_invalid_encoding(encoding, patched)


Epoch [15/100] - Train IoU: 0.8713 - Validation IoU: 0.8224


  _warn_about_invalid_encoding(encoding, patched)


Epoch [16/100] - Train IoU: 0.8856 - Validation IoU: 0.8339


  _warn_about_invalid_encoding(encoding, patched)


Epoch [17/100] - Train IoU: 0.8875 - Validation IoU: 0.8448


  _warn_about_invalid_encoding(encoding, patched)


Epoch [18/100] - Train IoU: 0.8878 - Validation IoU: 0.8495


  _warn_about_invalid_encoding(encoding, patched)


Epoch [19/100] - Train IoU: 0.8952 - Validation IoU: 0.8497


  _warn_about_invalid_encoding(encoding, patched)


Epoch [20/100] - Train IoU: 0.9061 - Validation IoU: 0.8508


  _warn_about_invalid_encoding(encoding, patched)


Epoch [21/100] - Train IoU: 0.9040 - Validation IoU: 0.8609


  _warn_about_invalid_encoding(encoding, patched)


Epoch [22/100] - Train IoU: 0.9107 - Validation IoU: 0.8561


  _warn_about_invalid_encoding(encoding, patched)


Epoch [23/100] - Train IoU: 0.9088 - Validation IoU: 0.8404


  _warn_about_invalid_encoding(encoding, patched)


Epoch [24/100] - Train IoU: 0.9048 - Validation IoU: 0.8518


  _warn_about_invalid_encoding(encoding, patched)


Epoch [25/100] - Train IoU: 0.9212 - Validation IoU: 0.8659


  _warn_about_invalid_encoding(encoding, patched)


Epoch [26/100] - Train IoU: 0.9209 - Validation IoU: 0.8519


  _warn_about_invalid_encoding(encoding, patched)


Epoch [27/100] - Train IoU: 0.9264 - Validation IoU: 0.8491


  _warn_about_invalid_encoding(encoding, patched)


Epoch [28/100] - Train IoU: 0.9222 - Validation IoU: 0.8553


  _warn_about_invalid_encoding(encoding, patched)


Epoch [29/100] - Train IoU: 0.9282 - Validation IoU: 0.8687


  _warn_about_invalid_encoding(encoding, patched)


Epoch [30/100] - Train IoU: 0.9337 - Validation IoU: 0.8684


  _warn_about_invalid_encoding(encoding, patched)


Epoch [31/100] - Train IoU: 0.9286 - Validation IoU: 0.8647


  _warn_about_invalid_encoding(encoding, patched)


Epoch [32/100] - Train IoU: 0.9328 - Validation IoU: 0.8727


  _warn_about_invalid_encoding(encoding, patched)


Epoch [33/100] - Train IoU: 0.9338 - Validation IoU: 0.8777


  _warn_about_invalid_encoding(encoding, patched)


Epoch [34/100] - Train IoU: 0.9342 - Validation IoU: 0.8711


  _warn_about_invalid_encoding(encoding, patched)


Epoch [35/100] - Train IoU: 0.9313 - Validation IoU: 0.8732


  _warn_about_invalid_encoding(encoding, patched)


Epoch [36/100] - Train IoU: 0.9405 - Validation IoU: 0.8663


  _warn_about_invalid_encoding(encoding, patched)


Epoch [37/100] - Train IoU: 0.9397 - Validation IoU: 0.8589


  _warn_about_invalid_encoding(encoding, patched)


Epoch [38/100] - Train IoU: 0.9393 - Validation IoU: 0.8655


  _warn_about_invalid_encoding(encoding, patched)


Epoch [39/100] - Train IoU: 0.9418 - Validation IoU: 0.8682


  _warn_about_invalid_encoding(encoding, patched)


Epoch [40/100] - Train IoU: 0.9400 - Validation IoU: 0.8666


  _warn_about_invalid_encoding(encoding, patched)


Epoch [41/100] - Train IoU: 0.9450 - Validation IoU: 0.8653


  _warn_about_invalid_encoding(encoding, patched)


Epoch [42/100] - Train IoU: 0.9430 - Validation IoU: 0.8789


  _warn_about_invalid_encoding(encoding, patched)


Epoch [43/100] - Train IoU: 0.9477 - Validation IoU: 0.8756


  _warn_about_invalid_encoding(encoding, patched)


Epoch [44/100] - Train IoU: 0.9483 - Validation IoU: 0.8690


  _warn_about_invalid_encoding(encoding, patched)


Epoch [45/100] - Train IoU: 0.9503 - Validation IoU: 0.8793


  _warn_about_invalid_encoding(encoding, patched)


Epoch [46/100] - Train IoU: 0.9466 - Validation IoU: 0.8780


  _warn_about_invalid_encoding(encoding, patched)


Epoch [47/100] - Train IoU: 0.9530 - Validation IoU: 0.8728


  _warn_about_invalid_encoding(encoding, patched)


Epoch [48/100] - Train IoU: 0.9517 - Validation IoU: 0.8846


  _warn_about_invalid_encoding(encoding, patched)


Epoch [49/100] - Train IoU: 0.9524 - Validation IoU: 0.8870


  _warn_about_invalid_encoding(encoding, patched)


Epoch [50/100] - Train IoU: 0.9560 - Validation IoU: 0.8819


  _warn_about_invalid_encoding(encoding, patched)


Epoch [51/100] - Train IoU: 0.9529 - Validation IoU: 0.8808


  _warn_about_invalid_encoding(encoding, patched)


Epoch [52/100] - Train IoU: 0.9517 - Validation IoU: 0.8881


  _warn_about_invalid_encoding(encoding, patched)


Epoch [53/100] - Train IoU: 0.9562 - Validation IoU: 0.8792


  _warn_about_invalid_encoding(encoding, patched)


Epoch [54/100] - Train IoU: 0.9495 - Validation IoU: 0.8793


  _warn_about_invalid_encoding(encoding, patched)


Epoch [55/100] - Train IoU: 0.9554 - Validation IoU: 0.8864


  _warn_about_invalid_encoding(encoding, patched)


Epoch [56/100] - Train IoU: 0.9560 - Validation IoU: 0.8915


  _warn_about_invalid_encoding(encoding, patched)


Epoch [57/100] - Train IoU: 0.9565 - Validation IoU: 0.8919


  _warn_about_invalid_encoding(encoding, patched)


Epoch [58/100] - Train IoU: 0.9596 - Validation IoU: 0.8858


  _warn_about_invalid_encoding(encoding, patched)


Epoch [59/100] - Train IoU: 0.9572 - Validation IoU: 0.8914


  _warn_about_invalid_encoding(encoding, patched)


Epoch [60/100] - Train IoU: 0.9547 - Validation IoU: 0.8941


  _warn_about_invalid_encoding(encoding, patched)


Epoch [61/100] - Train IoU: 0.9590 - Validation IoU: 0.8885


  _warn_about_invalid_encoding(encoding, patched)


Epoch [62/100] - Train IoU: 0.9627 - Validation IoU: 0.8959


  _warn_about_invalid_encoding(encoding, patched)


Epoch [63/100] - Train IoU: 0.9630 - Validation IoU: 0.8991


  _warn_about_invalid_encoding(encoding, patched)


Epoch [64/100] - Train IoU: 0.9603 - Validation IoU: 0.8942


  _warn_about_invalid_encoding(encoding, patched)


Epoch [65/100] - Train IoU: 0.9601 - Validation IoU: 0.8948


  _warn_about_invalid_encoding(encoding, patched)


Epoch [66/100] - Train IoU: 0.9623 - Validation IoU: 0.9007


  _warn_about_invalid_encoding(encoding, patched)


Epoch [67/100] - Train IoU: 0.9574 - Validation IoU: 0.8960


  _warn_about_invalid_encoding(encoding, patched)


Epoch [68/100] - Train IoU: 0.9627 - Validation IoU: 0.8895


  _warn_about_invalid_encoding(encoding, patched)


Epoch [69/100] - Train IoU: 0.9631 - Validation IoU: 0.8997


  _warn_about_invalid_encoding(encoding, patched)


Epoch [70/100] - Train IoU: 0.9650 - Validation IoU: 0.9029


  _warn_about_invalid_encoding(encoding, patched)


Epoch [71/100] - Train IoU: 0.9650 - Validation IoU: 0.8976


  _warn_about_invalid_encoding(encoding, patched)


Epoch [72/100] - Train IoU: 0.9652 - Validation IoU: 0.8921


  _warn_about_invalid_encoding(encoding, patched)


Epoch [73/100] - Train IoU: 0.9654 - Validation IoU: 0.9032


  _warn_about_invalid_encoding(encoding, patched)


Epoch [74/100] - Train IoU: 0.9651 - Validation IoU: 0.9039


  _warn_about_invalid_encoding(encoding, patched)


Epoch [75/100] - Train IoU: 0.9677 - Validation IoU: 0.8975


  _warn_about_invalid_encoding(encoding, patched)


Epoch [76/100] - Train IoU: 0.9658 - Validation IoU: 0.9009


  _warn_about_invalid_encoding(encoding, patched)


Epoch [77/100] - Train IoU: 0.9664 - Validation IoU: 0.9025


  _warn_about_invalid_encoding(encoding, patched)


Epoch [78/100] - Train IoU: 0.9680 - Validation IoU: 0.9020


  _warn_about_invalid_encoding(encoding, patched)


Epoch [79/100] - Train IoU: 0.9685 - Validation IoU: 0.9001


  _warn_about_invalid_encoding(encoding, patched)


Epoch [80/100] - Train IoU: 0.9661 - Validation IoU: 0.8991


  _warn_about_invalid_encoding(encoding, patched)


Epoch [81/100] - Train IoU: 0.9654 - Validation IoU: 0.9046


  _warn_about_invalid_encoding(encoding, patched)


Epoch [82/100] - Train IoU: 0.9687 - Validation IoU: 0.9050


  _warn_about_invalid_encoding(encoding, patched)


Epoch [83/100] - Train IoU: 0.9689 - Validation IoU: 0.9033


  _warn_about_invalid_encoding(encoding, patched)


Epoch [84/100] - Train IoU: 0.9694 - Validation IoU: 0.9014


  _warn_about_invalid_encoding(encoding, patched)


Epoch [85/100] - Train IoU: 0.9700 - Validation IoU: 0.9045


  _warn_about_invalid_encoding(encoding, patched)


Epoch [86/100] - Train IoU: 0.9673 - Validation IoU: 0.9081


  _warn_about_invalid_encoding(encoding, patched)


Epoch [87/100] - Train IoU: 0.9681 - Validation IoU: 0.9013


  _warn_about_invalid_encoding(encoding, patched)


Epoch [88/100] - Train IoU: 0.9672 - Validation IoU: 0.9019


  _warn_about_invalid_encoding(encoding, patched)


Epoch [89/100] - Train IoU: 0.9688 - Validation IoU: 0.9091


  _warn_about_invalid_encoding(encoding, patched)


Epoch [90/100] - Train IoU: 0.9711 - Validation IoU: 0.9067


  _warn_about_invalid_encoding(encoding, patched)


Epoch [91/100] - Train IoU: 0.9696 - Validation IoU: 0.9012


  _warn_about_invalid_encoding(encoding, patched)


Epoch [92/100] - Train IoU: 0.9714 - Validation IoU: 0.9105


  _warn_about_invalid_encoding(encoding, patched)


Epoch [93/100] - Train IoU: 0.9695 - Validation IoU: 0.9093


  _warn_about_invalid_encoding(encoding, patched)


Epoch [94/100] - Train IoU: 0.9695 - Validation IoU: 0.9083


  _warn_about_invalid_encoding(encoding, patched)


Epoch [95/100] - Train IoU: 0.9711 - Validation IoU: 0.9094


  _warn_about_invalid_encoding(encoding, patched)


Epoch [96/100] - Train IoU: 0.9697 - Validation IoU: 0.9146


  _warn_about_invalid_encoding(encoding, patched)


Epoch [97/100] - Train IoU: 0.9690 - Validation IoU: 0.9114


  _warn_about_invalid_encoding(encoding, patched)


Epoch [98/100] - Train IoU: 0.9707 - Validation IoU: 0.9090


  _warn_about_invalid_encoding(encoding, patched)


Epoch [99/100] - Train IoU: 0.9694 - Validation IoU: 0.9124


  _warn_about_invalid_encoding(encoding, patched)


Epoch [100/100] - Train IoU: 0.9708 - Validation IoU: 0.9167
Model saved successfully.


In [None]:
import os
import numpy as np
import pydicom
import nibabel as nib
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import random
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import jaccard_score

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


num_classes = 9
batch_size = 4
img_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Images"
mask_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Annotations"
metadata_path = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/segmentation_with_racegender.csv"


test_augmentations = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.485, 0.485), std=(0.229, 0.229, 0.229)),
    ToTensorV2(),
])

class MulticlassHipSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, metadata_df, paired_files, num_classes, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.metadata_df = metadata_df
        self.paired_files = paired_files
        self.num_classes = num_classes
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        image_file, mask_file = self.paired_files[idx]
        if not os.path.exists(os.path.join(self.mask_root, mask_file)):
            return None

        dicom_image = pydicom.dcmread(os.path.join(self.img_root, image_file))
        image = dicom_image.pixel_array.astype(np.float32)
        image = cv2.resize(image, (512, 512))


        image = np.stack([image] * 3, axis=-1)

        annotation = nib.load(os.path.join(self.mask_root, mask_file))
        annotation_data = annotation.get_fdata()
        if len(annotation_data.shape) == 3:
            annotation_data = annotation_data[:, :, 0]

        annotation_data = self.calculate_flipped_rotated_mask(annotation_data)
        annotation_data = cv2.resize(annotation_data, (512, 512))

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=annotation_data)
            image = transformed["image"]
            annotation_data = transformed["mask"]

        annotation_data_onehot = self.one_hot_encode(annotation_data)

        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=annotation_data_onehot)
            image = transformed["image"]
            annotation_data_onehot = transformed["mask"]

        patient_id = int(float(image_file.split(".")[0]))
        racegender_info = self.metadata_df.loc[self.metadata_df['id'] == patient_id]['racegender'].values
        racegender = racegender_info[0] if racegender_info.size > 0 else 'Unknown'

        return image, annotation_data_onehot, racegender

    def one_hot_encode(self, mask):
        one_hot_mask = np.zeros((self.num_classes, *mask.shape), dtype=np.float32)
        for class_idx in range(self.num_classes):
            one_hot_mask[class_idx][mask == class_idx] = 1.0
        return one_hot_mask

    def calculate_flipped_rotated_mask(self, mask):
        rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        flipped_rotated_mask = cv2.flip(rotated_mask, 1)
        return flipped_rotated_mask

metadata_df = pd.read_csv(metadata_path)

image_files = sorted(os.listdir(img_root))
mask_files = sorted(os.listdir(mask_root))

paired_files = []

for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    mask_file = f"{image_id}.nii.gz"
    if mask_file in mask_files:
        paired_files.append((image_file, mask_file))

random.shuffle(paired_files)

train_size = int(0.7 * len(paired_files))
valid_size = int(0.1 * len(paired_files))
test_size = len(paired_files) - train_size - valid_size

train_pairs = paired_files[:train_size]
valid_pairs = paired_files[train_size:train_size + valid_size]
test_pairs = paired_files[train_size + valid_size:]

train_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, train_pairs, num_classes,
    transforms=test_augmentations
)

def train_pan_model(num_epochs=100, encoder_name="resnet18"):

    model = smp.PAN(
        encoder_name=encoder_name,
        encoder_weights="imagenet",
        in_channels=3,
        classes=num_classes,
    )

    device = torch.device("cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    valid_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, valid_pairs, num_classes,
        transforms=test_augmentations
    )

    valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=2)

    for epoch in range(num_epochs):
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

        model.train()
        train_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(train_loader):
            if images is None:
                continue

            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)

            predicted_masks = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, masks.argmax(dim=1))
            loss.backward()
            optimizer.step()

            train_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )
            train_iou_list.append(train_iou)

        model.eval()
        valid_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(valid_loader):
            images, masks = images.to(device), masks.to(device)
            with torch.no_grad():
                outputs = model(images)
            predicted_masks = torch.argmax(outputs, dim=1)

            valid_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )

            valid_iou_list.append(valid_iou)

        train_iou_avg = np.mean(train_iou_list)
        valid_iou_avg = np.mean(valid_iou_list)

        print(f"Epoch [{epoch + 1}/{num_epochs}] - Train IoU: {train_iou_avg:.4f} - Validation IoU: {valid_iou_avg:.4f}")


    torch.save(model.state_dict(), 'pan_model_knee_init.pth')


    test_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, test_pairs, num_classes,
        transforms=test_augmentations
    )

    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    model.eval()
    test_iou_list = []

    for batch_idx, (images, masks, racegender) in enumerate(test_loader):
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(images)
        predicted_masks = torch.argmax(outputs, dim=1)

        test_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )

        test_iou_list.append(test_iou)

    test_iou_avg = np.mean(test_iou_list)


    print("Model saved successfully.")


print("Training PAN ..")
train_pan_model(encoder_name="resnet18")


Training PAN ..


  _warn_about_invalid_encoding(encoding, patched)


Epoch [1/100] - Train IoU: 0.0249 - Validation IoU: 0.0225


  _warn_about_invalid_encoding(encoding, patched)


Epoch [2/100] - Train IoU: 0.0837 - Validation IoU: 0.0245


  _warn_about_invalid_encoding(encoding, patched)


Epoch [3/100] - Train IoU: 0.1546 - Validation IoU: 0.0341


  _warn_about_invalid_encoding(encoding, patched)


Epoch [4/100] - Train IoU: 0.2294 - Validation IoU: 0.0905


  _warn_about_invalid_encoding(encoding, patched)


Epoch [5/100] - Train IoU: 0.3119 - Validation IoU: 0.1421


  _warn_about_invalid_encoding(encoding, patched)


Epoch [6/100] - Train IoU: 0.3655 - Validation IoU: 0.1984


  _warn_about_invalid_encoding(encoding, patched)


Epoch [7/100] - Train IoU: 0.4288 - Validation IoU: 0.2881


  _warn_about_invalid_encoding(encoding, patched)


Epoch [8/100] - Train IoU: 0.4895 - Validation IoU: 0.3607


  _warn_about_invalid_encoding(encoding, patched)


Epoch [9/100] - Train IoU: 0.5661 - Validation IoU: 0.4472


  _warn_about_invalid_encoding(encoding, patched)


Epoch [10/100] - Train IoU: 0.6248 - Validation IoU: 0.5288


  _warn_about_invalid_encoding(encoding, patched)


Epoch [11/100] - Train IoU: 0.6629 - Validation IoU: 0.6087


  _warn_about_invalid_encoding(encoding, patched)


Epoch [12/100] - Train IoU: 0.7236 - Validation IoU: 0.6656


  _warn_about_invalid_encoding(encoding, patched)


Epoch [13/100] - Train IoU: 0.7536 - Validation IoU: 0.6824


  _warn_about_invalid_encoding(encoding, patched)


Epoch [14/100] - Train IoU: 0.7928 - Validation IoU: 0.7311


  _warn_about_invalid_encoding(encoding, patched)


Epoch [15/100] - Train IoU: 0.8061 - Validation IoU: 0.7728


  _warn_about_invalid_encoding(encoding, patched)


Epoch [16/100] - Train IoU: 0.8364 - Validation IoU: 0.7942


  _warn_about_invalid_encoding(encoding, patched)


Epoch [17/100] - Train IoU: 0.8517 - Validation IoU: 0.8054


  _warn_about_invalid_encoding(encoding, patched)


Epoch [18/100] - Train IoU: 0.8608 - Validation IoU: 0.8059


  _warn_about_invalid_encoding(encoding, patched)


Epoch [19/100] - Train IoU: 0.8708 - Validation IoU: 0.8081


  _warn_about_invalid_encoding(encoding, patched)


Epoch [20/100] - Train IoU: 0.8838 - Validation IoU: 0.8153


  _warn_about_invalid_encoding(encoding, patched)


Epoch [21/100] - Train IoU: 0.8877 - Validation IoU: 0.8213


  _warn_about_invalid_encoding(encoding, patched)


Epoch [22/100] - Train IoU: 0.8896 - Validation IoU: 0.8292


  _warn_about_invalid_encoding(encoding, patched)


Epoch [23/100] - Train IoU: 0.8988 - Validation IoU: 0.8402


  _warn_about_invalid_encoding(encoding, patched)


Epoch [24/100] - Train IoU: 0.9050 - Validation IoU: 0.8520


  _warn_about_invalid_encoding(encoding, patched)


Epoch [25/100] - Train IoU: 0.9081 - Validation IoU: 0.8543


  _warn_about_invalid_encoding(encoding, patched)


Epoch [26/100] - Train IoU: 0.9144 - Validation IoU: 0.8584


  _warn_about_invalid_encoding(encoding, patched)


Epoch [27/100] - Train IoU: 0.9128 - Validation IoU: 0.8586


  _warn_about_invalid_encoding(encoding, patched)


Epoch [28/100] - Train IoU: 0.9161 - Validation IoU: 0.8640


  _warn_about_invalid_encoding(encoding, patched)


Epoch [29/100] - Train IoU: 0.9241 - Validation IoU: 0.8674


  _warn_about_invalid_encoding(encoding, patched)


Epoch [30/100] - Train IoU: 0.9254 - Validation IoU: 0.8594


  _warn_about_invalid_encoding(encoding, patched)


Epoch [31/100] - Train IoU: 0.9284 - Validation IoU: 0.8718


  _warn_about_invalid_encoding(encoding, patched)


Epoch [32/100] - Train IoU: 0.9313 - Validation IoU: 0.8786


  _warn_about_invalid_encoding(encoding, patched)


Epoch [33/100] - Train IoU: 0.9317 - Validation IoU: 0.8824


  _warn_about_invalid_encoding(encoding, patched)


Epoch [34/100] - Train IoU: 0.9356 - Validation IoU: 0.8874


  _warn_about_invalid_encoding(encoding, patched)


Epoch [35/100] - Train IoU: 0.9363 - Validation IoU: 0.8867


  _warn_about_invalid_encoding(encoding, patched)


Epoch [36/100] - Train IoU: 0.9429 - Validation IoU: 0.8904


  _warn_about_invalid_encoding(encoding, patched)


Epoch [37/100] - Train IoU: 0.9392 - Validation IoU: 0.8887


  _warn_about_invalid_encoding(encoding, patched)


Epoch [38/100] - Train IoU: 0.9419 - Validation IoU: 0.8880


  _warn_about_invalid_encoding(encoding, patched)


Epoch [39/100] - Train IoU: 0.9429 - Validation IoU: 0.8961


  _warn_about_invalid_encoding(encoding, patched)


Epoch [40/100] - Train IoU: 0.9459 - Validation IoU: 0.8970


  _warn_about_invalid_encoding(encoding, patched)


Epoch [41/100] - Train IoU: 0.9490 - Validation IoU: 0.9026


  _warn_about_invalid_encoding(encoding, patched)


Epoch [42/100] - Train IoU: 0.9499 - Validation IoU: 0.9020


  _warn_about_invalid_encoding(encoding, patched)


Epoch [43/100] - Train IoU: 0.9512 - Validation IoU: 0.9043


  _warn_about_invalid_encoding(encoding, patched)


Epoch [44/100] - Train IoU: 0.9528 - Validation IoU: 0.9066


  _warn_about_invalid_encoding(encoding, patched)


Epoch [45/100] - Train IoU: 0.9505 - Validation IoU: 0.9031


  _warn_about_invalid_encoding(encoding, patched)


Epoch [46/100] - Train IoU: 0.9516 - Validation IoU: 0.9040


  _warn_about_invalid_encoding(encoding, patched)


Epoch [47/100] - Train IoU: 0.9517 - Validation IoU: 0.9053


  _warn_about_invalid_encoding(encoding, patched)


Epoch [48/100] - Train IoU: 0.9552 - Validation IoU: 0.9078


  _warn_about_invalid_encoding(encoding, patched)


Epoch [49/100] - Train IoU: 0.9542 - Validation IoU: 0.9088


  _warn_about_invalid_encoding(encoding, patched)


Epoch [50/100] - Train IoU: 0.9554 - Validation IoU: 0.9106


  _warn_about_invalid_encoding(encoding, patched)


Epoch [51/100] - Train IoU: 0.9554 - Validation IoU: 0.9080


  _warn_about_invalid_encoding(encoding, patched)


Epoch [52/100] - Train IoU: 0.9582 - Validation IoU: 0.9081


  _warn_about_invalid_encoding(encoding, patched)


Epoch [53/100] - Train IoU: 0.9585 - Validation IoU: 0.9119


  _warn_about_invalid_encoding(encoding, patched)


Epoch [54/100] - Train IoU: 0.9598 - Validation IoU: 0.9126


  _warn_about_invalid_encoding(encoding, patched)


Epoch [55/100] - Train IoU: 0.9568 - Validation IoU: 0.9067


  _warn_about_invalid_encoding(encoding, patched)


Epoch [56/100] - Train IoU: 0.9573 - Validation IoU: 0.9035


  _warn_about_invalid_encoding(encoding, patched)


Epoch [57/100] - Train IoU: 0.9582 - Validation IoU: 0.9043


  _warn_about_invalid_encoding(encoding, patched)


Epoch [58/100] - Train IoU: 0.9608 - Validation IoU: 0.9107


  _warn_about_invalid_encoding(encoding, patched)


Epoch [59/100] - Train IoU: 0.9611 - Validation IoU: 0.9130


  _warn_about_invalid_encoding(encoding, patched)


Epoch [60/100] - Train IoU: 0.9619 - Validation IoU: 0.9132


  _warn_about_invalid_encoding(encoding, patched)


Epoch [61/100] - Train IoU: 0.9620 - Validation IoU: 0.9136


  _warn_about_invalid_encoding(encoding, patched)


Epoch [62/100] - Train IoU: 0.9624 - Validation IoU: 0.9100


  _warn_about_invalid_encoding(encoding, patched)


Epoch [63/100] - Train IoU: 0.9645 - Validation IoU: 0.9126


  _warn_about_invalid_encoding(encoding, patched)


Epoch [64/100] - Train IoU: 0.9633 - Validation IoU: 0.9148


  _warn_about_invalid_encoding(encoding, patched)


Epoch [65/100] - Train IoU: 0.9629 - Validation IoU: 0.9149


  _warn_about_invalid_encoding(encoding, patched)


Epoch [66/100] - Train IoU: 0.9631 - Validation IoU: 0.9140


  _warn_about_invalid_encoding(encoding, patched)


Epoch [67/100] - Train IoU: 0.9650 - Validation IoU: 0.9141


  _warn_about_invalid_encoding(encoding, patched)


Epoch [68/100] - Train IoU: 0.9663 - Validation IoU: 0.9165


  _warn_about_invalid_encoding(encoding, patched)


Epoch [69/100] - Train IoU: 0.9664 - Validation IoU: 0.9160


  _warn_about_invalid_encoding(encoding, patched)


Epoch [70/100] - Train IoU: 0.9649 - Validation IoU: 0.9164


  _warn_about_invalid_encoding(encoding, patched)


Epoch [71/100] - Train IoU: 0.9654 - Validation IoU: 0.9161


  _warn_about_invalid_encoding(encoding, patched)


Epoch [72/100] - Train IoU: 0.9664 - Validation IoU: 0.9166


  _warn_about_invalid_encoding(encoding, patched)


Epoch [73/100] - Train IoU: 0.9682 - Validation IoU: 0.9181


  _warn_about_invalid_encoding(encoding, patched)


Epoch [74/100] - Train IoU: 0.9674 - Validation IoU: 0.9192


  _warn_about_invalid_encoding(encoding, patched)


Epoch [75/100] - Train IoU: 0.9692 - Validation IoU: 0.9205


  _warn_about_invalid_encoding(encoding, patched)


Epoch [76/100] - Train IoU: 0.9671 - Validation IoU: 0.9181


  _warn_about_invalid_encoding(encoding, patched)


Epoch [77/100] - Train IoU: 0.9682 - Validation IoU: 0.9192


  _warn_about_invalid_encoding(encoding, patched)


Epoch [78/100] - Train IoU: 0.9692 - Validation IoU: 0.9196


  _warn_about_invalid_encoding(encoding, patched)


Epoch [79/100] - Train IoU: 0.9707 - Validation IoU: 0.9193


  _warn_about_invalid_encoding(encoding, patched)


Epoch [80/100] - Train IoU: 0.9693 - Validation IoU: 0.9198


  _warn_about_invalid_encoding(encoding, patched)


Epoch [81/100] - Train IoU: 0.9707 - Validation IoU: 0.9189


  _warn_about_invalid_encoding(encoding, patched)


Epoch [82/100] - Train IoU: 0.9724 - Validation IoU: 0.9198


  _warn_about_invalid_encoding(encoding, patched)


Epoch [83/100] - Train IoU: 0.9717 - Validation IoU: 0.9200


  _warn_about_invalid_encoding(encoding, patched)


Epoch [84/100] - Train IoU: 0.9708 - Validation IoU: 0.9216


  _warn_about_invalid_encoding(encoding, patched)


Epoch [85/100] - Train IoU: 0.9719 - Validation IoU: 0.9220


  _warn_about_invalid_encoding(encoding, patched)


Epoch [86/100] - Train IoU: 0.9713 - Validation IoU: 0.9218


  _warn_about_invalid_encoding(encoding, patched)


Epoch [87/100] - Train IoU: 0.9725 - Validation IoU: 0.9223


  _warn_about_invalid_encoding(encoding, patched)


Epoch [88/100] - Train IoU: 0.9736 - Validation IoU: 0.9207


  _warn_about_invalid_encoding(encoding, patched)


Epoch [89/100] - Train IoU: 0.9739 - Validation IoU: 0.9200


  _warn_about_invalid_encoding(encoding, patched)


Epoch [90/100] - Train IoU: 0.9741 - Validation IoU: 0.9215


  _warn_about_invalid_encoding(encoding, patched)


Epoch [91/100] - Train IoU: 0.9716 - Validation IoU: 0.9217


  _warn_about_invalid_encoding(encoding, patched)


Epoch [92/100] - Train IoU: 0.9723 - Validation IoU: 0.9221


  _warn_about_invalid_encoding(encoding, patched)


Epoch [93/100] - Train IoU: 0.9725 - Validation IoU: 0.9234


  _warn_about_invalid_encoding(encoding, patched)


Epoch [94/100] - Train IoU: 0.9730 - Validation IoU: 0.9223


  _warn_about_invalid_encoding(encoding, patched)


Epoch [95/100] - Train IoU: 0.9739 - Validation IoU: 0.9203


  _warn_about_invalid_encoding(encoding, patched)


Epoch [96/100] - Train IoU: 0.9739 - Validation IoU: 0.9201


  _warn_about_invalid_encoding(encoding, patched)


Epoch [97/100] - Train IoU: 0.9740 - Validation IoU: 0.9198


  _warn_about_invalid_encoding(encoding, patched)


Epoch [98/100] - Train IoU: 0.9743 - Validation IoU: 0.9210


  _warn_about_invalid_encoding(encoding, patched)


Epoch [99/100] - Train IoU: 0.9733 - Validation IoU: 0.9219


  _warn_about_invalid_encoding(encoding, patched)


Epoch [100/100] - Train IoU: 0.9756 - Validation IoU: 0.9220
Model saved successfully.


In [None]:
import os
import numpy as np
import pydicom
import nibabel as nib
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import random
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import jaccard_score


random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


num_classes = 9
batch_size = 4
img_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Images"
mask_root = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/Annotations"
metadata_path = "/home/ealam/JHIR_Hip_Knee_Datasets/Knee/segmentation_with_racegender.csv"


test_augmentations = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.485, 0.485), std=(0.229, 0.229, 0.229)),
    ToTensorV2(),
])

class MulticlassHipSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, metadata_df, paired_files, num_classes, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.metadata_df = metadata_df
        self.paired_files = paired_files
        self.num_classes = num_classes
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        image_file, mask_file = self.paired_files[idx]
        if not os.path.exists(os.path.join(self.mask_root, mask_file)):
            return None

        dicom_image = pydicom.dcmread(os.path.join(self.img_root, image_file))
        image = dicom_image.pixel_array.astype(np.float32)
        image = cv2.resize(image, (512, 512))


        image = np.stack([image] * 3, axis=-1)

        annotation = nib.load(os.path.join(self.mask_root, mask_file))
        annotation_data = annotation.get_fdata()
        if len(annotation_data.shape) == 3:
            annotation_data = annotation_data[:, :, 0]

        annotation_data = self.calculate_flipped_rotated_mask(annotation_data)
        annotation_data = cv2.resize(annotation_data, (512, 512))

        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=annotation_data)
            image = transformed["image"]
            annotation_data = transformed["mask"]

        annotation_data_onehot = self.one_hot_encode(annotation_data)

        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=annotation_data_onehot)
            image = transformed["image"]
            annotation_data_onehot = transformed["mask"]

        patient_id = int(float(image_file.split(".")[0]))
        racegender_info = self.metadata_df.loc[self.metadata_df['id'] == patient_id]['racegender'].values
        racegender = racegender_info[0] if racegender_info.size > 0 else 'Unknown'

        return image, annotation_data_onehot, racegender

    def one_hot_encode(self, mask):
        one_hot_mask = np.zeros((self.num_classes, *mask.shape), dtype=np.float32)
        for class_idx in range(self.num_classes):
            one_hot_mask[class_idx][mask == class_idx] = 1.0
        return one_hot_mask

    def calculate_flipped_rotated_mask(self, mask):
        rotated_mask = cv2.rotate(mask, cv2.ROTATE_90_CLOCKWISE)
        flipped_rotated_mask = cv2.flip(rotated_mask, 1)
        return flipped_rotated_mask

metadata_df = pd.read_csv(metadata_path)

image_files = sorted(os.listdir(img_root))
mask_files = sorted(os.listdir(mask_root))

paired_files = []

for image_file in image_files:
    image_id = os.path.splitext(image_file)[0]
    mask_file = f"{image_id}.nii.gz"
    if mask_file in mask_files:
        paired_files.append((image_file, mask_file))

random.shuffle(paired_files)

train_size = int(0.7 * len(paired_files))
valid_size = int(0.1 * len(paired_files))
test_size = len(paired_files) - train_size - valid_size

train_pairs = paired_files[:train_size]
valid_pairs = paired_files[train_size:train_size + valid_size]
test_pairs = paired_files[train_size + valid_size:]

train_set = MulticlassHipSegmentationDataset(
    img_root, mask_root, metadata_df, train_pairs, num_classes,
    transforms=test_augmentations
)

def train_manet_model(num_epochs=100, encoder_name="resnet18"):

    model = smp.MAnet(
        encoder_name=encoder_name,
        encoder_weights="imagenet",
        in_channels=3,
        classes=num_classes,
    )

    device = torch.device("cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    valid_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, valid_pairs, num_classes,
        transforms=test_augmentations
    )

    valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=2)

    for epoch in range(num_epochs):
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

        model.train()
        train_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(train_loader):
            if images is None:
                continue

            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)

            predicted_masks = torch.argmax(outputs, dim=1)
            loss = criterion(outputs, masks.argmax(dim=1))
            loss.backward()
            optimizer.step()

            train_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )
            train_iou_list.append(train_iou)

        model.eval()
        valid_iou_list = []

        for batch_idx, (images, masks, racegender) in enumerate(valid_loader):
            images, masks = images.to(device), masks.to(device)
            with torch.no_grad():
                outputs = model(images)
            predicted_masks = torch.argmax(outputs, dim=1)

            valid_iou = jaccard_score(
                masks.argmax(dim=1).cpu().numpy().flatten(),
                predicted_masks.cpu().numpy().flatten(),
                average='micro'
            )

            valid_iou_list.append(valid_iou)

        train_iou_avg = np.mean(train_iou_list)
        valid_iou_avg = np.mean(valid_iou_list)

        print(f"Epoch [{epoch + 1}/{num_epochs}] - Train IoU: {train_iou_avg:.4f} - Validation IoU: {valid_iou_avg:.4f}")


    torch.save(model.state_dict(), 'manet_model_knee_init.pth')


    test_set = MulticlassHipSegmentationDataset(
        img_root, mask_root, metadata_df, test_pairs, num_classes,
        transforms=test_augmentations
    )

    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    model.eval()
    test_iou_list = []

    for batch_idx, (images, masks, racegender) in enumerate(test_loader):
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            outputs = model(images)
        predicted_masks = torch.argmax(outputs, dim=1)

        test_iou = jaccard_score(
            masks.argmax(dim=1).cpu().numpy().flatten(),
            predicted_masks.cpu().numpy().flatten(),
            average='micro'
        )

        test_iou_list.append(test_iou)

    test_iou_avg = np.mean(test_iou_list)


    print("Model saved successfully.")


print("Training MAnet ..")
train_manet_model(encoder_name="resnet18")


Training MAnet ..


Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor
  _warn_about_invalid_encoding(encoding, patched)


Epoch [1/100] - Train IoU: 0.0241 - Validation IoU: 0.0322


  _warn_about_invalid_encoding(encoding, patched)


Epoch [2/100] - Train IoU: 0.0337 - Validation IoU: 0.0409


  _warn_about_invalid_encoding(encoding, patched)


Epoch [3/100] - Train IoU: 0.0613 - Validation IoU: 0.0822


  _warn_about_invalid_encoding(encoding, patched)


Epoch [4/100] - Train IoU: 0.1183 - Validation IoU: 0.1663


  _warn_about_invalid_encoding(encoding, patched)


Epoch [5/100] - Train IoU: 0.2056 - Validation IoU: 0.2647


  _warn_about_invalid_encoding(encoding, patched)


Epoch [6/100] - Train IoU: 0.2810 - Validation IoU: 0.3540


  _warn_about_invalid_encoding(encoding, patched)


Epoch [7/100] - Train IoU: 0.3479 - Validation IoU: 0.4079


  _warn_about_invalid_encoding(encoding, patched)


Epoch [8/100] - Train IoU: 0.3957 - Validation IoU: 0.4560


  _warn_about_invalid_encoding(encoding, patched)


Epoch [9/100] - Train IoU: 0.4307 - Validation IoU: 0.4902


  _warn_about_invalid_encoding(encoding, patched)


Epoch [10/100] - Train IoU: 0.4667 - Validation IoU: 0.5197


  _warn_about_invalid_encoding(encoding, patched)


Epoch [11/100] - Train IoU: 0.4959 - Validation IoU: 0.5545


  _warn_about_invalid_encoding(encoding, patched)


Epoch [12/100] - Train IoU: 0.5183 - Validation IoU: 0.5847


  _warn_about_invalid_encoding(encoding, patched)


Epoch [13/100] - Train IoU: 0.5423 - Validation IoU: 0.6050


  _warn_about_invalid_encoding(encoding, patched)


Epoch [14/100] - Train IoU: 0.5672 - Validation IoU: 0.6217


  _warn_about_invalid_encoding(encoding, patched)


Epoch [15/100] - Train IoU: 0.5794 - Validation IoU: 0.6387


  _warn_about_invalid_encoding(encoding, patched)


Epoch [16/100] - Train IoU: 0.6025 - Validation IoU: 0.6508


  _warn_about_invalid_encoding(encoding, patched)


Epoch [17/100] - Train IoU: 0.6148 - Validation IoU: 0.6518


  _warn_about_invalid_encoding(encoding, patched)


Epoch [18/100] - Train IoU: 0.6252 - Validation IoU: 0.6667


  _warn_about_invalid_encoding(encoding, patched)


Epoch [19/100] - Train IoU: 0.6353 - Validation IoU: 0.6687


  _warn_about_invalid_encoding(encoding, patched)


Epoch [20/100] - Train IoU: 0.6442 - Validation IoU: 0.6768


  _warn_about_invalid_encoding(encoding, patched)


Epoch [21/100] - Train IoU: 0.6519 - Validation IoU: 0.6920


  _warn_about_invalid_encoding(encoding, patched)


Epoch [22/100] - Train IoU: 0.6600 - Validation IoU: 0.6985


  _warn_about_invalid_encoding(encoding, patched)


Epoch [23/100] - Train IoU: 0.6682 - Validation IoU: 0.7079


  _warn_about_invalid_encoding(encoding, patched)


Epoch [24/100] - Train IoU: 0.6769 - Validation IoU: 0.7144


  _warn_about_invalid_encoding(encoding, patched)


Epoch [25/100] - Train IoU: 0.6844 - Validation IoU: 0.7193


  _warn_about_invalid_encoding(encoding, patched)


Epoch [26/100] - Train IoU: 0.6910 - Validation IoU: 0.7282


  _warn_about_invalid_encoding(encoding, patched)


Epoch [27/100] - Train IoU: 0.6998 - Validation IoU: 0.7380


  _warn_about_invalid_encoding(encoding, patched)


Epoch [28/100] - Train IoU: 0.7143 - Validation IoU: 0.7479


  _warn_about_invalid_encoding(encoding, patched)


Epoch [29/100] - Train IoU: 0.7229 - Validation IoU: 0.7552


  _warn_about_invalid_encoding(encoding, patched)


Epoch [30/100] - Train IoU: 0.7300 - Validation IoU: 0.7569


  _warn_about_invalid_encoding(encoding, patched)


Epoch [31/100] - Train IoU: 0.7384 - Validation IoU: 0.7617


  _warn_about_invalid_encoding(encoding, patched)


Epoch [32/100] - Train IoU: 0.7442 - Validation IoU: 0.7677


  _warn_about_invalid_encoding(encoding, patched)


Epoch [33/100] - Train IoU: 0.7514 - Validation IoU: 0.7719


  _warn_about_invalid_encoding(encoding, patched)


Epoch [34/100] - Train IoU: 0.7540 - Validation IoU: 0.7702


  _warn_about_invalid_encoding(encoding, patched)


Epoch [35/100] - Train IoU: 0.7584 - Validation IoU: 0.7734


  _warn_about_invalid_encoding(encoding, patched)


Epoch [36/100] - Train IoU: 0.7609 - Validation IoU: 0.7737


  _warn_about_invalid_encoding(encoding, patched)


Epoch [37/100] - Train IoU: 0.7659 - Validation IoU: 0.7762


  _warn_about_invalid_encoding(encoding, patched)


Epoch [38/100] - Train IoU: 0.7725 - Validation IoU: 0.7830


  _warn_about_invalid_encoding(encoding, patched)


Epoch [39/100] - Train IoU: 0.7835 - Validation IoU: 0.7898


  _warn_about_invalid_encoding(encoding, patched)


Epoch [40/100] - Train IoU: 0.8034 - Validation IoU: 0.8000


  _warn_about_invalid_encoding(encoding, patched)


Epoch [41/100] - Train IoU: 0.8294 - Validation IoU: 0.8170


  _warn_about_invalid_encoding(encoding, patched)


Epoch [42/100] - Train IoU: 0.8541 - Validation IoU: 0.8262


  _warn_about_invalid_encoding(encoding, patched)


Epoch [43/100] - Train IoU: 0.8767 - Validation IoU: 0.8384


  _warn_about_invalid_encoding(encoding, patched)


Epoch [44/100] - Train IoU: 0.8913 - Validation IoU: 0.8449


  _warn_about_invalid_encoding(encoding, patched)


Epoch [45/100] - Train IoU: 0.9045 - Validation IoU: 0.8533


  _warn_about_invalid_encoding(encoding, patched)


Epoch [46/100] - Train IoU: 0.9130 - Validation IoU: 0.8622


  _warn_about_invalid_encoding(encoding, patched)


Epoch [47/100] - Train IoU: 0.9236 - Validation IoU: 0.8691


  _warn_about_invalid_encoding(encoding, patched)


Epoch [48/100] - Train IoU: 0.9310 - Validation IoU: 0.8790


  _warn_about_invalid_encoding(encoding, patched)


Epoch [49/100] - Train IoU: 0.9361 - Validation IoU: 0.8811


  _warn_about_invalid_encoding(encoding, patched)


Epoch [50/100] - Train IoU: 0.9429 - Validation IoU: 0.8897


  _warn_about_invalid_encoding(encoding, patched)


Epoch [51/100] - Train IoU: 0.9453 - Validation IoU: 0.8878


  _warn_about_invalid_encoding(encoding, patched)


Epoch [52/100] - Train IoU: 0.9499 - Validation IoU: 0.8998


  _warn_about_invalid_encoding(encoding, patched)


Epoch [53/100] - Train IoU: 0.9534 - Validation IoU: 0.8967


  _warn_about_invalid_encoding(encoding, patched)


Epoch [54/100] - Train IoU: 0.9542 - Validation IoU: 0.8994


  _warn_about_invalid_encoding(encoding, patched)


Epoch [55/100] - Train IoU: 0.9559 - Validation IoU: 0.9025


  _warn_about_invalid_encoding(encoding, patched)


Epoch [56/100] - Train IoU: 0.9593 - Validation IoU: 0.9018


  _warn_about_invalid_encoding(encoding, patched)


Epoch [57/100] - Train IoU: 0.9586 - Validation IoU: 0.9032


  _warn_about_invalid_encoding(encoding, patched)


Epoch [58/100] - Train IoU: 0.9620 - Validation IoU: 0.9080


  _warn_about_invalid_encoding(encoding, patched)


Epoch [59/100] - Train IoU: 0.9644 - Validation IoU: 0.9033


  _warn_about_invalid_encoding(encoding, patched)


Epoch [60/100] - Train IoU: 0.9647 - Validation IoU: 0.9075


  _warn_about_invalid_encoding(encoding, patched)


Epoch [61/100] - Train IoU: 0.9674 - Validation IoU: 0.9123


  _warn_about_invalid_encoding(encoding, patched)


Epoch [62/100] - Train IoU: 0.9677 - Validation IoU: 0.9109


  _warn_about_invalid_encoding(encoding, patched)


Epoch [63/100] - Train IoU: 0.9687 - Validation IoU: 0.9087


  _warn_about_invalid_encoding(encoding, patched)


Epoch [64/100] - Train IoU: 0.9700 - Validation IoU: 0.9131


  _warn_about_invalid_encoding(encoding, patched)


Epoch [65/100] - Train IoU: 0.9714 - Validation IoU: 0.9122


  _warn_about_invalid_encoding(encoding, patched)


Epoch [66/100] - Train IoU: 0.9717 - Validation IoU: 0.9173


  _warn_about_invalid_encoding(encoding, patched)


Epoch [67/100] - Train IoU: 0.9711 - Validation IoU: 0.9115


  _warn_about_invalid_encoding(encoding, patched)


Epoch [68/100] - Train IoU: 0.9733 - Validation IoU: 0.9174


  _warn_about_invalid_encoding(encoding, patched)


Epoch [69/100] - Train IoU: 0.9727 - Validation IoU: 0.9200


  _warn_about_invalid_encoding(encoding, patched)


Epoch [70/100] - Train IoU: 0.9759 - Validation IoU: 0.9210


  _warn_about_invalid_encoding(encoding, patched)


Epoch [71/100] - Train IoU: 0.9749 - Validation IoU: 0.9193


  _warn_about_invalid_encoding(encoding, patched)


Epoch [72/100] - Train IoU: 0.9751 - Validation IoU: 0.9199


  _warn_about_invalid_encoding(encoding, patched)


Epoch [73/100] - Train IoU: 0.9782 - Validation IoU: 0.9206


  _warn_about_invalid_encoding(encoding, patched)


Epoch [74/100] - Train IoU: 0.9779 - Validation IoU: 0.9175


  _warn_about_invalid_encoding(encoding, patched)


Epoch [75/100] - Train IoU: 0.9777 - Validation IoU: 0.9203


  _warn_about_invalid_encoding(encoding, patched)


Epoch [76/100] - Train IoU: 0.9779 - Validation IoU: 0.9202


  _warn_about_invalid_encoding(encoding, patched)


Epoch [77/100] - Train IoU: 0.9776 - Validation IoU: 0.9177


  _warn_about_invalid_encoding(encoding, patched)


Epoch [78/100] - Train IoU: 0.9796 - Validation IoU: 0.9190


  _warn_about_invalid_encoding(encoding, patched)


Epoch [79/100] - Train IoU: 0.9778 - Validation IoU: 0.9210


  _warn_about_invalid_encoding(encoding, patched)


Epoch [80/100] - Train IoU: 0.9768 - Validation IoU: 0.9086


  _warn_about_invalid_encoding(encoding, patched)


Epoch [81/100] - Train IoU: 0.9765 - Validation IoU: 0.9200


  _warn_about_invalid_encoding(encoding, patched)


Epoch [82/100] - Train IoU: 0.9796 - Validation IoU: 0.9199


  _warn_about_invalid_encoding(encoding, patched)


Epoch [83/100] - Train IoU: 0.9797 - Validation IoU: 0.9227


  _warn_about_invalid_encoding(encoding, patched)


Epoch [84/100] - Train IoU: 0.9769 - Validation IoU: 0.9201


  _warn_about_invalid_encoding(encoding, patched)


Epoch [85/100] - Train IoU: 0.9806 - Validation IoU: 0.9218


  _warn_about_invalid_encoding(encoding, patched)


Epoch [86/100] - Train IoU: 0.9799 - Validation IoU: 0.9246


  _warn_about_invalid_encoding(encoding, patched)


Epoch [87/100] - Train IoU: 0.9811 - Validation IoU: 0.9222


  _warn_about_invalid_encoding(encoding, patched)


Epoch [88/100] - Train IoU: 0.9810 - Validation IoU: 0.9243


  _warn_about_invalid_encoding(encoding, patched)


Epoch [89/100] - Train IoU: 0.9817 - Validation IoU: 0.9246


  _warn_about_invalid_encoding(encoding, patched)


Epoch [90/100] - Train IoU: 0.9820 - Validation IoU: 0.9247


  _warn_about_invalid_encoding(encoding, patched)


Epoch [91/100] - Train IoU: 0.9828 - Validation IoU: 0.9276


  _warn_about_invalid_encoding(encoding, patched)


Epoch [92/100] - Train IoU: 0.9836 - Validation IoU: 0.9237


  _warn_about_invalid_encoding(encoding, patched)


Epoch [93/100] - Train IoU: 0.9820 - Validation IoU: 0.9261


  _warn_about_invalid_encoding(encoding, patched)


Epoch [94/100] - Train IoU: 0.9838 - Validation IoU: 0.9283


  _warn_about_invalid_encoding(encoding, patched)


Epoch [95/100] - Train IoU: 0.9839 - Validation IoU: 0.9254


  _warn_about_invalid_encoding(encoding, patched)


Epoch [96/100] - Train IoU: 0.9838 - Validation IoU: 0.9281


  _warn_about_invalid_encoding(encoding, patched)


Epoch [97/100] - Train IoU: 0.9842 - Validation IoU: 0.9263


  _warn_about_invalid_encoding(encoding, patched)


Epoch [98/100] - Train IoU: 0.9841 - Validation IoU: 0.9287


  _warn_about_invalid_encoding(encoding, patched)


Epoch [99/100] - Train IoU: 0.9833 - Validation IoU: 0.9273


  _warn_about_invalid_encoding(encoding, patched)


Epoch [100/100] - Train IoU: 0.9852 - Validation IoU: 0.9279
Model saved successfully.
