# Here are the library you need to import

In [1]:

import time
import torch
import torch.nn.functional as F_tor
from torch import nn
from torch.utils.data import Dataset
import glob
import scipy.io
import os
import numpy as np
from timm.models.layers import DropPath
from natsort import natsorted
from monai.losses import DiceLoss, DiceCELoss,GeneralizedDiceLoss,DiceFocalLoss
from monai.metrics import DiceMetric
from monai.inferers import SlidingWindowInferer
# from CONVIT_function import *
from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
    RandAffined,
    RandCropByLabelClassesd,
    SpatialPadd,
    RandAdjustContrastd,
    RandShiftIntensityd,
    ScaleIntensityd,
    NormalizeIntensityd,
    RandScaleIntensityd,
    RandGaussianNoised,
    RandGaussianSmoothd,
    ScaleIntensityRangePercentilesd,
    ResizeWithPadOrCropd
)
from monai.transforms import (CastToTyped,
                              Compose, CropForegroundd, EnsureChannelFirstd, LoadImaged,
                              NormalizeIntensity, RandCropByPosNegLabeld,
                              RandFlipd, RandGaussianNoised,
                              RandGaussianSmoothd, RandScaleIntensityd,
                              RandZoomd, SpatialCrop, SpatialPadd, EnsureTyped)
from monai.networks.nets import UNETR,VNet,DynUNet,SwinUNETR
from MLP_mixer import *

# Build the data loader using the monai library

In [2]:
# Here are the dataloader hyper-parameters, including the batch size for training and testing,
#class number (how many organs + 1 background),
# patch size (the actual input img size), image spacing, and color channel (usually 1 for medical images)
# And GPU
BATCH_SIZE_TRAIN = 2
BATCH_SIZE_TEST = 1
actual_batch = 400
patch_num = 2
class_num = 14
img_size = (256,256,160)
patch_size = (64,64,64)
spacing = (1,1,1.5)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# Here we use monai to process the nii data. If you use other format, check monai so it can read the data.
# load image using PILreader (read nii) -> add channel dimension to the image -> ensure orientation -> respacing all image to
# a same spacing -> intensity normalization -> padding or crop the boundary to ensure all images have same size

# Important notice 1: the Orientationd and Spacingd functions takes a lot of time. You can delete them in each transformation
# module depends on your requirements.

# Important notice 2:
# For CT which is quantative imaging, you must use this for normalization:
# ScaleIntensityRanged(
#     keys=["image"],
#     a_min=-self.min,
#     a_max=self.max,
#     b_min=0.0,
#     b_max=1.0,
#     clip=True,
# ),

# For MRI which is not quantative imaging, you must use this for normalization:
# ScaleIntensityRanged(
# ScaleIntensityd(
#     keys=["image"],
#     minv=0,
#     maxv=1.0,
# ),

class CustomDataset(Dataset):
    def __init__(self,imgs_path,mask_path,if_aug = False,if_train = True):
        self.imgs_path = imgs_path
        self.if_aug = if_aug
        self.if_train = if_train
        self.max = 1650
        self.min = -1024
        self.img_size = img_size
        self.patch_size = patch_size
        file_list = natsorted(glob.glob(self.imgs_path + "*"), key=lambda y: y.lower())
        mask_list = natsorted(glob.glob(mask_path + "*"), key=lambda y: y.lower())
        self.data = []
        self.label = []
        self.loader = LoadImaged(keys= ['image','label'],reader='nibabelreader')
        for img_path in file_list:
            class_name = img_path.split("/")[-1]
            self.data.append([img_path, class_name])
        for img_path in mask_list:
            class_name = img_path.split("/")[-1]
            self.label.append([img_path, class_name])
        self.class_num = class_num
        self.read_transforms = Compose(
                [
                    LoadImaged(keys=["image", "label"]),
                    AddChanneld(keys=["image", "label"]),
                    Orientationd(keys=["image", "label"], axcodes="RAS")])
        
        self.train_transforms = Compose(
                [
                    LoadImaged(keys=["image", "label"]),
                    AddChanneld(keys=["image", "label"]),
                    Orientationd(keys=["image", "label"], axcodes="RAS"),
                    Spacingd(
                        keys=["image", "label"],
                        pixdim=spacing,
                        mode=("bilinear", "nearest"),
                    ),
                    
                    # Normalization: choose the correct one based on the important notice 2 in the top of this block
#                     ScaleIntensityd(
#                         keys=["image"],
#                         minv=0,
#                         maxv=1.0,
#                     ),
                    ScaleIntensityRanged(
                        keys=["image"],
                        a_min=-self.min,
                        a_max=self.max,
                        b_min=0.0,
                        b_max=1.0,
                        clip=True,
                    ),
                    ResizeWithPadOrCropd(keys=["image", "label"],
                                    spatial_size = img_size,
                                    mode="constant",
                                    constant_values = 0,
                                    method = "end"),
                    RandCropByLabelClassesd(keys=["image", "label"],
                                            label_key="label",
                                            spatial_size = patch_size,
                                            num_classes=class_num,
                                            num_samples=patch_num),
                    
                    
                    # Augmentations: choose whatever you need
                    RandAffined(keys = ["image","label"], prob=0.5,
                        mode=("bilinear", "nearest"),
                        rotate_range = (0.2,0.2,0.2),
                        scale_range=((-0.3,0.3), (-0.3,0.3), (-0.3,0.3)),
                        padding_mode="border"),
                    RandAdjustContrastd(keys=["image"], prob=0.3, gamma=(0.95, 1.05)),
                    RandGaussianNoised(keys=["image"], std=0.01, prob=0.3),
                    RandGaussianSmoothd(
                        keys=["image"],
                        sigma_x=(0.5, 1.5),
                        sigma_y=(0.5, 1.5),
                        sigma_z=(0.5, 1.5),
                        prob=0.15,
                    ),
                    RandShiftIntensityd(
                        keys=["image"],
                        offsets=0.10,
                        prob=0.50,
                    ),
                    ToTensord(keys=["image", "label"]),
                ]
            )
        self.train_transforms_noaug = Compose(
                [
                    LoadImaged(keys=["image", "label"]),
                    AddChanneld(keys=["image", "label"]),
                    Orientationd(keys=["image", "label"], axcodes="RAS"),
                    Spacingd(
                        keys=["image", "label"],
                        pixdim=spacing,
                        mode=("bilinear", "nearest"),
                    ),
                    # ScaleIntensityd(
                    #     keys=["image"],
                    #     minv=0,
                    #     maxv=1.0,
                    # ),
                    ScaleIntensityRanged(
                        keys=["image"],
                        a_min=-self.min,
                        a_max=self.max,
                        b_min=0.0,
                        b_max=1.0,
                        clip=True,
                    ),
                    ResizeWithPadOrCropd(keys=["image", "label"],
                                    spatial_size = img_size,
                                    mode="constant",
                                    constant_values = 0,
                                    method = "end"),
                    RandCropByLabelClassesd(keys=["image", "label"],
                                            label_key="label",
                                            spatial_size = patch_size,
                                            num_classes=class_num,
                                            num_samples=patch_num),
                    ToTensord(keys=["image", "label"]),
                ]
            )    
        
        self.val_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                AddChanneld(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=spacing,
                    mode=("bilinear", "nearest"),
                ),
                # ScaleIntensityd(
                #         keys=["image"],
                #         minv=0,
                #         maxv=1.0,
                #     ),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-self.min,
                    a_max=self.max,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
                ToTensord(keys=["image", "label"]),
            ]
        )
        self.val_transforms_ori = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                AddChanneld(keys=["image", "label"]),
                Orientationd(keys=["image", "label"], axcodes="RAS"),
                Spacingd(
                    keys=["image", "label"],
                    pixdim=spacing,
                    mode=("bilinear", "nearest"),
                ),
                # ScaleIntensityd(
                #         keys=["image"],
                #         minv=0,
                #         maxv=1.0,
                #     ),
                ScaleIntensityRanged(
                        keys=["image"],
                        a_min=-self.min,
                        a_max=self.max,
                        b_min=0.0,
                        b_max=1.0,
                        clip=True,
                    ),
                ToTensord(keys=["image", "label"]),
            ]
        )
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx,):

        # aa = time.time() 
        img_path, class_name = self.data[idx]
        mask_path, class_name = self.label[idx]
        cao = {"image":img_path,'label':mask_path}

        if self.if_train is True:
            if self.if_aug is True:    
                affined_data_dict = self.train_transforms(cao)                    
                img = np.zeros([patch_num, self.patch_size[0], self.patch_size[1], self.patch_size[2]])
                label = np.zeros([patch_num, self.patch_size[0], self.patch_size[1], self.patch_size[2]])
                for i,after_l in enumerate(affined_data_dict):
                    img[i,:,:,:] = after_l['image']
                    label[i,:,:,:] = after_l['label']
    
                img_tensor = torch.unsqueeze(torch.from_numpy(img.copy()), 1).to(torch.float)
                label_tensor = torch.from_numpy(label.copy()).to(torch.int64)
                label_tensor = F_tor.one_hot(label_tensor,
                                    num_classes=self.class_num).permute(0,4,1,2,3).squeeze()
                
            elif self.if_aug is False:
                affined_data_dict = self.train_transforms_noaug(cao)      
                img_tensor = affined_data_dict['image'].to(torch.float)
                label_tensor = affined_data_dict['label'].to(torch.int64)                
                label_tensor = F_tor.one_hot(label_tensor,
                                    num_classes=self.class_num).permute(0,4,1,2,3).squeeze()  
        
        else:              
            datac_dict =  self.val_transforms(cao)
            ori_datac_dict =  self.val_transforms_ori(cao)
            label_tensor = (F_tor.one_hot(datac_dict['label'].to(torch.int64),
                                num_classes=self.class_num).permute(0,4,1,2,3)).squeeze()

            orilabel_tensor = (F_tor.one_hot(ori_datac_dict['label'].to(torch.int64),
                                num_classes=self.class_num).permute(0,4,1,2,3)).squeeze()

            
            img_tensor = {'data':datac_dict['image'].to(torch.float),
                            'ori_data':ori_datac_dict['image'].to(torch.float)}

            label_tensor = {'data':label_tensor.to(torch.int32),
                            'ori_data':orilabel_tensor.to(torch.int32)}

        return img_tensor, label_tensor

# Build the Token_MLP network

In [4]:
# in_channels: color channel for the input, usually 1 for medical images
# out_channels: number of the segmentation classes, # of organs + 1(background)
# depth: depth of the network
# feature_size: Token size, controling how dense of the information extracted by the token. Larger -> more information, but easier to overfit.
# hidden_size: Layer size, similar to the convolutional channel in CNNs. Larger -> more information, but easier to overfit.
# But notice, hidden_size = 512 means 64,128,256,512 since depth = 4.
# mlp_dim: MLP layer size in the MLP_Mixer, controling how much you want to learn from the token. Larger -> more information, but easier to overfit.

model =  MLP_MIXER(
    in_channels=1,
    out_channels=class_num,
    depth = 4,
    feature_size=512,
    hidden_size=512,
    mlp_dim=512,
).to(device)

# In case you want to use other famous segmentation networks. They are built from MONAI

In [5]:
# class ViTResNet(nn.Module):
#     def __init__(self, batch_size, num_classes=class_num, dim = 1024, num_tokens = 128, mlp_dim = 2048, heads = 16, depth = 16, emb_dropout = 0, dropout= 0.2):
#         super(ViTResNet, self).__init__()
#         self.L = num_tokens
#         self.cT = dim*1
#         self.mlp_dim = self.cT*2
        
#         # self.model = nnFormer(crop_size=(14,128,128),
#         #         embedding_dim=96,
#         #         input_channels=1, 
#         #         num_classes=class_num, 
#         #         conv_op=nn.Conv3d, 
#         #         patch_size=[1,4,4],
#         #         window_size=[[3,5,5],[3,5,5],[7,10,10],[3,5,5]],
#         #         down_stride=[[1,4,4],[1,8,8],[2,16,16],[4,32,32]],
#         #         depths=[2, 2, 2, 2],   
#         #         num_heads=[6, 12, 24, 48],
#         #         # patch_size=[2,4,4],
#         #         # window_size=[4,4,8,4],
#         #         deep_supervision=False)  
# #         self.model = SwinUNETR(
# #                 img_size = img_size,
# #                 in_channels=1,
# #                 out_channels = class_num,
# #                 depths = (2, 2, 2, 2),
# #                 num_heads= (3, 6, 12, 24),
# #                 feature_size = 96,
# #                 norm_name = "instance",
# #                 drop_rate = 0.0,
# #                 attn_drop_rate = 0.0,
# #                 dropout_path_rate = 0.2,
# #                 normalize = True,
# #                 use_checkpoint = False,
# #                 spatial_dims = 3,
# #             ).to(device)        
#         # self.model = DynUNet(
#         #         spatial_dims=3,
#         #         in_channels=1,
#         #         out_channels=class_num,
#         #         kernel_size=kernels,
#         #         strides=strides,
#         #         upsample_kernel_size=strides[1:],
#         #         norm_name="instance",
#         #         deep_supervision=False,
#         #         res_block = True
#         #     ) 
        

#         # self.model = VNet(
#         #         spatial_dims=3,
#         #         in_channels=1,
#         #         out_channels=class_num,
#         #         dropout_prob = 0
#         #     ).to(device)

#         # self.model = UNETR(
#         #     in_channels=1,
#         #     out_channels=class_num,
#         #     img_size=(96, 96, 96),
#         #     feature_size=16,
#         #     hidden_size=768,
#         #     mlp_dim=3072,
#         #     num_heads=12,
#         #     pos_embed='perceptron',
#         #     norm_name='instance',
#         #     conv_block=True,
#         #     res_block=True,
#         #     dropout_rate=0.0).to(device)
#     def forward(self, img, mask = None):
#         x_out = self.model(img)
#         return x_out
# model = ViTResNet(BATCH_SIZE_TRAIN).to(device)

# Build the loss functions for optimization and evaluation

In [6]:
loss_criterion1 =DiceCELoss(include_background=True,
                            to_onehot_y=False,
                            softmax=True,
                            squared_pred=False,
                            smooth_nr=1e-5,
                            smooth_dr=1e-5).to(device)
eval_criterion = DiceMetric(include_background=False,
                          reduction="mean",
                          get_not_nans=True)

# Build the optimizer

In [7]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
print('parameter number is '+str(pytorch_total_params))
torch.backends.cudnn.benchmark = True

lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
print('Learning rate is '+str(lr))

parameter number is 58652448
Learning rate is 0.001


# Build the training function. Run the training function once = one epoch

In [8]:
def train(model, optimizer,data_loader1, loss_history):
    
    #1: set the model to training mode
    total_samples = len(data_loader1.dataset)
    model.train()
    loss_sum = 0  
    dice_loss_sum = 0
    count = 0
    alpha = 0.4
    total_time = 0
    
    #2: Loop the whole dataset, x1 (traindata) is the image batch 
    for i, (x1,y1)in enumerate(data_loader1):


        traindata = x1.to(device)
        traintarget = y1.to(device)
        
        #3: since they are patch-based form, reshape the images so the patch channel converge into the batch channel
        traindata = traindata.view(-1,traindata.shape[2],traindata.shape[3],traindata.shape[4],traindata.shape[5])
        traintarget = traintarget.view(-1,traintarget.shape[2],traintarget.shape[3],traintarget.shape[4],traintarget.shape[5])
        
        #4: Optimize the Token_MLP network
        aa = time.time()
        optimizer.zero_grad()
        output = model(traindata)
            
        loss1 = loss_criterion1(output, traintarget)
        loss = loss1
        loss.backward()
        loss_sum += loss.detach().cpu().numpy()
        count += 1
        optimizer.step()
        print('optimization time: '+ str(time.time()-aa))
        total_time += time.time()-aa
        if i % 1 == 0:
            print('[' +  '{:5}'.format(i * BATCH_SIZE_TRAIN) + '/' + '{:5}'.format(total_samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_loader1)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))

    loss_sum /= count
    loss_history.append(loss_sum)
    print("Total time per sample is: "+str(total_time))
    print('Averaged loss is: '+ str(loss_sum))
    print('Averaged  Dice loss is: '+ str(dice_loss_sum/(i+1)))   

# Build the testing function.

In [9]:
# Build the patch-based prediction. You need to decide two parameters.
overlap = 0.5 # Overlap ratio between the prediction patches
mode = 'gaussian'  # Overlap mode between the prediction patches
inferer = SlidingWindowInferer(img_size, patch_num,overlap,mode =mode)

def evaluate(model,epoch,path, data_loader, loss_history,acc_history):
    
    #1: set the model to eval mode
    model.eval()
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0
    count = 0
    prediction = []
    true = []
    img = []
    loss_all = []
    acc_all = []
    with torch.no_grad():
        thetime = 0
        
        #2: Loop the testing dataset
        for i, (data, target) in enumerate(data_loader):
            
            
            #3: The following code will automatically convert the network output into the segmentation map
            aa = time.time()
            testtarget = target['ori_data'].to(device)
            testdata = data['data'].to(device)
            reconstructed_output = inferer(testdata,model)
            reconstructed_output = F_tor.interpolate(reconstructed_output,
                                      size=target['ori_data'].shape[2:], mode='trilinear', align_corners=False)
            output = F_tor.softmax(reconstructed_output,dim=1)
            loss = loss_criterion1(reconstructed_output, testtarget)            
            
            output1 = output.argmax(1)
            output1 = F_tor.one_hot(output1,
                        num_classes=class_num).permute(0,4,1,2,3)
            
            #4: Evaluation
            acc = eval_criterion(output1, testtarget)
            print(acc.cpu().numpy())

            loss_all.append(loss.cpu().numpy())
            acc_all.append(acc.cpu().numpy())
            thetime += time.time()-aa
            print('optimization time: '+ str(time.time()-aa))            
            
            img.append(data['ori_data'].squeeze().cpu().numpy())
            true.append(testtarget.squeeze().argmax(0).cpu().numpy())                                                             
            prediction.append(output1.squeeze().argmax(0).cpu().numpy())
            thetarget = testtarget.squeeze().argmax(0).cpu().numpy()
            thepred = output1.squeeze().argmax(0).cpu().numpy()

        loss_history.append(np.nanmean(loss_all))
        acc_history.append(np.nanmean(acc_all))        
      
        #5: Save the predictions into mat files
        if np.nanmean(acc_all)> best_acc:
            data = {"Losshistory":loss_history,"img":img,
                    "true":true, "prediction":prediction}
            scipy.io.savemat(path+ 'validation_epoch'+str(epoch)+'.mat',data)
        return np.nanmean(acc_all)

# Set the data folder for data reading

In [10]:
# Enter your data folder
training_set1 = CustomDataset('./example_data/imagesTr/',
                              './example_data/labelsTr/',
                              if_aug = True,if_train = True)

test_set = CustomDataset('./example_data/imagesTs/',
                          './example_data/labelsTs/',
                          if_aug = False,if_train = False)

# Enter your data reader parameters
params = {'batch_size': BATCH_SIZE_TRAIN,
          'shuffle': True,
          'pin_memory': True,
          'drop_last': False}
testparams = {'batch_size': BATCH_SIZE_TEST, 
          'shuffle': False,
          'pin_memory': True,
          'drop_last': False}

train_loader1 = torch.utils.data.DataLoader(training_set1, **params)

test_loader = torch.utils.data.DataLoader(test_set, **testparams)

# Start the training and testing

In [None]:
# Enter your total number of epoch
N_EPOCHS = 500

# Enter the address you save the checkpoint and the prediction examples
path ="./output/example-1/"
PATH = path+'ViTRes1.pt' # Use your own path
best_acc = 0
if not os.path.exists(path):
  os.makedirs(path) 
train_loss_history, test_loss_history,test_acc_history = [], [], []

# Uncomment this when you resume the checkpoint
# model.load_state_dict(torch.load(PATH))
for epoch in range(0, N_EPOCHS):
    print('Epoch:', epoch)
    start_time = time.time()
    train(model, optimizer, train_loader1, train_loss_history)
    print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
    if epoch % 10 == 0:
        theacc = evaluate(model,epoch,path, test_loader, test_loss_history, test_acc_history)
        if theacc > best_acc:
            print('Save the latest best model')
            torch.save(model.state_dict(), PATH)
            best_acc = theacc
print('Execution time')

Epoch: 0




torch.Size([2, 128, 16, 16, 16])
torch.Size([2, 64, 32, 32, 32])
torch.Size([2, 14, 64, 64, 64])
optimization time: 5.861515045166016
Total time per sample is: 5.861515045166016
Averaged loss is: 3.757516384124756
Averaged  Dice loss is: 0.0
Execution time:  7.20 seconds
torch.Size([2, 128, 64, 64, 40])
torch.Size([2, 64, 128, 128, 80])


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "C:\Users\pshaoya\Anaconda3\envs\DL\lib\site-packages\IPython\core\interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\pshaoya\AppData\Local\Temp/ipykernel_25528/1108966507.py", line 20, in <module>
    theacc = evaluate(model,epoch,path, test_loader, test_loss_history, test_acc_history)
  File "C:\Users\pshaoya\AppData\Local\Temp/ipykernel_25528/2552549526.py", line 30, in evaluate
    reconstructed_output = inferer(testdata,model)
  File "C:\Users\pshaoya\Anaconda3\envs\DL\lib\site-packages\monai\inferers\inferer.py", line 192, in __call__
    return sliding_window_inference(  # type: ignore
  File "C:\Users\pshaoya\Anaconda3\envs\DL\lib\site-packages\monai\inferers\utils.py", line 176, in sliding_window_inference
    seg_prob_out = predictor(window_data, *args, **kwargs)  # batched patch segmentation
  File "C:\Users\pshaoya\AppData\Roaming\Python\Python38\site-packages\torch\nn

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "C:\Users\pshaoya\Anaconda3\envs\DL\lib\site-packages\IPython\core\interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\pshaoya\AppData\Local\Temp/ipykernel_25528/1108966507.py", line 20, in <module>
    theacc = evaluate(model,epoch,path, test_loader, test_loss_history, test_acc_history)
  File "C:\Users\pshaoya\AppData\Local\Temp/ipykernel_25528/2552549526.py", line 30, in evaluate
    reconstructed_output = inferer(testdata,model)
  File "C:\Users\pshaoya\Anaconda3\envs\DL\lib\site-packages\monai\inferers\inferer.py", line 192, in __call__
    return sliding_window_inference(  # type: ignore
  File "C:\Users\pshaoya\Anaconda3\envs\DL\lib\site-packages\monai\inferers\utils.py", line 176, in sliding_window_inference
    seg_prob_out = predictor(window_data, *args, **kwargs)  # batched patch segmentation
  File "C:\Users\pshaoya\AppData\Roaming\Python\Python38\site-packages\torch\nn