# This is the training and test program for 3D U-net for MP2RAGE MRI images
Training was done using the MPI-LEMON dataset of MP2RAGE images found at [MPI-LEMON MRI Download Page](https://fcon_1000.projects.nitrc.org/indi/retro/MPI_LEMON/downloads/download_MRI.html)

Make sure to update and verify all of the file and folder directories

Details about hardware requirements will be specified in the README


## Import packages

In [None]:
import torch
import torch.nn as nn
import scipy.ndimage as ndi
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
import os
import nibabel as nib
import numpy as np
from skimage.filters import threshold_otsu
from skimage.morphology import closing, remove_small_objects, disk, ball
from scipy.ndimage import binary_fill_holes, label
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

### If necessary: Install required packages

Only with Google colab

In [None]:
#!pip install nibabel

### If necessary: mount to drive
Only with Google colab

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

## Creating foreground mask using percentage threshold

In [None]:
#Generate a 3D binary mask using a percentage threshold and morphological operations.
def create_foreground_mask(volume, threshold_percentage=0.05):
    # Threshold to remove air/background
    mask = volume > (threshold_percentage * np.max(volume))

    # Apply 3D morphological operations
    mask = closing(mask, ball(radius=3))
    mask = binary_fill_holes(mask)
    mask = remove_small_objects(mask, min_size=100)

    return mask.astype(np.uint8)

#Save the mask as a .nii.gz file with the same affine and header as the input image.
def save_mask(mask, reference_nii, output_path):
    mask_nifti = nib.Nifti1Image(mask.astype(np.uint8), affine=reference_nii.affine, header=reference_nii.header)
    nib.save(mask_nifti, output_path)

def create_mask(img_data):
    # 1) Threshold to remove air/background
    mask0 = img_data > (0.05 * img_data.max())

    mask1 = closing(mask0, ball(radius=3))

    # 2) Keep only the largest connected component (the head)
    labels, n = ndi.label(mask1)
    sizes = ndi.sum(mask1, labels, range(1, n+1))
    largest = np.argmax(sizes) + 1
    whole_head_mask = (labels == largest).astype(np.uint8)
    return whole_head_mask

In [None]:
#generate masks for every inv2 volume
def process_all_nii(input_folder, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    nii_files = [f for f in os.listdir(input_folder) if f.endswith(".nii.gz")]

    for f in tqdm(nii_files, desc="Processing subjects"):
        input_path = os.path.join(input_folder, f)
        output_path = os.path.join(output_folder, f.replace(".nii.gz", "_mask.nii.gz"))

        nii = nib.load(input_path)
        vol = nii.get_fdata()

        mask = create_mask(vol)
        save_mask(mask, nii, output_path)

    print(f"\nDone! {len(nii_files)} mask(s) saved to: {output_folder}")

input_folder = "./lemon_data/inv2_volumes"
output_folder = "./lemon_data/inv2_masks"

process_all_nii(input_folder, output_folder)

## Create dataset of all nii.gz files for training

In [None]:
#dataset of nii.gz files
class NiftiDataset3D(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.nii') or f.endswith('.nii.gz')])
        self.mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.nii') or f.endswith('.nii.gz')])
        self.transform = transform   #none for now

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = nib.load(self.image_paths[idx]).get_fdata()
        mask = nib.load(self.mask_paths[idx]).get_fdata()

        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0) # Add channel dimension


        if self.transform:
            img, mask = self.transform(img, mask)

        return img, mask
#image and mask
inv2_dir = './lemon_data/inv2_volumes'  # Folder of .nii.gz volumes
mask_dir = './lemon_data/inv2_masks'    # Corresponding masks

#create inv2 dataset
inv2_dataset = NiftiDataset3D(inv2_dir, mask_dir)
inv2_loader = DataLoader(inv2_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

## Create 3d U-net model

In [None]:
#convolution block for a u-net
class ConvBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

#3D U-net class
class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.enc1 = ConvBlock3D(in_channels, 32)
        self.pool1 = nn.MaxPool3d(2)

        self.enc2 = ConvBlock3D(32, 64)
        self.pool2 = nn.MaxPool3d(2)

        self.bottleneck = ConvBlock3D(64, 128)

        self.upconv2 = nn.ConvTranspose3d(128, 64, 2, stride=2)
        self.dec2 = ConvBlock3D(128, 64)

        self.upconv1 = nn.ConvTranspose3d(64, 32, 2, stride=2)
        self.dec1 = ConvBlock3D(64, 32)

        self.final = nn.Conv3d(32, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        b = self.bottleneck(self.pool2(e2))

        d2 = self.upconv2(b)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        return self.final(d1)

Dice loss function

In [None]:
#loss function: dice loss
def dice_loss(pred, target, smooth=1e-5):
    pred = torch.sigmoid(pred)  # if using raw logits
    pred_flat = pred.contiguous().view(-1)
    target_flat = target.contiguous().view(-1).float()
    intersection = (pred_flat * target_flat).sum()
    return 1 - (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)

Two models: One trained with inv2 and one with t1map

In [None]:
#use cuda or cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Inv2 model
inv2model = UNet3D().to(device)
#Adam optimizer
optimizer = torch.optim.Adam(inv2model.parameters(), lr=1e-4)

## Train model

In [None]:
#training loop
def train(epochs, model, loader, accumulate_grad_batches=4):
    model.train()
    scaler = GradScaler()  # Initialize GradScaler for mixed precision
    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()
        for i, (imgs, masks) in enumerate(loader):
            imgs, masks = imgs.to(device), masks.to(device)

            with autocast():  # Autocast for mixed precision
                outputs = model(imgs)
                loss = dice_loss(outputs, masks)
                loss = loss / accumulate_grad_batches

            scaler.scale(loss).backward()  # Scale the loss and call backward()

            if (i + 1) % accumulate_grad_batches == 0:
                scaler.step(optimizer)  # Optimizer step with scaler
                scaler.update()  # Update the scaler
                optimizer.zero_grad()

            total_loss += loss.item() * accumulate_grad_batches

        if (i + 1) % accumulate_grad_batches != 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        print(f"Epoch {epoch+1}: Loss = {total_loss/len(loader):.4f}")

### training
Do not load parameters first before your first training run

In [None]:
#load parameters
inv2model.load_state_dict(torch.load('./3d_segmentation_inv2.pth', map_location=torch.device('cpu')))

In [None]:
#train and save parameters
train(5, inv2model, inv2_loader)
torch.save(inv2model.state_dict(), './3d_segmentation.pth')

## Predicting mask to test
functions to predict and edit volume based on mask

In [None]:
#function to predict mask using data
def predict(model, image_tensor):
    model.eval()
    with torch.no_grad():
        # Reshape tensor to [Batch, Channels, Depth, Height, Width]
        # Assuming image_tensor is [D, H, W] or [1, D, H, W]
        if image_tensor.ndim == 3:
            image_tensor = image_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, D, H, W]
        elif image_tensor.ndim == 4:
            image_tensor = image_tensor.unsqueeze(0) # [1, 1, D, H, W]

        image_tensor = image_tensor.to(device)
        output = model(image_tensor)
        return output.cpu()

#edit an image based on mask
def maskEditing(image, mask, threshold=-0.5):
    new_img = torch.tensor(image.copy())  # Create a new tensor from a NumPy copy
    if not torch.is_tensor(mask):
        mask = torch.tensor(mask)

    # Ensure mask has the same spatial dimensions as new_img
    # Assuming mask is [1, 1, D, H, W] and new_img is [D, H, W]
    if mask.ndim == 5:
        mask = mask.squeeze(0).squeeze(0) # [D, H, W]

    for l in range(mask.size(0)): # Iterate over depth
        for w in range(mask.size(1)): # Iterate over height
            for h in range(mask.size(2)): # Iterate over width
                 if mask[l, w, h] < threshold:
                    new_img[l, w, h] = -1  #set background voxels to -1

    return new_img  # return as numpy array to keep format consistent

In [None]:
slice_idx = 125  #index of one layer of head

# Load the full 3D volume and mask
full_inv2_volume = inv2_dataset[12][0]  # Shape: [1, D, H, W]
full_mask_volume = inv2_dataset[12][1]  # Shape: [1, D, H, W]

#load corresponding t1map image
t1map = torch.from_numpy(nib.load("/content/drive/My Drive/lemon_data/t1map_volumes/sub-032394_ses-01_acq-mp2rage_T1map.nii.gz").get_fdata()).float()
t1map_slice = t1map[slice_idx, :, :] #get slice of t1map image
print(type(t1map), type(full_inv2_volume), t1map.size())

t1map_mask = predict(inv2model, t1map) #generate mask based on t1map

binaryMask = (t1map_mask > 0).float() #create binary mask

#edit t1map image based on mask
edited_t1map = maskEditing(t1map.numpy(), binaryMask, threshold=0.5) # Pass numpy array to maskEditing
edited_t1map_slice = edited_t1map[slice_idx, :, :]

plt.subplot(1, 4, 1)
plt.imshow(t1map_mask.squeeze().detach().numpy()[slice_idx, :, :], cmap="gray") #display ground truth mask
plt.subplot(1, 4, 2)
plt.imshow(binaryMask.squeeze().detach().numpy()[slice_idx, :, :], cmap="gray") #display generated mask
plt.subplot(1, 4, 3)
plt.imshow(t1map_slice, cmap="gray") #display original t1map image
plt.subplot(1, 4, 4)
plt.imshow(edited_t1map_slice, cmap="gray") #display edited t1map image