In [1]:
from torch.utils.data import Dataset, DataLoader,WeightedRandomSampler
import os 
import nibabel as nib
import pandas as pd
import pickle
from random import choice
import numpy as np 
from torch import nn
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from monai.networks.blocks import Convolution, ResidualUnit
from collections import OrderedDict

In [2]:
data_df_path='./pre_process_data/train/train_index.csv'

In [3]:
data_df=pd.read_csv(data_df_path)

In [4]:
def slice_index (arg):
    slice_nu=arg['slice_nu']
    try_index= [i  for i in range(slice_nu)]
    choices_list = []
    window_size = arg['window_size']
    choose_index=try_index[0+window_size:slice_nu-window_size]
    ch = choice(choose_index)
    
    rn = try_index[ch:ch+window_size]
    choices_list.append(rn)
    ch_p = ch+1
    p_rn = try_index[ch_p:ch_p+window_size]
    if np.array(p_rn).max()>slice_nu:
        ch_p = ch-1
        p_rn = try_index[ch_p:ch_p+window_size]
    else:
        p_rn = p_rn 
    choices_list.append(p_rn)
    rn_u=list(set(rn).union(set(p_rn)))
    try_index_n=list(set(choose_index).difference(set(rn_u)))
    ch_n = choice(try_index_n[:-window_size+1])
    n_rn = try_index[ch_n:ch_n+window_size]
    choices_list.append(n_rn)
    return choices_list

In [5]:
def slice_data (images,slice_index):
    query_index=slice_index[0]
   
    positive_index=slice_index[1]
    negative_index=slice_index[2]
    query=images[:,:,:,query_index]
    positive=images[:,:,:,positive_index]
    negative=images[:,:,:,negative_index]
    
    return query, positive, negative


In [6]:
class simpletrain(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, dataset,arg):
        
        self.data = dataset
        self.arg = arg


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        data_index = self.data.iloc[idx]
        images_path=data_index.images_path
        with open(images_path, 'rb') as f:
            images = pickle.load(f)
        input_index=slice_index(self.arg)
        
        
        
        query,positive,negative=slice_data (images,input_index)
        positive_label=np.array(1.0)
        negative_label=np.array(0.0)
        
        return (query,positive,negative,positive_label,negative_label)
        

In [7]:
def Double_conv_encoder_block_in (num_in_filter,args):
    input_dim=args['input_dim'] ; kernal_size=args['kernal_size']; pool_kernel_size=args['pool_kernal_size'];
    dilation_size=args['dilation_size']; stride_size=args['stride_size'] ; pool_stride_size=args['pool_stride_size'];
    padding_size=args['padding_size']; bias=args['bias'] ; act_type= args['act_type'] ; norm_type=args['norm_type'] ; 
    pool_type=args['pool_type'] ;
    
    conv1=Convolution(input_dim,num_in_filter,num_in_filter*2, strides=stride_size, 
                      kernel_size=kernal_size,  dilation=dilation_size,
                         bias=bias, conv_only=True,  padding=padding_size)
    conv2=Convolution(input_dim,num_in_filter*2,num_in_filter*4, strides=stride_size, 
                      kernel_size=kernal_size,  dilation=dilation_size,
                         bias=bias, conv_only=True,  padding=padding_size)
    
    if act_type=='PReLU':
        act_layer=nn.PReLU()
    elif act_type=='SELU':
        act_layer=nn.SELU()
    elif act_type=='CELU':
        act_layer=nn.CELU()
    else:
        act_layer=nn.ReLU()
    
    if norm_type=='GroupNorm':
        norm_layer_1=nn.GroupNorm(8,num_in_filter*2)
        norm_layer_2=nn.GroupNorm(8,num_in_filter*4)
    elif act_type=='BatchNorm':
        norm_layer_1=nn.LazyBatchNorm3d()
        norm_layer_2=nn.LazyBatchNorm3d()
    else: 
        norm_layer_1=nn.LazyInstanceNorm3d()
        norm_layer_2=nn.LazyInstanceNorm3d()
    
    encoder_block=nn.Sequential(OrderedDict(
        [('conv1', conv1),
        (act_type+'_1', act_layer),
        (norm_type+'_1', norm_layer_1),
        ('conv2', conv2),
        (act_type+'_2', act_layer),
        (norm_type+'_2', norm_layer_2)]))
    
    return encoder_block
    

In [8]:
def Double_conv_encoder_block_en (num_in_filter,args):
    input_dim=args['input_dim'] ; kernal_size=args['kernal_size']; pool_kernel_size=args['pool_kernal_size'];
    dilation_size=args['dilation_size']; stride_size=args['stride_size'] ; pool_stride_size=args['pool_stride_size'];
    padding_size=args['padding_size']; bias=args['bias'] ; act_type= args['act_type'] ; norm_type=args['norm_type'] ; 
    pool_type=args['pool_type'] ;
    
    assert pool_type in ['max', 'avg']
    if pool_type == 'avg':
        pooling = nn.AvgPool3d(kernel_size=pool_kernel_size, stride=pool_stride_size, padding=0, dilation=1, ceil_mode=False)
    else:
        pooling = nn.MaxPool3d(kernel_size=pool_kernel_size, stride=pool_stride_size, padding=0, dilation=1, ceil_mode=False)
      
    
    conv1=Convolution(input_dim,num_in_filter*2,num_in_filter*2, strides=stride_size, 
                      kernel_size=kernal_size,  dilation=dilation_size,
                         bias=bias, conv_only=True,  padding=padding_size)
    conv2=Convolution(input_dim,num_in_filter*2,num_in_filter*4, strides=stride_size, 
                      kernel_size=kernal_size,  dilation=dilation_size,
                         bias=bias, conv_only=True,  padding=padding_size)
    
    if act_type=='PReLU':
        act_layer=nn.PReLU()
    elif act_type=='SELU':
        act_layer=nn.SELU()
    elif act_type=='CELU':
        act_layer=nn.CELU()
    else:
        act_layer=nn.ReLU()
    
    if norm_type=='GroupNorm':
        norm_layer_1=nn.GroupNorm(8,num_in_filter*2)
        norm_layer_2=nn.GroupNorm(8,num_in_filter*4)
    elif act_type=='BatchNorm':
        norm_layer_1=nn.LazyBatchNorm3d()
        norm_layer_2=nn.LazyBatchNorm3d()
    else: 
        norm_layer_1=nn.LazyInstanceNorm3d()
        norm_layer_2=nn.LazyInstanceNorm3d()
    
    encoder_block=nn.Sequential(OrderedDict(
        [('pool', pooling),
        ('conv1', conv1),
        (act_type+'_1', act_layer),
        (norm_type+'_1', norm_layer_1),
        ('conv2', conv2),
        (act_type+'_2', act_layer),
        (norm_type+'_2', norm_layer_2)]))
    
    return encoder_block

In [9]:
class UNet_3d_encoder(nn.Module):
    def __init__(self, args,en_channel):
        super(UNet_3d_encoder, self).__init__()
        self.args=args
        encoders = []
        
        input_en=Double_conv_encoder_block_in (en_channel[0],self.args)
        encoders.append(input_en)
        en_channel.pop(0)
        for i in en_channel:
            encoder=Double_conv_encoder_block_en (i,self.args)
            encoders.append(encoder)
        self.encoders=nn.ModuleList(encoders)
        
        
    def forward(self, x):
        # encoder part
        encoders_features = []
        for encoder in self.encoders:
            x = encoder(x)
            # reverse the encoder outputs to be aligned with the decoder
            encoders_features.insert(0, x)
        final_layer= encoders_features[0]
        encoders_features_o = encoders_features[1:]
        
        return final_layer , encoders_features_o

In [10]:
class encoder_mlp(nn.Module):
    def __init__(self, encoder, args):
        super(encoder_mlp, self).__init__()
        self.encoder = encoder
        self.mlp_out_list=args['mlp_out_list']
        self.lin_1=nn.LazyLinear(self.mlp_out_list[0])
        self.lin_2=nn.LazyLinear(self.mlp_out_list[1])
        self.lin_3=nn.LazyLinear(self.mlp_out_list[2])
        self.lin_4=nn.LazyLinear(self.mlp_out_list[3])
    def forward(self, images ):
        out,_=self.encoder(images)
        out_size_list=out.size()
        reposrnu=out_size_list[1]*out_size_list[2]*out_size_list[3]*out_size_list[4]
        out=out.view(-1,reposrnu)
        out=self.lin_1(out)
        out=self.lin_2(out)
        out=self.lin_3(out)
        out=self.lin_4(out)
        
        return out
        

In [11]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)

        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive

In [12]:
def create_encoder_maps(init_channel_number, number_of_layers):
    return [init_channel_number * 2 ** k for k in range(number_of_layers)]

In [13]:
args={}
args['slice_nu']=155
args['window_size']=10
args['batch_size']=20
args['mlp_out_list']=[2048,1024,512,256]
args['epoch']=300
args['input_dim']=3
args['kernal_size']=3
args['pool_kernal_size']=2
args['dilation_size']=1
args['stride_size']=1
args['pool_stride_size']=2
args['padding_size']=1
args['bias']=False
args['act_type']='PReLU'
args['norm_type']='GroupNorm'
args['pool_type']='max'

In [14]:
data_df['index']=range(len(data_df))

In [15]:
train_df=data_df.sample(frac=0.8)
val_df=data_df[~data_df['index'].isin(train_df['index'])]

In [16]:
train_data=simpletrain(train_df,args)
train_loader=DataLoader(train_data,batch_size=args['batch_size'])
val_data=simpletrain(val_df,args)
val_loader=DataLoader(val_data,batch_size=1)

In [None]:
next(iter(val_loader))[0].size()

In [17]:
en_channel=create_encoder_maps(4, 4)
en_channel_r=en_channel.copy()
en_channel_r.reverse()
print(en_channel)

[4, 8, 16, 32]


In [18]:
model_en=UNet_3d_encoder(args,en_channel)
model=encoder_mlp(model_en,args)



In [19]:
device=13
torch.cuda.set_device(device)

In [20]:
model=model.cuda()

In [21]:
optimizer=optim.Adam(model.parameters(),lr=0.005,weight_decay=1e-2)

In [None]:
all_train_loss=[]
all_val_loss=[]
loss_fn=ContrastiveLoss().cuda()
loss_fn.requires_grad = True
for epoch in tqdm(range(1, args['epoch'] + 1)):
    epoch_train_loss=0
    for count, batch in enumerate(train_loader):
        model.train()
        optimizer.zero_grad()
        q_img=batch[0].cuda()
        p_img=batch[1].cuda()
        n_img=batch[2].cuda()
        p_lab=batch[3].cuda()
        n_lab=batch[4].cuda()
        q_out=model(q_img)
        p_out=model(p_img)
        n_out=model(n_img)
        p_loss=loss_fn(q_out,p_out,p_lab.float())
        n_loss=loss_fn(q_out,n_out,n_lab.float())
        train_loss=p_loss+n_loss
        train_loss.backward()
        optimizer.step()
        record_loss=train_loss.cpu().data.numpy().item()
        epoch_train_loss+=record_loss
    epoch_train_loss=epoch_train_loss/count
    #print('current_train_loss {:8.5f}'.format(epoch_train_loss))
    all_train_loss.append(epoch_train_loss)
    if epoch% 30==0:
        model.eval()
        epoch_val_loss=0
        for count, batch in enumerate(val_loader):
            q_img=batch[0].cuda()
            p_img=batch[1].cuda()
            n_img=batch[2].cuda()
            p_lab=batch[3].cuda()
            n_lab=batch[4].cuda()
            with torch.no_grad():
                q_out=model(q_img)
                p_out=model(p_img)
                n_out=model(n_img)
                p_loss=loss_fn(q_out,p_out,p_lab)
                n_loss=loss_fn(q_out,n_out,n_lab)
                val_loss=p_loss+n_loss
                record_loss=val_loss.cpu().data.numpy().item()
                epoch_val_loss+=record_loss
        epoch_val_loss=epoch_val_loss/count
        all_val_loss.append(epoch_val_loss)
        model_name='3DUNet_encoder_'+str(epoch)+'.pt'
        torch.save(model.encoder.state_dict(), './model/3DUNet_encoder_10_pre/'+model_name)
        print('current_train_loss {:8.5f},current_val_loss {:8.5f}'.format(epoch_train_loss,epoch_val_loss))
            
            
        

 40%|████      | 120/300 [1:27:32<2:14:00, 44.67s/it]

current_train_loss 461.02841,current_val_loss 366.92928


 50%|█████     | 150/300 [1:48:24<1:50:08, 44.05s/it]

current_train_loss 321.01111,current_val_loss 267.61401


 52%|█████▏    | 156/300 [1:52:33<1:40:46, 41.99s/it]

In [None]:
model_en=UNet_3d_encoder(args,en_channel)
model_en.load_state_dict(torch.load('./model/3DUNet_encoder_10_pre/3DUNet_encoder_270.pt'))

In [None]:
model.encoder.load_state_dict(torch.load('./model/3DUNet_encoder_10_pre/3DUNet_encoder_270.pt'))

In [None]:
import gc
#del model
del model_en
gc.collect()
with torch.cuda.device(device):
    torch.cuda.empty_cache()

In [None]:
all_train_loss_df=pd.DataFrame(all_train_loss,columns=['loss'])
all_train_loss_df['epoch']=range(len(all_train_loss_df))
all_train_loss_df['epoch']=all_train_loss_df['epoch']+1
all_val_loss_df=pd.DataFrame(all_val_loss,columns=['loss'])
all_val_loss_df['epoch']=range(len(all_val_loss_df))
all_val_loss_df['epoch']=(all_val_loss_df['epoch']+1)*30
all_train_loss_df.to_csv('./model/all_loss/train_loss_3DUNet_encoder_pre_slide_10.csv')
all_val_loss_df.to_csv('./model/all_loss/val_loss_3DUNet_encoder_pre_slide_10.csv')

In [None]:
all_train_loss_df

In [None]:
all_val_loss_df

In [None]:
os.listdir('./model/Modified3DUNet_encoder_10_pre/')

In [None]:
model_en

In [None]:
model_en=Modified3DUNet_encoder(in_channels=4)
model_en.load_state_dict(torch.load('./model/Modified3DUNet_encoder_10_pre/Modified3DUNet_encoder_300.pt'))

In [None]:
model_en.state_dict()

In [None]:
model(batch[0].cuda()).size()

In [None]:
size_list=batch[0].size()

In [None]:
size_list[0]

In [None]:
q_img.size()

In [None]:
epoch_train_loss/count

In [None]:
loss(q_img,p_img,p_lab.float())

In [None]:
loss_fn=ContrastiveLoss()

In [None]:
n_out

In [None]:
n_loss

In [None]:
train_loss.backward()

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [None]:
count_parameters(model)

In [None]:
model(images[0]).size()

In [None]:
encoder_out=model_en(images[0]).view(10,-1)

In [None]:
lin1=nn.LazyLinear(4000)

In [None]:
lin1(encoder_out).size()

In [None]:
train_data=simpletrain(train_df,arg)
train_loader=DataLoader(train_data,batch_size=10)

In [None]:
from monai.networks.blocks.convolutions import Convolution

In [None]:
images=next(iter(train_loader))

In [None]:
images[3]

In [None]:
conv = Convolution(
    dimensions=3,
    in_channels=4,
    out_channels=8,
    adn_ordering="ADN",
    act=("prelu", {"init": 0.2}),
    dropout=0.1,
    norm=("instance"),
)
print(conv)

In [None]:
images[0]

In [None]:
conv(images[0]).size()

In [None]:

    ch_n = choice(try_index_n[:-window_size+1])
    n_rn = try_index_n[ch_n:ch_n+window_size]

In [None]:
class ContrastiveLoss(torch.nn.Module):
    

      def __init__(self, margin=2.0):
            super(ContrastiveLoss, self).__init__()
            self.margin = margin

      def forward(self, output1, output2, label):
            # Find the pairwise distance or eucledian distance of two output feature vectors
            euclidean_distance = F.pairwise_distance(output1, output2)
            # perform contrastive loss calculation with the distance
            loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
            (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

            return loss_contrastive


In [None]:
train_data=simpletrain(train_df,arg)


In [None]:
from monai.networks.blocks.convolutions import Convolution

In [None]:
conv = blocks.Convolution(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    adn_ordering="ADN",
    act=("prelu", {"init": 0.2}),
    dropout=0.1,
    norm=("layer", {"normalized_shape": (10, 10, 10)}),
)

In [None]:
train_df.

In [None]:
all_pt=[]
all_modality=[]
all_route=[]
pt_list=os.listdir(train_data_path)
for i in 
os.listdir(train_data_path+pt_list[0])

In [None]:
nib.load(os.path.join(data_folder, data_id) + "_t1.nii.gz").get_fdata()

In [None]:
import warnings
from typing import Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import SkipConnection
from monai.utils import alias, deprecated_arg, export

In [None]:
from monai.networks.nets import UNet

In [None]:
net = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(4, 8, 16),
    strides=(2, 1),
    num_res_units=0
)


In [None]:
from pytorch_model_summary import summary

In [None]:
list(net.children())

In [None]:
net

In [None]:
print(summary(net,torch.zeros((1, 1, 28, 28)), show_input=False, show_hierarchical=True))

In [None]:
net.model.

In [None]:
monai.networks.blocks.UpSample

In [None]:
conv = blocks.Convolution(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    adn_ordering="ADN",
    act=("prelu", {"init": 0.2}),
    dropout=0.1,
    norm=("layer", {"normalized_shape": (10, 10, 10)}),
)
print(conv)

In [None]:
blocks.UpSample