In [1]:
from torch.utils.data import Dataset, DataLoader,WeightedRandomSampler
import os 
import nibabel as nib
import pandas as pd
import pickle
from random import choice
import numpy as np 
from torch import nn
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from monai.losses import GeneralizedDiceLoss,DiceCELoss,DiceFocalLoss,GeneralizedDiceFocalLoss
from monai.metrics import compute_meandice
from monai.networks.blocks import Convolution, ResidualUnit
from collections import OrderedDict
from torch.optim.lr_scheduler import StepLR
from bayesian_torch.layers import Conv3dFlipout
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss
import gc

In [2]:
data_df_path='./pre_process_data/train/train_index.csv'

In [3]:
data_df=pd.read_csv(data_df_path)
data_df.shape

(264, 3)

In [4]:
data_df=pd.read_csv(data_df_path)
data_df['index']=range(len(data_df))
train_df=data_df.sample(frac=0.9)
val_df=data_df[~data_df['index'].isin(train_df['index'])]
print(len(train_df))
print(len(val_df))

238
26


In [5]:
class simpletrain(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, dataset):
        
        self.data = dataset



    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        data_index = self.data.iloc[idx]
        images_path=data_index.images_path
        seg_path=data_index.seg_path
        with open(images_path, 'rb') as f:
            images = pickle.load(f)
        
        with open(seg_path, 'rb') as f:
            seg = pickle.load(f)
        
        
        return images, seg

In [6]:
def Double_conv_encoder_block_in (num_in_filter,args):
    input_dim=args['input_dim'] ; kernal_size=args['kernal_size']; pool_kernel_size=args['pool_kernal_size'];
    dilation_size=args['dilation_size']; stride_size=args['stride_size'] ; pool_stride_size=args['pool_stride_size'];
    padding_size=args['padding_size']; bias=args['bias'] ; act_type= args['act_type'] ; norm_type=args['norm_type'] ; 
    pool_type=args['pool_type'] ;
    
    conv1=Convolution(input_dim,num_in_filter,num_in_filter*2, strides=stride_size, 
                      kernel_size=kernal_size,  dilation=dilation_size,
                         bias=bias, conv_only=True,  padding=padding_size)
    conv2=Convolution(input_dim,num_in_filter*2,num_in_filter*4, strides=stride_size, 
                      kernel_size=kernal_size,  dilation=dilation_size,
                         bias=bias, conv_only=True,  padding=padding_size)
    
    if act_type=='PReLU':
        act_layer=nn.PReLU()
    elif act_type=='SELU':
        act_layer=nn.SELU()
    elif act_type=='CELU':
        act_layer=nn.CELU()
    else:
        act_layer=nn.ReLU()
    
    if norm_type=='GroupNorm':
        norm_layer_1=nn.GroupNorm(8,num_in_filter*2)
        norm_layer_2=nn.GroupNorm(8,num_in_filter*4)
    elif act_type=='BatchNorm':
        norm_layer_1=nn.LazyBatchNorm3d()
        norm_layer_2=nn.LazyBatchNorm3d()
    else: 
        norm_layer_1=nn.LazyInstanceNorm3d()
        norm_layer_2=nn.LazyInstanceNorm3d()
    
    encoder_block=nn.Sequential(OrderedDict(
        [('conv1', conv1),
        (act_type+'_1', act_layer),
        (norm_type+'_1', norm_layer_1),
        ('conv2', conv2),
        (act_type+'_2', act_layer),
        (norm_type+'_2', norm_layer_2)]))
    
    return encoder_block

In [7]:
def Double_conv_encoder_block_en (num_in_filter,args):
    input_dim=args['input_dim'] ; kernal_size=args['kernal_size']; pool_kernel_size=args['pool_kernal_size'];
    dilation_size=args['dilation_size']; stride_size=args['stride_size'] ; pool_stride_size=args['pool_stride_size'];
    padding_size=args['padding_size']; bias=args['bias'] ; act_type= args['act_type'] ; norm_type=args['norm_type'] ; 
    pool_type=args['pool_type'] ;
    
    assert pool_type in ['max', 'avg']
    if pool_type == 'avg':
        pooling = nn.AvgPool3d(kernel_size=pool_kernel_size, stride=pool_stride_size, padding=0, dilation=1, ceil_mode=False)
    else:
        pooling = nn.MaxPool3d(kernel_size=pool_kernel_size, stride=pool_stride_size, padding=0, dilation=1, ceil_mode=False)
      
    
    conv1=Convolution(input_dim,num_in_filter*2,num_in_filter*2, strides=stride_size, 
                      kernel_size=kernal_size,  dilation=dilation_size,
                         bias=bias, conv_only=True,  padding=padding_size)
    conv2=Convolution(input_dim,num_in_filter*2,num_in_filter*4, strides=stride_size, 
                      kernel_size=kernal_size,  dilation=dilation_size,
                         bias=bias, conv_only=True,  padding=padding_size)
    
    if act_type=='PReLU':
        act_layer=nn.PReLU()
    elif act_type=='SELU':
        act_layer=nn.SELU()
    elif act_type=='CELU':
        act_layer=nn.CELU()
    else:
        act_layer=nn.ReLU()
    
    if norm_type=='GroupNorm':
        norm_layer_1=nn.GroupNorm(8,num_in_filter*2)
        norm_layer_2=nn.GroupNorm(8,num_in_filter*4)
    elif act_type=='BatchNorm':
        norm_layer_1=nn.LazyBatchNorm3d()
        norm_layer_2=nn.LazyBatchNorm3d()
    else: 
        norm_layer_1=nn.LazyInstanceNorm3d()
        norm_layer_2=nn.LazyInstanceNorm3d()
    
    encoder_block=nn.Sequential(OrderedDict(
        [('pool', pooling),
        ('conv1', conv1),
        (act_type+'_1', act_layer),
        (norm_type+'_1', norm_layer_1),
        ('conv2', conv2),
        (act_type+'_2', act_layer),
        (norm_type+'_2', norm_layer_2)]))
    
    return encoder_block

In [8]:
class Double_conv_encoder_block_de(nn.Module):
    def __init__(self, num_in_filter,args):
        super(Double_conv_encoder_block_de, self).__init__()
        self.args=args
        self.num_in_filter=num_in_filter
        input_dim=self.args['input_dim'] ; kernal_size=self.args['kernal_size']; 
        pool_kernel_size=self.args['pool_kernal_size'];dilation_size=self.args['dilation_size']; 
        stride_size=self.args['stride_size'] ; pool_stride_size=self.args['pool_stride_size'];
        padding_size=self.args['padding_size']; bias=self.args['bias'] ; act_type= self.args['act_type'] ; 
        norm_type=self.args['norm_type'] ; pool_type=self.args['pool_type'] ; baysian=self.args['baysian'] ; 
        prior_mean=self.args['baysian_prior_mean'];prior_variance=self.args['baysian_prior_variance'] ; 
        posterior_mu_init=self.args['baysian_post_mu']; posterior_rho_init=self.args['baysian_post_rho']
        
        if baysian==True:
            self.conv1=Conv3dFlipout(num_in_filter,int(num_in_filter/3),kernel_size=kernal_size ,stride=stride_size,  dilation=dilation_size,
                         bias=bias,  padding=padding_size, 
                         prior_mean=prior_mean, prior_variance=prior_variance, 
                         posterior_mu_init=posterior_mu_init, posterior_rho_init=posterior_rho_init)
            self.conv2=Conv3dFlipout(int(num_in_filter/3),int(num_in_filter/3),kernel_size=kernal_size, stride=stride_size, 
                        dilation=dilation_size,
                         bias=bias, padding=padding_size,
                        prior_mean=prior_mean, prior_variance=prior_variance, 
                         posterior_mu_init=posterior_mu_init, posterior_rho_init=posterior_rho_init)
        else: 
    
            self.conv1=Convolution(input_dim,num_in_filter,int(num_in_filter/3), strides=stride_size, 
                          kernel_size=kernal_size,  dilation=dilation_size,
                             bias=bias, conv_only=True,  padding=padding_size)
            self.conv2=Convolution(input_dim,int(num_in_filter/3),int(num_in_filter/3), strides=stride_size, 
                          kernel_size=kernal_size,  dilation=dilation_size,
                             bias=bias, conv_only=True,  padding=padding_size)

        if act_type=='PReLU':
            self.act_layer=nn.PReLU()
        elif act_type=='SELU':
            self.act_layer=nn.SELU()
        elif act_type=='CELU':
            self.act_layer=nn.CELU()
        else:
            self.act_layer=nn.ReLU()
    
        if norm_type=='GroupNorm':
            self.norm_layer_1=nn.GroupNorm(8,int(num_in_filter/3))
            self.norm_layer_2=nn.GroupNorm(8,int(num_in_filter/3))
        elif act_type=='BatchNorm':
            self.norm_layer_1=nn.LazyBatchNorm3d()
            self.norm_layer_2=nn.LazyBatchNorm3d()
        else: 
            self.norm_layer_1=nn.LazyInstanceNorm3d()
            self.norm_layer_2=nn.LazyInstanceNorm3d()
            
    def forward(self, x):
        kl_sum = 0
        x, kl = self.conv1(x)
        kl_sum += kl
        x = self.act_layer(x)
        x = self.norm_layer_1(x)
        x, kl = self.conv2(x)
        kl_sum += kl
        x = self.act_layer(x)
        x = self.norm_layer_2(x)
    
        return x , kl_sum

In [9]:
class UNet_3d_encoder(nn.Module):
    def __init__(self, args,en_channel):
        super(UNet_3d_encoder, self).__init__()
        self.args=args
        encoders = []
        
        input_en=Double_conv_encoder_block_in (en_channel[0],self.args)
        encoders.append(input_en)
        en_channel.pop(0)
        for i in en_channel:
            encoder=Double_conv_encoder_block_en (i,self.args)
            encoders.append(encoder)
        self.encoders=nn.ModuleList(encoders)
        
        
    def forward(self, x):
        # encoder part
        encoders_features = []
        for encoder in self.encoders:
            x = encoder(x)
            # reverse the encoder outputs to be aligned with the decoder
            encoders_features.insert(0, x)
        final_layer= encoders_features[0]
        encoders_features_o = encoders_features[1:]
        
        return final_layer , encoders_features_o

In [10]:
class UNet_3d_decoder(nn.Module):
    def __init__(self, args,de_channel):
        super(UNet_3d_decoder, self).__init__()
        self.args=args
        decoders = []
        
        for i in de_channel:
            decoder=Double_conv_encoder_block_de (i,self.args)
            decoders.append(decoder)
        self.decoders=nn.ModuleList(decoders)
        
        
    def forward(self, x,encoders_features):
        # encoder part
        all_kl=0
        for decoder, encoder_features in zip(self.decoders, encoders_features):
            output_size = encoder_features.size()[2:]
            
            x = F.interpolate(x, size=output_size, mode='nearest')
            
            # concatenate encoder_features (encoder path) with the upsampled input across channel dimension
            x = torch.cat((encoder_features, x), dim=1)
            
            x,kl=decoder(x)
            all_kl+=kl
        
        
        return x,kl

In [11]:
def create_encoder_maps(init_channel_number, number_of_layers):
    return [init_channel_number * 2 ** k for k in range(number_of_layers)]

In [12]:
def create_decoder_maps(init_channel_number,encoder_maps_list_r):
    decoder_maps=[]
    for i in range(len(encoder_maps_list_r)):
        try:
            decoder_map=encoder_maps_list_r[i]*init_channel_number+encoder_maps_list_r[i+1]*init_channel_number
            decoder_maps.append(decoder_map)
        except:
            break
        
    return decoder_maps

In [13]:
class Unet(nn.Module):
    def __init__(self, encoder,decoder,encoder_pre,decoder_pre, args):
        super(Unet, self).__init__()
        self.args=args
        self.encoder = encoder
        self.decoder = decoder
        self.encoder_pre = encoder_pre
        self.decoder_pre= decoder_pre
        
        self.final_conv = nn.Conv3d(self.args['class_nu']**2,self.args['class_nu'], kernel_size=(1, 1, 1), stride=(1, 1, 1))
        
        
    def forward(self, images ):
        final_layer , encoders_features_o=self.encoder(images)
        final_layer_pre , encoders_features_o_pre=self.encoder_pre(images)
        
        out,kl = self.decoder(final_layer , encoders_features_o)
       
        out_pre,kl_pre = self.decoder_pre(final_layer_pre , encoders_features_o_pre)
        
        out=out+out_pre
        out=self.final_conv(out)
       
        
        return out,kl+kl_pre

In [14]:
def expand_as_one_hot(input, C, ignore_index=None):
    """
    Converts NxDxHxW label image to NxCxDxHxW, where each label gets converted to its corresponding one-hot vector
    :param input: 4D input image (NxDxHxW)
    :param C: number of channels/labels
    :param ignore_index: ignore index to be kept during the expansion
    :return: 5D output image (NxCxDxHxW)
    """
    assert input.dim() == 4

    # expand the input tensor to Nx1xDxHxW before scattering
    input = input.unsqueeze(1)
    # create result tensor shape (NxCxDxHxW)
    shape = list(input.size())
    shape[1] = C

    if ignore_index is not None:
        # create ignore_index mask for the result
        mask = input.expand(shape) == ignore_index
        # clone the src tensor and zero out ignore_index in the input
        input = input.clone()
        input[input == ignore_index] = 0
        # scatter to get the one-hot tensor
        result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
        # bring back the ignore_index in the result
        result[mask] = ignore_index
        return result
    else:
        # scatter to get the one-hot tensor
        return torch.zeros(shape).to(input.device).scatter_(1, input, 1)

In [15]:
load_en_path_pre='./model/3d_unet_preoperative/3DUNet_preoperative_encoder_baysian.pt'
load_de_path_pre='./model/3d_unet_preoperative/3DUNet_preoperative_decoder_baysian.pt'

In [16]:
args={}
args['epoch']=30
args['input_dim']=3
args['kernal_size']=3
args['slice_nu']=155
args['window_size']=10
args['pool_kernal_size']=2
args['dilation_size']=1
args['stride_size']=1
args['pool_stride_size']=2
args['padding_size']=1
args['bias']=True
args['act_type']='PReLU'
args['norm_type']='GroupNorm'
args['pool_type']='max'
args['class_nu']=4
args['batch_size']=2
args['baysian']=True ; 
args['baysian_prior_mean']=0.0
args['baysian_prior_variance']=1.0
args['baysian_post_mu']=0
args['baysian_post_rho']=-3.0

In [17]:
en_channel=create_encoder_maps(4, 4)
en_channel_r=en_channel.copy()
en_channel_r.reverse()
print(en_channel)

[4, 8, 16, 32]


In [18]:
de_channel=create_decoder_maps(4,en_channel_r)
print(de_channel)

[192, 96, 48]


In [19]:
device=14
torch.cuda.set_device(device)

In [20]:
en_channel_p=en_channel.copy()
encoder=UNet_3d_encoder(args,en_channel_p)
de_channel_p=de_channel.copy()
decoder=UNet_3d_decoder(args,de_channel_p)
en_channel_p=en_channel.copy()
encoder_pre=UNet_3d_encoder(args,en_channel_p)
de_channel_p=de_channel.copy()
decoder_pre=UNet_3d_decoder(args,de_channel_p)

model=Unet(encoder,decoder,encoder_pre,decoder_pre,args)
#model.encoder.load_state_dict(torch.load(load_en_path_pre))
#model.decoder.load_state_dict(torch.load(load_de_path_pre))
model.encoder_pre.load_state_dict(torch.load(load_en_path_pre))
model.decoder_pre.load_state_dict(torch.load(load_de_path_pre))

<All keys matched successfully>

In [21]:
layer_counter = 0
for (name, module) in model.named_children():
    print(name)
    if '_pre' in name :
        for layer in module.children():
            for param in layer.parameters():
                param.requires_grad = False
            
            print('Layer "{}" in module "{}" was frozen!'.format(layer_counter, name))
            layer_counter+=1

encoder
decoder
encoder_pre
Layer "0" in module "encoder_pre" was frozen!
decoder_pre
Layer "1" in module "decoder_pre" was frozen!
final_conv


In [22]:
freezed_num, pass_num = 0, 0
for (name, param) in model.named_parameters():
    #print(name)
    if param.requires_grad == False:
        print('freeze_layer_'+name)
        freezed_num += 1
    else:
        print('train_layer_'+name)
        pass_num += 1

train_layer_encoder.encoders.0.conv1.conv.weight
train_layer_encoder.encoders.0.conv1.conv.bias
train_layer_encoder.encoders.0.PReLU_1.weight
train_layer_encoder.encoders.0.GroupNorm_1.weight
train_layer_encoder.encoders.0.GroupNorm_1.bias
train_layer_encoder.encoders.0.conv2.conv.weight
train_layer_encoder.encoders.0.conv2.conv.bias
train_layer_encoder.encoders.0.GroupNorm_2.weight
train_layer_encoder.encoders.0.GroupNorm_2.bias
train_layer_encoder.encoders.1.conv1.conv.weight
train_layer_encoder.encoders.1.conv1.conv.bias
train_layer_encoder.encoders.1.PReLU_1.weight
train_layer_encoder.encoders.1.GroupNorm_1.weight
train_layer_encoder.encoders.1.GroupNorm_1.bias
train_layer_encoder.encoders.1.conv2.conv.weight
train_layer_encoder.encoders.1.conv2.conv.bias
train_layer_encoder.encoders.1.GroupNorm_2.weight
train_layer_encoder.encoders.1.GroupNorm_2.bias
train_layer_encoder.encoders.2.conv1.conv.weight
train_layer_encoder.encoders.2.conv1.conv.bias
train_layer_encoder.encoders.2.PReLU

In [23]:
model = model.cuda()

In [24]:
optimizer=optim.Adam(model.parameters(),lr=0.005,weight_decay=1e-5)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [25]:
train_data=simpletrain(train_df)
train_loader=DataLoader(train_data,batch_size=1,shuffle=True)
#train_loader=DataLoader(train_data,2)
val_data=simpletrain(val_df)
val_loader=DataLoader(val_data,batch_size=1)

In [26]:
criterion =criterion =GeneralizedDiceFocalLoss(softmax=True,lambda_gdl=1.0, gamma=3.0,lambda_focal=1.0,focal_weight=[1.0,1.0,1.0,7.0])

#criterion =GeneralizedDiceLoss(softmax=True,include_background=True)

In [27]:
def dice_coef2(y_true, y_pred):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    union = np.sum(y_true_f) + np.sum(y_pred_f)
    if union==0: return 1
    
    intersection = np.sum(y_true_f * y_pred_f)
    
    return 2. * intersection / union

In [28]:
def dice_return(newseg,indices ):
    background_dice=[]
    enh_dice=[]
    nonenc_dice=[]
    edema_dice=[]

    seg = newseg
    pred = indices
    
    newsegback=np.where(seg==0,1,0)
    newpredback=np.where(pred==0,1,0)
    bdice=dice_coef2(newsegback, newpredback)
    background_dice.append(bdice)
    
        #'edema'
    newsegback=np.where(seg==1,1,0)
    newpredback=np.where(pred==1,1,0)
    bdice=dice_coef2(newsegback, newpredback)
    edema_dice.append(bdice)
        #'enhencing'
    newsegback=np.where(seg==2,1,0)
    newpredback=np.where(pred==2,1,0)
    bdice=dice_coef2(newsegback, newpredback)
    enh_dice.append(bdice)
        #'nonenhencing'
    newsegback=np.where(seg==3,1,0)
    newpredback=np.where(pred==3,1,0)
    bdice=dice_coef2(newsegback, newpredback)
    nonenc_dice.append(bdice)
    print(background_dice[0],edema_dice[0],enh_dice[0],nonenc_dice[0])
    
    return [background_dice[0],edema_dice[0],enh_dice[0],nonenc_dice[0]]
        

In [None]:
all_train_loss=[]
all_val_loss=[]

current_val_loss = 0.0
prev_val_loss = 100.0

final_act = nn.Softmax(dim=1)
for epoch in tqdm(range(1, args['epoch'] + 1)):
    epoch_train_loss=0
    for count, batch in enumerate(train_loader):
        #print(len(batch), len(batch[0]), batch[0][0].shape, batch[0][1].shape)
        #images, seg = batch[0][0], batch[0][1]
        images, seg = batch
        seg=expand_as_one_hot(seg.long(), 4, ignore_index=None)
        images = images.cuda()
        seg = seg.cuda()
        model.train()
        optimizer.zero_grad()
        pred,train_kl_loss=model(images)
        train_loss = criterion(pred, seg)
        
        train_loss=train_loss+(train_kl_loss/1)
        train_loss.backward()
        optimizer.step()
        record_loss=train_loss.cpu().data.numpy().item()
        
        epoch_train_loss+=record_loss
    scheduler.step()
    epoch_train_loss=epoch_train_loss/count
    print('current_train_loss {:8.5f}'.format(epoch_train_loss))
    all_train_loss.append(epoch_train_loss)
    
    
    model.eval()
    epoch_val_loss=0
    eval_dice=[]
    for count, batch in enumerate(val_loader):
        images, seg = batch
        seg_1=seg.cpu().detach().numpy().astype('uint8')
        seg=expand_as_one_hot(seg.long(), 4, ignore_index=None)
        images = images.cuda()
        seg = seg.cuda()
        with torch.no_grad():
            pred,_=model(images)
            val_loss = criterion(pred, seg)
            record_loss=val_loss.cpu().data.numpy().item()
            epoch_val_loss+=record_loss
            pred = final_act(pred)
            pred = pred.squeeze()
            _, indices = pred.max(0)
            indices =indices.cpu().detach().numpy().astype('uint8')
            eval_dice.append(dice_return(indices,seg_1.squeeze()))
            
    eval_dsc=pd.DataFrame(eval_dice,columns=['background','flair','enh','non_enh'])       
    epoch_val_loss=epoch_val_loss/count
    all_val_loss.append(epoch_val_loss)
    
    print('current_train_loss {:8.5f},current_val_loss {:8.5f}'
          .format(epoch_train_loss,epoch_val_loss))
    current_val_loss = epoch_val_loss
    
    if current_val_loss < prev_val_loss:
        model_name='3DUNet_postoperative_encoder_baysian_low_7'+'.pt'
        torch.save(model.encoder.state_dict(), './model/3d_unet_postoperative/'+ model_name)
        model_name='3DUNet_postoperative_decoder_baysian_low_7'+'.pt'
        torch.save(model.decoder.state_dict(), './model/3d_unet_preoperative/'+ model_name)
        model_name='3DUNet_postoperative_final_conv_baysian_low_7'+'.pt'
        torch.save(model.final_conv.state_dict(), './model/3d_unet_postoperative/'+ model_name)
        model_name='3DUNet_postoperative_whole_model_baysian_low_7'+'.pt'
        torch.save(model.state_dict(), './model/3d_unet_postoperative/'+ model_name)
        prev_val_loss = current_val_loss
        eval_dsc.to_csv('./model/all_loss/3d_unet_postoperative_dsc_baysian.csv',index=False)
        
        print("Improved and model saved in ", epoch)
    else:
        print("No improvement in ", epoch)

import gc
del model
gc.collect()
with torch.cuda.device(device):
    torch.cuda.empty_cache()    
    

# prediction

In [55]:
import os

In [56]:
args={}
args['epoch']=30
args['input_dim']=3
args['kernal_size']=3
args['slice_nu']=155
args['window_size']=10
args['pool_kernal_size']=2
args['dilation_size']=1
args['stride_size']=1
args['pool_stride_size']=2
args['padding_size']=1
args['bias']=True
args['act_type']='PReLU'
args['norm_type']='GroupNorm'
args['pool_type']='max'
args['class_nu']=4
args['batch_size']=2
args['baysian']=True ; 
args['baysian_prior_mean']=0.0
args['baysian_prior_variance']=1.0
args['baysian_post_mu']=0
args['baysian_post_rho']=-3.0

In [57]:
en_channel=create_encoder_maps(4, 4)
en_channel_r=en_channel.copy()
en_channel_r.reverse()
print(en_channel)

[4, 8, 16, 32]


In [58]:
de_channel=create_decoder_maps(4,en_channel_r)
print(de_channel)

[192, 96, 48]


In [59]:
test_df_path='./pre_process_data/test/test_index.csv'
test_df=pd.read_csv(test_df_path)

In [60]:
device=14
torch.cuda.set_device(device)

In [61]:
test_df.head()

Unnamed: 0,pt_id,images_path,seg_path
0,34114852.0,./pre_process_data/test/images/34114852_images...,./pre_process_data/test/segmentation/34114852_...
1,34169940.0,./pre_process_data/test/images/34169940_images...,./pre_process_data/test/segmentation/34169940_...
2,34259182.0,./pre_process_data/test/images/34259182_images...,./pre_process_data/test/segmentation/34259182_...
3,34290186.0,./pre_process_data/test/images/34290186_images...,./pre_process_data/test/segmentation/34290186_...
4,35069039.2,./pre_process_data/test/images/35069039.2_imag...,./pre_process_data/test/segmentation/35069039....


In [62]:
class simpletest(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, dataset):
        
        self.data = dataset



    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        data_index = self.data.iloc[idx]
        images_path=data_index.images_path
        seg_path=data_index.seg_path
        patient_id=data_index.pt_id
        with open(images_path, 'rb') as f:
            images = pickle.load(f)
        
        with open(seg_path, 'rb') as f:
            seg = pickle.load(f)
        
        
        return images, seg, str(patient_id)

In [63]:
test_data=simpletest(test_df)
test_loader=DataLoader(test_data,batch_size=1)

In [64]:

path='./model/3d_unet_postoperative/'

model_list=os.listdir(path)
model_list=[i for i in model_list if 'baysian_low' in i ]
model_list=[i for i in model_list if 'whole_model' in i ]
model_list.sort()
print(model_list)

['3DUNet_postoperative_whole_model_baysian_low_3.pt', '3DUNet_postoperative_whole_model_baysian_low_5.pt', '3DUNet_postoperative_whole_model_baysian_low_7.pt']


In [73]:
num_monte_carlo=20

In [None]:
for i in range(len(model_list)):
    predict_folder=model_list[i].replace('.pt','')
    isExist = os.path.exists('./Model_prediction/'+predict_folder)
    if not isExist:
        os.makedirs('./Model_prediction/'+predict_folder)
    save_folder_path='./Model_prediction/'+predict_folder+'/'
    en_channel_p=en_channel.copy()
    encoder=UNet_3d_encoder(args,en_channel_p)
    de_channel_p=de_channel.copy()
    decoder=UNet_3d_decoder(args,de_channel_p)
    en_channel_p=en_channel.copy()
    encoder_pre=UNet_3d_encoder(args,en_channel_p)
    de_channel_p=de_channel.copy()
    decoder_pre=UNet_3d_decoder(args,de_channel_p)
    model=Unet(encoder,decoder,encoder_pre,decoder_pre,args)
    load_path='./model/3d_unet_postoperative/'+model_list[i]
    model.load_state_dict(torch.load(load_path))
    model.cuda()
    for count, batch in enumerate(test_loader):
        images, _, pt_id = batch
        images=images.cuda()
        seg_results = []
        for j in range(num_monte_carlo):
            model.eval()
            with torch.no_grad():
                output,_ = model(images)
                seg_results.append(output.cpu())
        seg_results= torch.stack(seg_results)   
        seg_results=torch.squeeze(seg_results)
        uncertainity_temp=np.var(seg_results.cpu().detach().numpy(), 0)
        uncertainity_background=uncertainity_temp[0]
        uncertainity_flair=uncertainity_temp[1]
        uncertainity_con_enh=uncertainity_temp[2]
        uncertainity_non_enh=uncertainity_temp[3]
        uncertainity_all=uncertainity_temp.mean(0)
        seg_layer_avg = torch.mean(seg_results, dim=0)
        _, indices = seg_layer_avg.max(0)
        indices = indices.cpu().detach().numpy().astype('uint8')
        img = nib.Nifti1Image(indices, [[-1.,  0.,  0., -0.],[0.,  1.,  0., -0.],[0.,  0.,  1.,  0.],[0.,  0.,  0.,  1.]])
        uncertainity_img_back = nib.Nifti1Image(uncertainity_background, [[-1.,  0.,  0., -0.],[0.,  1.,  0., -0.],[0.,  0.,  1.,  0.],[0.,  0.,  0.,  1.]])
        uncertainity_img_flair = nib.Nifti1Image(uncertainity_flair, [[-1.,  0.,  0., -0.],[0.,  1.,  0., -0.],[0.,  0.,  1.,  0.],[0.,  0.,  0.,  1.]])
        uncertainity_img_contrast = nib.Nifti1Image(uncertainity_con_enh, [[-1.,  0.,  0., -0.],[0.,  1.,  0., -0.],[0.,  0.,  1.,  0.],[0.,  0.,  0.,  1.]])
        uncertainity_img_non_enh = nib.Nifti1Image(uncertainity_non_enh, [[-1.,  0.,  0., -0.],[0.,  1.,  0., -0.],[0.,  0.,  1.,  0.],[0.,  0.,  0.,  1.]])
        uncertainity_img_all = nib.Nifti1Image(uncertainity_all, [[-1.,  0.,  0., -0.],[0.,  1.,  0., -0.],[0.,  0.,  1.,  0.],[0.,  0.,  0.,  1.]])
        nib.save(img, save_folder_path+str(pt_id[0])+'_segment_pred_kg.nii.gz')
        nib.save(uncertainity_img_back, save_folder_path+str(pt_id[0])+'_segment_pred_new_kg_uncertainity_back.nii.gz')
        nib.save(uncertainity_img_flair, save_folder_path+str(pt_id[0])+'_segment_pred_new_kg_uncertainity_flair.nii.gz')
        nib.save(uncertainity_img_contrast, save_folder_path+str(pt_id[0])+'_segment_pred_new_kg_uncertainity_con.nii.gz')
        nib.save(uncertainity_img_non_enh, save_folder_path+str(pt_id[0])+'_segment_pred_new_kg_uncertainity_non_enh.nii.gz')
        nib.save(uncertainity_img_all, save_folder_path+str(pt_id[0])+'_segment_pred_new_kg_uncertainity_all.nii.gz')
        print('Saved example_segment_pred'+'_'+str(pt_id[0]))
    del model
    gc.collect()
    with torch.cuda.device(device):
        torch.cuda.empty_cache()
            
        

Saved example_segment_pred_34114852
Saved example_segment_pred_34169940
Saved example_segment_pred_34259182
Saved example_segment_pred_34290186
Saved example_segment_pred_35069039.2
Saved example_segment_pred_35256564
Saved example_segment_pred_35520522
Saved example_segment_pred_36104836
Saved example_segment_pred_36477034
Saved example_segment_pred_38216722
Saved example_segment_pred_39597056
Saved example_segment_pred_39621144
Saved example_segment_pred_39991662
Saved example_segment_pred_40481407
Saved example_segment_pred_40486118
Saved example_segment_pred_40552846
Saved example_segment_pred_41245644
Saved example_segment_pred_41263165
Saved example_segment_pred_41336169
Saved example_segment_pred_45533978
Saved example_segment_pred_45947450
Saved example_segment_pred_46140963
Saved example_segment_pred_46292941
Saved example_segment_pred_46351381
Saved example_segment_pred_46481193
Saved example_segment_pred_46674661
Saved example_segment_pred_47136867
Saved example_segment_pred

In [75]:
del model
gc.collect()
with torch.cuda.device(device):
    torch.cuda.empty_cache()
            