In [15]:
import numpy as np
import pandas as pd
import os, glob
import nibabel as nib
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import itertools

In [16]:
# Separately collects unique subject name from files name
def getFilename(full_dir):
    _,filename = full_dir.split('\\')
    print(filename)
    subject,_,_,_,_= filename.split('_')
    return str(subject)

In [17]:
# Created empty arrays to collects by the categories
segm_byInst = []

far_flair_byInst = []
far_t1_byInst = []
far_t1ce_byInst = []
far_t2_byInst = []
far_adc_byInst = []

Near_flair_byinst = []
Near_t1_byinst = []
Near_t1ce_byinst = []
Near_t2_byinst = []
Near_adc_byinst = []

In [18]:
# Call first institution subjects
first_seg_dir = 'segmentation images directory'
first_far_dir = 'collected Far patches'
first_near_dir = 'collected Near patches'

# Collects all the files for first institutions
first_segm_dir = sorted(glob.glob(os.path.join(first_seg_dir, "*_segm*.nii.gz")))

first_far_flair_dir = sorted(glob.glob(os.path.join(first_far_dir, "*_flairPatch*.nii.gz"))) 
first_far_t1_dir = sorted(glob.glob(os.path.join(first_far_dir, "*_t1Patch*.nii.gz"))) 
first_far_t1ce_dir = sorted(glob.glob(os.path.join(first_far_dir, "*_t1cePatch*.nii.gz"))) 
first_far_t2_dir = sorted(glob.glob(os.path.join(first_far_dir, "*_t2Patch*.nii.gz"))) 
first_far_adc_dir = sorted(glob.glob(os.path.join(first_far_dir, "*_adcPatch*.nii.gz"))) 

first_near_flair_dir = sorted(glob.glob(os.path.join(first_near_dir, "*_flairPatch*.nii.gz"))) 
first_near_t1_dir = sorted(glob.glob(os.path.join(first_near_dir, "*_t1Patch*.nii.gz"))) 
first_near_t1ce_dir = sorted(glob.glob(os.path.join(first_near_dir, "*_t1cePatch*.nii.gz"))) 
first_near_t2_dir = sorted(glob.glob(os.path.join(first_near_dir, "*_t2Patch*.nii.gz"))) 
first_near_adc_dir = sorted(glob.glob(os.path.join(first_near_dir, "*_adcPatch*.nii.gz"))) 

# Added to the array by categories
segm_byInst.append(first_segm_dir)

far_flair_byInst.append(first_far_flair_dir)
far_t1_byInst.append(first_far_t1_dir)
far_t1ce_byInst.append(first_far_t1ce_dir)
far_t2_byInst.append(first_far_t2_dir)
far_adc_byInst.append(first_far_adc_dir)

Near_flair_byinst.append(first_near_flair_dir)
Near_t1_byinst.append(first_near_t1_dir)
Near_t1ce_byinst.append(first_near_t1ce_dir)
Near_t2_byinst.append(first_near_t2_dir)
Near_adc_byinst.append(first_near_adc_dir)

In [23]:
# Call second institutino subjects
second_seg_dir = 'segmentation images directory'
second_far_dir = 'collected Far patches'
second_near_dir = 'collected Near patches'

# Collects all the files for second institutions
second_segm_dir = sorted(glob.glob(os.path.join(second_seg_dir, "*_segm*.nii.gz")))

second_far_flair_dir = sorted(glob.glob(os.path.join(second_far_dir, "*_flairPatch*.nii.gz"))) 
second_far_t1_dir = sorted(glob.glob(os.path.join(second_far_dir, "*_t1Patch*.nii.gz"))) 
second_far_t1ce_dir = sorted(glob.glob(os.path.join(second_far_dir, "*_t1cePatch*.nii.gz"))) 
second_far_t2_dir = sorted(glob.glob(os.path.join(second_far_dir, "*_t2Patch*.nii.gz"))) 
second_far_adc_dir = sorted(glob.glob(os.path.join(second_far_dir, "*_adcPatch*.nii.gz"))) 

second_near_flair_dir = sorted(glob.glob(os.path.join(second_near_dir, "*_flairPatch*.nii.gz"))) 
second_near_t1_dir = sorted(glob.glob(os.path.join(second_near_dir, "*_t1Patch*.nii.gz"))) 
second_near_t1ce_dir = sorted(glob.glob(os.path.join(second_near_dir, "*_t1cePatch*.nii.gz"))) 
second_near_t2_dir = sorted(glob.glob(os.path.join(second_near_dir, "*_t2Patch*.nii.gz"))) 
second_near_adc_dir = sorted(glob.glob(os.path.join(second_near_dir, "*_adcPatch*.nii.gz")))

segm_byInst.append(second_segm_dir)

far_flair_byInst.append(second_far_flair_dir)
far_t1_byInst.append(second_far_t1_dir)
far_t1ce_byInst.append(second_far_t1ce_dir)
far_t2_byInst.append(second_far_t2_dir)
far_adc_byInst.append(second_far_adc_dir)

Near_flair_byinst.append(second_near_flair_dir)
Near_t1_byinst.append(second_near_t1_dir)
Near_t1ce_byinst.append(second_near_t1ce_dir)
Near_t2_byinst.append(second_near_t2_dir)
Near_adc_byinst.append(second_near_adc_dir)

In [None]:
# You may add more institutinos as you wish

In [24]:
# Load nifti files
def LoadingImage(dir):
    
    nifti_image = nib.load(dir)
    image = np.asarray(nifti_image.dataobj)
    header = nifti_image.header
    imgaffine = nifti_image.affine
    
    return image, header, imgaffine

In [25]:
# Load images and stack them for training to sahpe of (5,5,5,5) 
def load_and_stack_image_dirs(dirs):
    stacked_images = []
    for i in range(len(dirs[0])):  # Assume all lists have same length
        # Load images from each directory, then stack along a new dimension
        images = [nib.load(dir[i]).get_fdata() for dir in dirs]
        for image in images:
            if image.shape != (5, 5, 5):
                print(f'error: ' + dirs[0][i])
                print(image.shape)
        stacked_image = np.stack(images, axis=0)  # Now each image is treated as a separate channel
        stacked_images.append(stacked_image)

    return np.array(stacked_images)

In [26]:
# Simple 3D CNN network
class Simple3DCNN(nn.Module):
    def __init__(self, num_classes):
        super(Simple3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(5, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 5 * 5 * 5, 128)  # 128 is an arbitrary choice, feel free to change
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, num_classes)  

    def forward(self, x):
        x = torch.relu(self.conv1(x))  # apply conv1, then ReLU
        x = torch.relu(self.conv2(x))  # apply conv2, then ReLU
        x = x.view(x.size(0), -1)  # flatten the tensor
        x = torch.relu(self.fc1(x))  # apply first fully connected layer, then ReLU
        x = self.dropout(x)  # apply dropout
        x = self.fc2(x)  # apply second fully connected layer
        return x

In [27]:
subjects = ['first_institution', 'second_institution']

In [None]:
# number of institutions you use
n = 2

# Iterate by each institution and perform leave-one-site-out training. 
for k in range(n):
    segm_combine = segm_byInst[:k] + segm_byInst[k+1:]

    far_flair_combine = far_flair_byInst[:k] + far_flair_byInst[k+1:]
    far_t1_combine = far_t1_byInst[:k] + far_t1_byInst[k+1:]
    far_t1ce_combine = far_t1ce_byInst[:k] + far_t1ce_byInst[k+1:]
    far_t2_combine = far_t2_byInst[:k] + far_t2_byInst[k+1:]
    far_adc_combine = far_adc_byInst[:k] + far_adc_byInst[k+1:]

    Near_flair_combine = Near_flair_byinst[:k] + Near_flair_byinst[k+1:]
    Near_t1_combine = Near_t1_byinst[:k] + Near_t1_byinst[k+1:]
    Near_t1ce_combine = Near_t1ce_byinst[:k] + Near_t1ce_byinst[k+1:]
    Near_t2_combine = Near_t2_byinst[:k] + Near_t2_byinst[k+1:]
    Near_adc_combine = Near_adc_byinst[:k] + Near_adc_byinst[k+1:]

    far_flair_dir = list(itertools.chain.from_iterable(far_flair_combine))
    far_t1_dir = list(itertools.chain.from_iterable(far_t1_combine))
    far_t1ce_dir = list(itertools.chain.from_iterable(far_t1ce_combine))
    far_t2_dir = list(itertools.chain.from_iterable(far_t2_combine))
    far_adc_dir = list(itertools.chain.from_iterable(far_adc_combine))

    rec_flair_dir = list(itertools.chain.from_iterable(Near_flair_combine))
    rec_t1_dir = list(itertools.chain.from_iterable(Near_t1_combine))
    rec_t1ce_dir = list(itertools.chain.from_iterable(Near_t1ce_combine))
    rec_t2_dir = list(itertools.chain.from_iterable(Near_t2_combine))
    rec_adc_dir = list(itertools.chain.from_iterable(Near_adc_combine))

    far_dirs = [far_flair_dir, far_t1_dir, far_t1ce_dir, far_t2_dir, far_adc_dir]
    rec_dirs = [rec_flair_dir, rec_t1_dir, rec_t1ce_dir, rec_t2_dir, rec_adc_dir]

    #(N,C,H,W,D)
    far_images = load_and_stack_image_dirs(far_dirs)
    rec_images = load_and_stack_image_dirs(rec_dirs)

    if rec_images.shape[0] > far_images.shape[0]:
        indices = np.random.choice(rec_images.shape[0], far_images.shape[0], replace=False)
        rec_images = rec_images[indices]
    else:
        indices = np.random.choice(far_images.shape[0], rec_images.shape[0], replace=False)
        far_images = far_images[indices]

    all_images = np.concatenate((far_images, rec_images), axis=0)
    print(all_images.shape)

    y = np.concatenate([np.zeros(len(far_images)), np.ones(len(rec_images))], axis=0)  # Replace with your actual labels
    print(y.shape)

    # Split data into training and validation sets
    x_train, x_val, y_train, y_val = train_test_split(all_images, y, test_size=0.3, random_state=42, shuffle = True)

    # Convert to PyTorch tensors
    x_train = torch.from_numpy(x_train).float()
    y_train = torch.from_numpy(y_train).long()
    x_val = torch.from_numpy(x_val).float()
    y_val = torch.from_numpy(y_val).long()

    # Create DataLoaders
    trainset = TensorDataset(x_train, y_train)
    valset = TensorDataset(x_val, y_val)
    trainloader = DataLoader(trainset, batch_size=16, shuffle=True)
    valloader = DataLoader(valset, batch_size=16, shuffle=False)

    # Define the model, loss function, and optimizer
    model = Simple3DCNN(num_classes=2)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

    # Lists for storing losses and accuracies
    train_losses = []
    val_losses = []
    val_accuracies = []

    # Variables for early stopping
    best_val_loss = float('inf')  # set initial best validation loss to infinity
    patience = 10
    epochs_no_improve = 0

    # Train the model
    num_epochs = 1000
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        for i, (images, labels) in enumerate(trainloader):
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            train_loss += loss.item()

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        train_losses.append(train_loss / len(trainloader))  # Compute average loss

        # Validation phase
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in valloader:
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            val_accuracy = correct / total
            val_accuracies.append(val_accuracy)  # Store validation accuracy
        val_losses.append(val_loss / len(valloader))  # Compute average loss

        # Early stopping check
        if val_loss < best_val_loss:  # check if the current validation loss is lower than the best validation loss
            best_val_loss = val_loss
            epochs_no_improve = 0
            # Save the model weights
            torch.save(model.state_dict(), 'Save model in ouput directory')
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print("Early stopping")
                break

        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_losses[-1]}, Validation Loss: {val_losses[-1]}, Validation Accuracy: {val_accuracy * 100}%')

    # Plot training and validation losses
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training loss')
    plt.plot(val_losses, label='Validation loss')
    plt.title('Training and validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plot validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies, label='Validation accuracy')
    plt.title('Validation accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()