In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.transforms import ToTensor
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from medpy.metric import dc, assd
import torch.nn.functional as F
import torchio as tio
import nibabel as nib

In [2]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Downsample path
        self.conv1 = self.double_conv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = self.double_conv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = self.double_conv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = self.double_conv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        # Bottom
        self.conv5 = self.double_conv(512, 1024)

        # Upsample path
        self.up6 = nn.ConvTranspose2d(1024, 512, kernel_size = 2, stride = 2)
        self.conv6 = self.double_conv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, kernel_size = 2, stride = 2)
        self.conv7 = self.double_conv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, kernel_size = 2, stride = 2)
        self.conv8 = self.double_conv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, kernel_size = 2, stride = 2)
        self.conv9 = self.double_conv(128, 64)

        # Output
        self.conv10 = nn.Conv2d(64, out_channels, kernel_size = 1)

    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True)
        )

    def forward(self, x):
        # Downsample path
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)

        # Bottom
        c5 = self.conv5(p4)

        # Upsample path
        # up6 = self.up6(c5)
        up6 = F.interpolate(self.up6(c5), size=c4.size()[2:], mode='bilinear', align_corners=False)
        merge6 = torch.cat([up6, c4], dim = 1)
        c6 = self.conv6(merge6)

        up7 = F.interpolate(self.up7(c6), size=c3.size()[2:], mode='bilinear', align_corners=False)
        merge7 = torch.cat([up7, c3], dim = 1)
        c7 = self.conv7(merge7)

        up8 = F.interpolate(self.up8(c7), size=c2.size()[2:], mode='bilinear', align_corners=False)
        merge8 = torch.cat([up8, c2], dim = 1)
        c8 = self.conv8(merge8)

        up9 = F.interpolate(self.up9(c8), size=c1.size()[2:], mode='bilinear', align_corners=False)
        merge9 = torch.cat([up9, c1], dim = 1)
        c9 = self.conv9(merge9)

        # Output
        out = self.conv10(c9)
        return out
    
    def predict(self, x):
        out = self.forward(x)
        _, preds = torch.max(out, 1)
        return preds


In [3]:
def min_max_normalization(slice_2d):
    max_val = np.max(slice_2d)
    min_val = np.min(slice_2d)
    # Only normalize if there is data
    if max_val - min_val > 0:
        slice_2d_normalized = (slice_2d - min_val) / (max_val - min_val)
    else:
        assert max_val == 0 and min_val == 0
        slice_2d_normalized = slice_2d
    return slice_2d_normalized

In [7]:
import os
import torch
import torchio as tio
import numpy as np
import nibabel as nib
from PIL import Image
from torchvision.transforms import Compose
from torch.autograd import Variable
from torchvision.transforms import ToTensor
device = "cuda:2"
# Create an instance of your model
model = UNet(in_channels=2, out_channels=4)
model.load_state_dict(torch.load("unet_PED_axial.pth"))
model = model.to(device)

import re

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split('(\d+)', text) ]

# Define the directories
t1c_dir = "pediatric_modalities/t1c"
t2f_dir = "pediatric_modalities/t2f"
output_dir =  "PED_ax_op"
original_seg_dir = "ASNR-MICCAI-BraTS2023-PED-Challenge-TrainingData/"

# Iterate over patient folders
for patient_dir in sorted(os.listdir(t1c_dir)):
    print(f"Processing patient: {patient_dir}")
    
    t1c_img = nib.load(os.path.join(t1c_dir, patient_dir)).get_fdata()
    t2f_img = nib.load(os.path.join(t2f_dir, patient_dir.replace("t1c", "t2f"))).get_fdata()
    
    
    # Placeholder for the output segmentation slices
    seg_slices = []
    
    
    for i in range(t1c_img.shape[-1]):
        t1c_img_slice = min_max_normalization(t1c_img[:, :, i])
        t2f_img_slice = min_max_normalization(t2f_img[:, :, i])
        
        t1c_tensor = torch.from_numpy(np.array(t1c_img_slice,dtype = np.float32)[None, ...])
        
        t2f_tensor = torch.from_numpy(np.array(t2f_img_slice,dtype = np.float32)[None, ...])
        
        # Convert to tensors and stack
        stack = torch.cat((t1c_tensor, t2f_tensor), dim=0).unsqueeze(0).to(device)
        output = model(stack)

        _, preds = torch.max(output, 1)
        
        # Save the segmentation slice
        seg_slices.append(preds.squeeze().detach().cpu().numpy())
    
    # Stack the segmentation slices to form a 3D image
    seg_3d = np.stack(seg_slices, axis=-1)
    
    # Load a original segmentation file to get the affine and header
    original_seg_file = nib.load(os.path.join(original_seg_dir, patient_dir.replace('-t1c', '').split('.')[0], f"{patient_dir.replace('-t1c', '').split('.')[0]}-seg.nii.gz"))
    
    # Resize and reorient the 3D predicted mask
    # target_dims = original_seg_file.shape
    # transform = tio.CropOrPad(target_dims)
    # seg_3d = torch.from_numpy(seg_3d)
    # seg_3d = seg_3d.unsqueeze(0)
    # seg_3d = transform(seg_3d)
    # seg_3d = seg_3d.squeeze(0).numpy()

    # Convert the 3D image to a Nifti1Image and save
    seg_nifti = nib.Nifti1Image(seg_3d, affine=original_seg_file.affine, header=original_seg_file.header)
    nib.save(seg_nifti, os.path.join(output_dir, f"{patient_dir}_segmentation.nii.gz"))


Processing patient: BraTS-PED-00002-000-t1c.nii.gz
Processing patient: BraTS-PED-00003-000-t1c.nii.gz


KeyboardInterrupt: 

In [None]:
import os
import torch
import torchio as tio
import numpy as np
import nibabel as nib
from PIL import Image
from torchvision.transforms import Compose
from torch.autograd import Variable
from torchvision.transforms import ToTensor
device = "cuda:2"
# Create an instance of your model
model = UNet(in_channels=2, out_channels=4)
model.load_state_dict(torch.load("unet_PED_coronal.pth"))
model = model.to(device)

import re

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split('(\d+)', text) ]

# Define the directories
t1c_dir = "pediatric_modalities/t1c"
t2f_dir = "pediatric_modalities/t2f"
output_dir =  "PED_ax_op"
original_seg_dir = "ASNR-MICCAI-BraTS2023-PED-Challenge-TrainingData/"

# Iterate over patient folders
for patient_dir in sorted(os.listdir(t1c_dir)):
    print(f"Processing patient: {patient_dir}")
    
    t1c_img = nib.load(os.path.join(t1c_dir, patient_dir)).get_fdata()
    t2f_img = nib.load(os.path.join(t2f_dir, patient_dir.replace("t1c", "t2f"))).get_fdata()
    
    
    # Placeholder for the output segmentation slices
    seg_slices = []
    
    
    for i in range(t1c_img.shape[-1]):
        t1c_img_slice = min_max_normalization(t1c_img[:, i, :])
        t2f_img_slice = min_max_normalization(t2f_img[:, i, :])
        
        t1c_tensor = torch.from_numpy(np.array(t1c_img_slice,dtype = np.float32)[None, ...])
        
        t2f_tensor = torch.from_numpy(np.array(t2f_img_slice,dtype = np.float32)[None, ...])
        
        # Convert to tensors and stack
        stack = torch.cat((t1c_tensor, t2f_tensor), dim=0).unsqueeze(0).to(device)
        # stack = torch.stack([ToTensor()(t1c_img), ToTensor()(t2f_img)]).unsqueeze(0).to(device)
        # print(stack.size())
        # stack = stack.view(stack.shape[0], stack.shape[1], stack.shape[3], stack.shape[4])  # Reshape here
        output = model(stack)

        _, preds = torch.max(output, 1)
        
        # Save the segmentation slice
        seg_slices.append(preds.squeeze().detach().cpu().numpy())
    
    # Stack the segmentation slices to form a 3D image
    seg_3d = np.stack(seg_slices, axis=1)
    
    # Load a original segmentation file to get the affine and header
    original_seg_file = nib.load(os.path.join(original_seg_dir, patient_dir.replace('-t1c', '').split('.')[0], f"{patient_dir.replace('-t1c', '').split('.')[0]}-seg.nii.gz"))
    
    # Resize and reorient the 3D predicted mask
    # target_dims = original_seg_file.shape
    # transform = tio.CropOrPad(target_dims)
    # seg_3d = torch.from_numpy(seg_3d)
    # seg_3d = seg_3d.unsqueeze(0)
    # seg_3d = transform(seg_3d)
    # seg_3d = seg_3d.squeeze(0).numpy()

    # Convert the 3D image to a Nifti1Image and save
    seg_nifti = nib.Nifti1Image(seg_3d, affine=original_seg_file.affine, header=original_seg_file.header)
    nib.save(seg_nifti, os.path.join(output_dir, f"{patient_dir}_segmentation.nii.gz"))


In [None]:
import os
import torch
import torchio as tio
import numpy as np
import nibabel as nib
from PIL import Image
from torchvision.transforms import Compose
from torch.autograd import Variable
from torchvision.transforms import ToTensor
device = "cuda:2"
# Create an instance of your model
model = UNet(in_channels=2, out_channels=4)
model.load_state_dict(torch.load("unet_PED_sagittal.pth"))
model = model.to(device)

import re

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split('(\d+)', text) ]

# Define the directories
t1c_dir = "pediatric_modalities/t1c"
t2f_dir = "pediatric_modalities/t2f"
output_dir =  "PED_ax_op"
original_seg_dir = "ASNR-MICCAI-BraTS2023-PED-Challenge-TrainingData/"

# Iterate over patient folders
for patient_dir in sorted(os.listdir(t1c_dir)):
    print(f"Processing patient: {patient_dir}")
    
    t1c_img = nib.load(os.path.join(t1c_dir, patient_dir)).get_fdata()
    t2f_img = nib.load(os.path.join(t2f_dir, patient_dir.replace("t1c", "t2f"))).get_fdata()
    
    
    # Placeholder for the output segmentation slices
    seg_slices = []
    
    
    for i in range(t1c_img.shape[-1]):
        t1c_img_slice = min_max_normalization(t1c_img[i, :, :])
        t2f_img_slice = min_max_normalization(t2f_img[i, :, :])
        
        t1c_tensor = torch.from_numpy(np.array(t1c_img_slice,dtype = np.float32)[None, ...])
        
        t2f_tensor = torch.from_numpy(np.array(t2f_img_slice,dtype = np.float32)[None, ...])
        
        # Convert to tensors and stack
        stack = torch.cat((t1c_tensor, t2f_tensor), dim=0).unsqueeze(0).to(device)
        # stack = torch.stack([ToTensor()(t1c_img), ToTensor()(t2f_img)]).unsqueeze(0).to(device)
        # print(stack.size())
        # stack = stack.view(stack.shape[0], stack.shape[1], stack.shape[3], stack.shape[4])  # Reshape here
        output = model(stack)

        _, preds = torch.max(output, 1)
        
        # Save the segmentation slice
        seg_slices.append(preds.squeeze().detach().cpu().numpy())
    
    # Stack the segmentation slices to form a 3D image
    seg_3d = np.stack(seg_slices, axis=0)
    
    # Load a original segmentation file to get the affine and header
    original_seg_file = nib.load(os.path.join(original_seg_dir, patient_dir.replace('-t1c', '').split('.')[0], f"{patient_dir.replace('-t1c', '').split('.')[0]}-seg.nii.gz"))
    
    # Resize and reorient the 3D predicted mask
    # target_dims = original_seg_file.shape
    # transform = tio.CropOrPad(target_dims)
    # seg_3d = torch.from_numpy(seg_3d)
    # seg_3d = seg_3d.unsqueeze(0)
    # seg_3d = transform(seg_3d)
    # seg_3d = seg_3d.squeeze(0).numpy()

    # Convert the 3D image to a Nifti1Image and save
    seg_nifti = nib.Nifti1Image(seg_3d, affine=original_seg_file.affine, header=original_seg_file.header)
    nib.save(seg_nifti, os.path.join(output_dir, f"{patient_dir}_segmentation.nii.gz"))


In [None]:
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
from scipy.stats import mode

# Load your pre-trained models' predictions
axial_preds = img1.get_fdata()
coronal_preds = img2.get_fdata()
sagittal_preds = img3.get_fdata()

# Let's assume the dimensions of the original MRI are:
original_shape = (240, 240, 155)  # Adjust as per your data

# Rescale the predictions to the original dimensions:
rescaled_axial = axial_preds
rescaled_coronal = coronal_preds
rescaled_sagittal = sagittal_preds

# Combine the three predictions
combined_preds = np.stack([rescaled_axial, rescaled_coronal, rescaled_sagittal])

# Create the final segmentation by taking the majority vote
final_segmentation = mode(combined_preds, axis=0)[0][0]

# Create a Nifti image from the final segmentation
# In your case, you might want to use the affine and header from one of the loaded NIfTI images
final_seg_nifti = nib.Nifti1Image(final_segmentation, affine=img1.affine, header=img1.header)

# Save the final segmentation
nib.save(final_seg_nifti, "final_segmentation.nii.gz")


In [None]:
import os
import torch
import torchio as tio
import numpy as np
import nibabel as nib
from PIL import Image
from torchvision.transforms import Compose
from torch.autograd import Variable
from torchvision.transforms import ToTensor

device = "cuda:2"
model = UNet(in_channels=2, out_channels=4)
model = model.to(device)

import re

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split('(\d+)', text) ]

In [13]:


def process_orientation(orientation):
    if orientation == "axial":
        model.load_state_dict(torch.load("unet_PED_axial.pth"))
    elif orientation == "coronal":
        model.load_state_dict(torch.load("unet_PED_coronal.pth"))
    else:  # orientation == "sagittal"
        model.load_state_dict(torch.load("unet_PED_sagittal.pth"))

    # Define the directories
    t1c_dir = "pediatric_modalities/t1c"
    t2f_dir = "pediatric_modalities/t2f"
    output_dir =  "PED_ax_op"
    original_seg_dir = "ASNR-MICCAI-BraTS2023-PED-Challenge-TrainingData/"

    # Iterate over patient folders
    for patient_dir in sorted(os.listdir(t1c_dir), key=natural_keys):
        print(f"Processing patient: {patient_dir}")

        t1c_img = nib.load(os.path.join(t1c_dir, patient_dir)).get_fdata()
        t2f_img = nib.load(os.path.join(t2f_dir, patient_dir.replace("t1c", "t2f"))).get_fdata()

        # Placeholder for the output segmentation slices
        seg_slices = []
        
        if orientation == "axial":
            shape = t1c_img.shape[-1]
        elif orientation == "coronal":
            shape = t1c_img.shape[1]
        else:
            shape = t1c_img.shape[0]
        
        for i in range(shape):
            if orientation == "axial":
                t1c_img_slice = min_max_normalization(t1c_img[:, :, i])
                t2f_img_slice = min_max_normalization(t2f_img[:, :, i])
            elif orientation == "coronal":
                t1c_img_slice = min_max_normalization(t1c_img[:, i, :])
                t2f_img_slice = min_max_normalization(t2f_img[:, i, :])
            else:  # orientation == "sagittal"
                t1c_img_slice = min_max_normalization(t1c_img[i, :, :])
                t2f_img_slice = min_max_normalization(t2f_img[i, :, :])

            t1c_tensor = torch.from_numpy(np.array(t1c_img_slice,dtype = np.float32)[None, ...])

            t2f_tensor = torch.from_numpy(np.array(t2f_img_slice,dtype = np.float32)[None, ...])

            # Convert to tensors and stack
            stack = torch.cat((t1c_tensor, t2f_tensor), dim=0).unsqueeze(0).to(device)
            output = model(stack)

            _, preds = torch.max(output, 1)

            # Save the segmentation slice
            seg_slices.append(preds.squeeze().detach().cpu().numpy())

        # Stack the segmentation slices to form a 3D image
        if orientation == "axial":
            seg_3d = np.stack(seg_slices, axis=-1)
        elif orientation == "coronal":
            seg_3d = np.stack(seg_slices, axis=1)
        else:  # orientation == "sagittal"
            seg_3d = np.stack(seg_slices, axis=0)

        # Load a original segmentation file to get the affine and header
        original_seg_file = nib.load(os.path.join(original_seg_dir, patient_dir.replace('-t1c', '').split('.')[0], f"{patient_dir.replace('-t1c', '').split('.')[0]}-seg.nii.gz"))

        # Convert the 3D image to a Nifti1Image and save
        seg_nifti = nib.Nifti1Image(seg_3d, affine=original_seg_file.affine, header=original_seg_file.header)
        nib.save(seg_nifti, os.path.join(output_dir, f"{patient_dir}_{orientation}_segmentation.nii.gz"))

# Run the function for each orientation
for orientation in ["coronal", "axial", "sagittal"]:
    process_orientation(orientation)


Processing patient: BraTS-PED-00002-000-t1c.nii.gz
Processing patient: BraTS-PED-00003-000-t1c.nii.gz
Processing patient: BraTS-PED-00004-000-t1c.nii.gz


KeyboardInterrupt: 

In [4]:
import shutil
import os

directory_path = 'PED_Sliced/axial/t2f/BraTS-PED-00002-000/.ipynb_checkpoints'

# Check if the directory exists
if os.path.exists(directory_path) and os.path.isdir(directory_path):
    # Delete the directory
    shutil.rmtree(directory_path)
    print(f"The directory '{directory_path}' has been deleted.")
else:
    print(f"The directory '{directory_path}' does not exist.")


The directory 'PED_Sliced/axial/t2f/BraTS-PED-00002-000/.ipynb_checkpoints' has been deleted.


In [9]:
import os
import torch
import numpy as np
from PIL import Image
import nibabel as nib

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# Load the models
model_coronal = UNet(in_channels=2, out_channels=4).to(device)
model_coronal.load_state_dict(torch.load('unet_PED_coronal.pth'))
model_coronal.eval()

model_axial = UNet(in_channels=2, out_channels=4).to(device)
model_axial.load_state_dict(torch.load('unet_PED_axial.pth'))
model_axial.eval()

model_sagittal = UNet(in_channels=2, out_channels=4).to(device)
model_sagittal.load_state_dict(torch.load('unet_PED_sagittal.pth'))
model_sagittal.eval()

# Define the function to load the slices
def load_slices(t1c_folder, t2f_folder):
    t1c_files = sorted(os.listdir(t1c_folder))
    t2f_files = sorted(os.listdir(t2f_folder))

    slices = []
    for t1c_file, t2f_file in zip(t1c_files, t2f_files):
        t1c_slice = np.array(Image.open(os.path.join(t1c_folder, t1c_file)))
        t2f_slice = np.array(Image.open(os.path.join(t2f_folder, t2f_file)))
        
        # Stack the t1c and t2f slices along the channel dimension and normalize
        stacked_slice = np.stack((t1c_slice, t2f_slice), axis=0) / 255.0
        slices.append(stacked_slice)

    return slices

def get_prediction(model, data):
    with torch.no_grad():
        data = data.to(device)
        preds = model.predict(data)
        return preds.cpu().numpy()

from scipy.ndimage import zoom

def reconstruct_3D(coronal, axial, sagittal, target_shape):
    # Reshape each stack to the target shape
    zoom_factors_coronal = [target_dim / current_dim for target_dim, current_dim in zip(target_shape, np.stack(coronal, axis=1).shape)]
    zoom_factors_axial = [target_dim / current_dim for target_dim, current_dim in zip(target_shape, np.stack(axial, axis=0).shape)]
    zoom_factors_sagittal = [target_dim / current_dim for target_dim, current_dim in zip(target_shape, np.swapaxes(np.stack(sagittal, axis=2), 0, 1).shape)]
    
    stack_coronal = zoom(np.stack(coronal, axis=1), zoom_factors_coronal)
    stack_axial = zoom(np.stack(axial, axis=0), zoom_factors_axial)
    stack_sagittal = zoom(np.swapaxes(np.stack(sagittal, axis=2), 0, 1), zoom_factors_sagittal)

    assert stack_coronal.shape == stack_axial.shape == stack_sagittal.shape == target_shape

    stack_list = [stack_coronal, stack_axial, stack_sagittal]

    final_stack = np.empty(target_shape, dtype=np.int)

    for i in range(target_shape[0]):
        for j in range(target_shape[1]):
            for k in range(target_shape[2]):
                voxel_values = [stack[i, j, k] for stack in stack_list]
                final_stack[i, j, k] = max(set(voxel_values), key=voxel_values.count)

    return final_stack

# Root folders
axial_t1c_root = "PED_Sliced/axial/t1c/"
axial_t2f_root = "PED_Sliced/axial/t2f/"
sagittal_t1c_root = "PED_Sliced/sagittal/t1c/"
sagittal_t2f_root = "PED_Sliced/sagittal/t2f/"
coronal_t1c_root = "PED_Sliced/coronal/t1c/"
coronal_t2f_root = "PED_Sliced/coronal/t2f/"
seg_root = "pediatric_modalities/seg/"
output_folder = 'PED_Sliced/reconstructed/'

# For each patient folder
for patient_folder in sorted(os.listdir(axial_t1c_root)):
    axial_t1c_folder = os.path.join(axial_t1c_root, patient_folder)
    axial_t2f_folder = os.path.join(axial_t2f_root, patient_folder)
    sagittal_t1c_folder = os.path.join(sagittal_t1c_root, patient_folder)
    sagittal_t2f_folder = os.path.join(sagittal_t2f_root, patient_folder)
    coronal_t1c_folder = os.path.join(coronal_t1c_root, patient_folder)
    coronal_t2f_folder = os.path.join(coronal_t2f_root, patient_folder)
    seg_file = os.path.join(seg_root, patient_folder + '-seg.nii.gz')

    axial_slices = load_slices(axial_t1c_folder, axial_t2f_folder)
    sagittal_slices = load_slices(sagittal_t1c_folder, sagittal_t2f_folder)
    coronal_slices = load_slices(coronal_t1c_folder, coronal_t2f_folder)

    axial_preds = [get_prediction(model_axial, torch.tensor([slice], dtype=torch.float32)) for slice in axial_slices]
    sagittal_preds = [get_prediction(model_sagittal, torch.tensor([slice], dtype=torch.float32)) for slice in sagittal_slices]
    coronal_preds = [get_prediction(model_coronal, torch.tensor([slice], dtype=torch.float32)) for slice in coronal_slices]

    # reconstructed_3D_image = reconstruct_3D(coronal_preds, axial_preds, sagittal_preds)

    # Use the affine from the original segmentation
    seg_nii = nib.load(seg_file)
    target_shape = seg_nii.shape
    seg_affine = seg_nii.affine
    
    reconstructed_3D_image = reconstruct_3D(coronal_preds, axial_preds, sagittal_preds, target_shape)


    # Save the reconstructed 3D image
    output_filename = os.path.join(output_folder, patient_folder + '.nii.gz')
    img_nifti = nib.Nifti1Image(reconstructed_3D_image, seg_affine)
    nib.save(img_nifti, output_filename)


RuntimeError: sequence argument must have length equal to input rank

In [11]:
import numpy as np
import torch
from scipy.ndimage import zoom
import nibabel as nib
from glob import glob
from torchvision import transforms
from PIL import Image
import os

# Modify the path to the saved models according to your setup
model_axial_path = 'unet_PED_axial.pth'
model_sagittal_path = 'unet_PED_sagittal.pth'
model_coronal_path = 'unet_PED_coronal.pth'

# Load your models
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# Load the models
model_coronal = UNet(in_channels=2, out_channels=4).to(device)
model_coronal.load_state_dict(torch.load('unet_PED_coronal.pth'))
model_coronal.eval()

model_axial = UNet(in_channels=2, out_channels=4).to(device)
model_axial.load_state_dict(torch.load('unet_PED_axial.pth'))
model_axial.eval()

model_sagittal = UNet(in_channels=2, out_channels=4).to(device)
model_sagittal.load_state_dict(torch.load('unet_PED_sagittal.pth'))
model_sagittal.eval()

# Load the slices for each plane
axial_t1c_root = "PED_Sliced/axial/t1c/"
axial_t2f_root = "PED_Sliced/axial/t2f/"
sagittal_t1c_root = "PED_Sliced/sagittal/t1c/"
sagittal_t2f_root = "PED_Sliced/sagittal/t2f/"
coronal_t1c_root = "PED_Sliced/coronal/t1c/"
coronal_t2f_root = "PED_Sliced/coronal/t2f/"

def load_slices(t1c_folder, t2f_folder):
    t1c_files = sorted(glob(os.path.join(t1c_folder, '*.png')))
    t2f_files = sorted(glob(os.path.join(t2f_folder, '*.png')))
    
    slices = []
    for t1c_file, t2f_file in zip(t1c_files, t2f_files):
        t1c_img = Image.open(t1c_file)
        t2f_img = Image.open(t2f_file)
        
        t1c_tensor = transforms.functional.to_tensor(t1c_img).unsqueeze(0)
        t2f_tensor = transforms.functional.to_tensor(t2f_img).unsqueeze(0)
        
        slice = torch.cat((t1c_tensor, t2f_tensor), dim=1)
        slices.append(slice.numpy())
        
    return slices

def get_prediction(model, data):
    data = torch.tensor(data, dtype=torch.float32).to(device)
    with torch.no_grad():
        preds = model.predict(data)
        return preds.cpu().numpy()

def reconstruct_3D(coronal, axial, sagittal, target_shape):
    # Reshape each stack to the target shape
    stack_coronal = zoom(np.stack(coronal, axis=1), target_shape)
    stack_axial = zoom(np.stack(axial, axis=0), target_shape)
    stack_sagittal = zoom(np.swapaxes(np.stack(sagittal, axis=2), 0, 1), target_shape)
    
    # Assert all stacks have the same shape
    assert stack_coronal.shape == stack_axial.shape == stack_sagittal.shape

    # Majority voting
    stack_list = [stack_coronal, stack_axial, stack_sagittal]
    reconstructed_3D_image = np.stack(stack_list).astype(int)
    reconstructed_3D_image = np.argmax(np.bincount(reconstructed_3D_image.reshape(-1), minlength=len(stack_list)).reshape(-1, *stack_coronal.shape), axis=0)
    
    return reconstructed_3D_image

# Getting 2D predictions for each plane
axial_slices = load_slices(axial_t1c_folder, axial_t2f_folder)
sagittal_slices = load_slices(sagittal_t1c_folder, sagittal_t2f_folder)
coronal_slices = load_slices(coronal_t1c_folder, coronal_t2f_folder)

axial_preds = [get_prediction(model_axial, slice) for slice in axial_slices]
sagittal_preds = [get_prediction(model_sagittal, slice) for slice in sagittal_slices]
coronal_preds = [get_prediction(model_coronal, slice) for slice in coronal_slices]

# The target shape for the 3D reconstructed image, set to the original image shape
target_shape = (240, 240, 155)

reconstructed_3D_image = reconstruct_3D(coronal_preds, axial_preds, sagittal_preds, target_shape)

# Saving the output 3D image as a nii.gz file
output_folder = "PED_output_folder"

# Use the affine from the original segmentation
seg_nii = nib.load("pediatric_modalities/seg/BraTS-PED-00002-000-seg.nii.gz")
seg_affine = seg_nii.affine

output_img = nib.Nifti1Image(reconstructed_3D_image, affine=seg_affine)
nib.save(output_img, os.path.join(output_folder, "reconstructed_3D_image.nii.gz"))


RuntimeError: sequence argument must have length equal to input rank