In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.transforms import ToTensor
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from medpy.metric import dc, assd
import torch.nn.functional as F
import torchio as tio
import nibabel as nib

In [2]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Downsample path
        self.conv1 = self.double_conv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = self.double_conv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = self.double_conv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = self.double_conv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        # Bottom
        self.conv5 = self.double_conv(512, 1024)

        # Upsample path
        self.up6 = nn.ConvTranspose2d(1024, 512, kernel_size = 2, stride = 2)
        self.conv6 = self.double_conv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, kernel_size = 2, stride = 2)
        self.conv7 = self.double_conv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, kernel_size = 2, stride = 2)
        self.conv8 = self.double_conv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, kernel_size = 2, stride = 2)
        self.conv9 = self.double_conv(128, 64)

        # Output
        self.conv10 = nn.Conv2d(64, out_channels, kernel_size = 1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True)
        )

    def forward(self, x):
        # Downsample path
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)

        # Bottom
        c5 = self.conv5(p4)

        # Upsample path
        # up6 = self.up6(c5)
        up6 = F.interpolate(self.up6(c5), size=c4.size()[2:], mode='bilinear', align_corners=False)
        merge6 = torch.cat([up6, c4], dim = 1)
        c6 = self.conv6(merge6)

        up7 = F.interpolate(self.up7(c6), size=c3.size()[2:], mode='bilinear', align_corners=False)
        merge7 = torch.cat([up7, c3], dim = 1)
        c7 = self.conv7(merge7)

        up8 = F.interpolate(self.up8(c7), size=c2.size()[2:], mode='bilinear', align_corners=False)
        merge8 = torch.cat([up8, c2], dim = 1)
        c8 = self.conv8(merge8)

        up9 = F.interpolate(self.up9(c8), size=c1.size()[2:], mode='bilinear', align_corners=False)
        merge9 = torch.cat([up9, c1], dim = 1)
        c9 = self.conv9(merge9)

        # Output
        out = self.conv10(c9)
        return out
    
    def predict(self, x):
        out = self.forward(x)
        _, preds = torch.max(out, 1)
        return preds


In [3]:
def min_max_normalization(slice_2d):
    max_val = np.max(slice_2d)
    min_val = np.min(slice_2d)
    # Only normalize if there is data
    if max_val - min_val > 0:
        slice_2d_normalized = (slice_2d - min_val) / (max_val - min_val)
    else:
        assert max_val == 0 and min_val == 0
        slice_2d_normalized = slice_2d
    return slice_2d_normalized

In [None]:
import os
import torch
import torchio as tio
import numpy as np
import nibabel as nib
from PIL import Image
from torchvision.transforms import Compose
from torch.autograd import Variable
from torchvision.transforms import ToTensor

device = "cuda:2"
model = UNet(in_channels=2, out_channels=4)
model = model.to(device)

import re

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split('(\d+)', text) ]

In [13]:
def process_orientation(orientation):
    if orientation == "axial":
        model.load_state_dict(torch.load("unet_PED_axial.pth"))
    elif orientation == "coronal":
        model.load_state_dict(torch.load("unet_PED_coronal.pth"))
    else:  # orientation == "sagittal"
        model.load_state_dict(torch.load("unet_PED_sagittal.pth"))

    # Define the directories
    t1c_dir = "pediatric_modalities/t1c"
    t2f_dir = "pediatric_modalities/t2f"
    output_dir =  "PED_ax_op"
    original_seg_dir = "ASNR-MICCAI-BraTS2023-PED-Challenge-TrainingData/"

    # Iterate over patient folders
    for patient_dir in sorted(os.listdir(t1c_dir), key=natural_keys):
        print(f"Processing patient: {patient_dir}")

        t1c_img = nib.load(os.path.join(t1c_dir, patient_dir)).get_fdata()
        t2f_img = nib.load(os.path.join(t2f_dir, patient_dir.replace("t1c", "t2f"))).get_fdata()

        # Placeholder for the output segmentation slices
        seg_slices = []
        
        if orientation == "axial":
            shape = t1c_img.shape[-1]
        elif orientation == "coronal":
            shape = t1c_img.shape[1]
        else:
            shape = t1c_img.shape[0]
        
        for i in range(shape):
            if orientation == "axial":
                t1c_img_slice = min_max_normalization(t1c_img[:, :, i])
                t2f_img_slice = min_max_normalization(t2f_img[:, :, i])
            elif orientation == "coronal":
                t1c_img_slice = min_max_normalization(t1c_img[:, i, :])
                t2f_img_slice = min_max_normalization(t2f_img[:, i, :])
            else:  # orientation == "sagittal"
                t1c_img_slice = min_max_normalization(t1c_img[i, :, :])
                t2f_img_slice = min_max_normalization(t2f_img[i, :, :])

            t1c_tensor = torch.from_numpy(np.array(t1c_img_slice,dtype = np.float32)[None, ...])

            t2f_tensor = torch.from_numpy(np.array(t2f_img_slice,dtype = np.float32)[None, ...])

            # Convert to tensors and stack
            stack = torch.cat((t1c_tensor, t2f_tensor), dim=0).unsqueeze(0).to(device)
            output = model(stack)

            _, preds = torch.max(output, 1)

            # Save the segmentation slice
            seg_slices.append(preds.squeeze().detach().cpu().numpy())

        # Stack the segmentation slices to form a 3D image
        if orientation == "axial":
            seg_3d = np.stack(seg_slices, axis=-1)
        elif orientation == "coronal":
            seg_3d = np.stack(seg_slices, axis=1)
        else:  # orientation == "sagittal"
            seg_3d = np.stack(seg_slices, axis=0)

        # Load a original segmentation file to get the affine and header
        original_seg_file = nib.load(os.path.join(original_seg_dir, patient_dir.replace('-t1c', '').split('.')[0], f"{patient_dir.replace('-t1c', '').split('.')[0]}-seg.nii.gz"))

        # Convert the 3D image to a Nifti1Image and save
        seg_nifti = nib.Nifti1Image(seg_3d, affine=original_seg_file.affine, header=original_seg_file.header)
        nib.save(seg_nifti, os.path.join(output_dir, f"{patient_dir}_{orientation}_segmentation.nii.gz"))

# Run the function for each orientation
for orientation in ["coronal", "axial", "sagittal"]:
    process_orientation(orientation)


Processing patient: BraTS-PED-00002-000-t1c.nii.gz
Processing patient: BraTS-PED-00003-000-t1c.nii.gz
Processing patient: BraTS-PED-00004-000-t1c.nii.gz


KeyboardInterrupt: 

In [None]:
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
from scipy.stats import mode


img1 = nib.load('BraTS-PED-00003-000-t1c.nii.gz_axial_segmentation.nii.gz') # change this code put a for loop
img2 = nib.load("BraTS-PED-00003-000-t1c.nii.gz_coronal_segmentation.nii.gz") 
img3 = nib.load("BraTS-PED-00003-000-t1c.nii.gz_sagittal_segmentation.nii.gz")

# Load your pre-trained models' predictions
axial_preds = img1.get_fdata()
coronal_preds = img2.get_fdata()
sagittal_preds = img3.get_fdata()

# Let's assume the dimensions of the original MRI are:
original_shape = (240, 240, 155)  # Adjust as per your data

# Rescale the predictions to the original dimensions:
rescaled_axial = axial_preds
rescaled_coronal = coronal_preds
rescaled_sagittal = sagittal_preds

# Combine the three predictions
combined_preds = np.stack([rescaled_axial, rescaled_coronal, rescaled_sagittal])

# Create the final segmentation by taking the majority vote
final_segmentation = mode(combined_preds, axis=0)[0][0]

# Create a Nifti image from the final segmentation
# In your case, you might want to use the affine and header from one of the loaded NIfTI images
final_seg_nifti = nib.Nifti1Image(final_segmentation, affine=img1.affine, header=img1.header)

# Save the final segmentation
nib.save(final_seg_nifti, "final_segmentation.nii.gz")
