In [7]:
import numpy as np
import pandas as pd
import os, glob
import nibabel as nib
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
# Load Standard Scaler (same as z-normalization) for normalization of whole image
scaler = StandardScaler()

In [8]:
# Call subjects from desired institution(this code only works for one institutions) for testing
seg_dir = 'segmentation images directory'
baseline_dir = 'baseline images(t1,t1ce,t2,flair,adc)'

segm_dir = sorted(glob.glob(os.path.join(seg_dir, "*_seg*.nii.gz"))) 

flair_dir = sorted(glob.glob(os.path.join(baseline_dir, "*_flair*.nii.gz"))) 
t1_dir = sorted(glob.glob(os.path.join(baseline_dir, "*_t1_*.nii.gz"))) 
t1ce_dir = sorted(glob.glob(os.path.join(baseline_dir, "*_t1ce*.nii.gz"))) 
t2_dir = sorted(glob.glob(os.path.join(baseline_dir, "*_t2*.nii.gz"))) 
adc_dir = sorted(glob.glob(os.path.join(baseline_dir, "*_adc*.nii.gz"))) 

In [9]:
# Load nifti files into numpy array
def LoadingImage(dir):
    
    nifti_image = nib.load(dir)
    image = np.asarray(nifti_image.dataobj)
    header = nifti_image.header
    imgaffine = nifti_image.affine
    
    return image, header, imgaffine

In [10]:
# Separately collects unique subject name from files name
def getFilename(full_dir):
    _,filename = full_dir.split('\\')
    print(filename)
    subject,_,_,_,_= filename.split('_')
    return str(subject)

In [11]:
# Simple 3D CNN network (same as training network)
class Simple3DCNN(nn.Module):
    def __init__(self, num_classes):
        super(Simple3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(5, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 5 * 5 * 5, 128)  # 128 is an arbitrary choice, feel free to change
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, num_classes)  

    def forward(self, x):
        x = torch.relu(self.conv1(x))  # apply conv1, then ReLU
        x = torch.relu(self.conv2(x))  # apply conv2, then ReLU
        x = x.view(x.size(0), -1)  # flatten the tensor
        x = torch.relu(self.fc1(x))  # apply first fully connected layer, then ReLU
        x = self.dropout(x)  # apply dropout
        x = self.fc2(x)  # apply second fully connected layer
        return x

In [None]:
# Contunuous case
# Load the model
model = Simple3DCNN(num_classes=2)
model.load_state_dict(torch.load('Load the model weights that you trained'))
model.eval()

# iterate each subjects
# Assuming the file names in each directory match up
for i in range(len(segm_dir)):
    # Seg (240, 240, 155)
    seg_img, seg_hdr, seg_imgaffine = LoadingImage(segm_dir[i])

    # Load your 5-channel 3D image with size (240,240,155,5)
    img_dirs = [flair_dir[i], t1_dir[i], t1ce_dir[i], t2_dir[i], adc_dir[i]]
    
    # Now the image is in (C, H, W, D) format
    img = np.stack([nib.load(dir).get_fdata() for dir in img_dirs], axis=0)  
    
    # Normalize each channel separately
    img = np.stack([scaler.fit_transform(channel.reshape(-1, 1)).reshape(channel.shape) for channel in img], axis=0)

    # Find the voxels with a label of 2
    voxels = np.where(seg_img == 2)

    # Create an array to hold the output values
    output = np.zeros_like(seg_img, dtype=float)

    # filename
    filename = getFilename(segm_dir[i])

    # Iterate over the voxels
    for voxel_index in range(len(voxels[0])):
        x, y, z = voxels[0][voxel_index], voxels[1][voxel_index], voxels[2][voxel_index]
        # Make sure the patch is completely inside the image
        if x-2 >= 0 and x+2 < 240 and y-2 >= 0 and y+2 < 240 and z-2 >= 0 and z+2 < 155:
            # Extract the patch around the voxel
            patch = img[:, x-2:x+3, y-2:y+3, z-2:z+3]
            # Preprocess the patch if needed, for example if your model expects a certain shape
            patch = np.expand_dims(patch, axis=0)  # Now the patch is in (N, C, H, W, D) format
            patch = torch.from_numpy(patch).float()
            # Get the binary classification output from the model
            with torch.no_grad():
                outputs = model(patch)
                probabilities = torch.sigmoid(outputs).data
            # Put the output value back into the original voxel
            output[x, y, z] = probabilities[0][1].item()

    # Save the output
    nifti_img = nib.Nifti1Image(output, seg_imgaffine, seg_hdr)
    nib.save(nifti_img, 'save the result probaility map back into original space')