In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


# **Create a directory inside My drive/Colab Notebooks "attention"** 

In [2]:
cd gdrive/My\ Drive/Colab\ Notebooks/attention

/content/gdrive/My Drive/Colab Notebooks/attention


In [20]:
!ls
!pip install tqdm shap imblearn

50epochs_pt	      model_on_train_002.pt	test_ouput_shapley_4.nii
5epochs_pt_3d_slices  model_on_train_005.pt	test_ouput_shapley.nii
BaseLoader.py	      model_on_train_013.pt	train_model.py
BaseLoader_v2.py      __pycache__		vox_net.pt
data		      README.md			VoxResNetPytorch.py
images_output	      roc_auc_82.pt		VoxResNetPytorch_v2.py
Main.ipynb	      test_ouput_shapley_2.nii
Main_last.ipynb       test_ouput_shapley_3.nii


In [0]:
import shap
import pandas as pd
import numpy as np
import os
import nibabel as nib
import nibabel.processing
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data

from torchvision.datasets.folder import *
from imblearn.over_sampling import RandomOverSampler
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold,train_test_split
from torch.autograd import Variable,Function,gradcheck
from tqdm import tqdm_notebook as tqdm

In [0]:
class Fast3dResize(object):
    def __call__(self, image):
        image = image.cuda()
        resized = resample3d(image,inp_space=(193,193,229),out_space=(512,512,512))
        return resized

def resample3d(inp,inp_space,out_space=(1,1,1)):
    # Infer new shape
 
    out = resample1d(inp,inp_space[2],out_space[2]).permute(0,2,1)
    out = resample1d(out,inp_space[1],out_space[1]).permute(2,1,0)
    out = resample1d(out,inp_space[0],out_space[0]).permute(2,0,1)
    return out

def resample1d(inp,inp_space,out_space=1):
    #Output shape
    out_shape = list(np.int64(inp.size()[:-1]))+[int(np.floor(inp.size()[-1]*inp_space/out_space))] #Optional for if we expect a float_tensor
    out_shape = [int(item) for item in out_shape]
    # Get output coordinates, deltas, and t (chord distances)
    torch.cuda.set_device(inp.get_device())
    
    # Output coordinates in real space
    coords = torch.cuda.FloatTensor(range(out_shape[-1]))*out_space
    delta = coords.fmod(inp_space).div(inp_space).repeat(out_shape[0],out_shape[1],1)
    t = torch.cuda.FloatTensor(4,out_shape[0],out_shape[1],out_shape[2]).zero_()
    t[0] = 1
    t[1] = delta
    t[2] = delta**2
    t[3] = delta**3

    
    # Nearest neighbours indices
    nn = coords.div(inp_space).floor().long()    

    # Stack the nearest neighbors into P, the Points Array
    P = torch.cuda.FloatTensor(4,out_shape[0],out_shape[1],out_shape[2]).zero_()
    for i in range(-1,3):
        P[i+1] = inp.index_select(2,torch.clamp(nn+i,0,inp.size()[-1]-1))    
    
    #Take catmull-rom  spline interpolation:
    return 0.5*t.mul(torch.cuda.FloatTensor([[ 0,  2,  0,  0],
                            [-1,  0,  1,  0],
                            [ 2, -5,  4, -1],
                            [ -1, 3, -3,  1]]).mm(P.view(4,-1))\
                                                              .view(4,
                                                                    out_shape[0],
                                                                    out_shape[1],
                                                                    out_shape[2]))\
                                                              .sum(0)\
                                                              .squeeze()

In [0]:
class Base_loader_memory(data.Dataset):
    """ Some documentation"""
    """ Loader for work with mri data"""

    def __init__(self,labels, root_dir, transform=None,loader=default_loader,load_ds_memory=True,cuda=False,normalize_data=False,scaler_=False,slice_=95,resize=False):
        """
        Args:
        """
        self.all_data = load_dataset_in_memory(labels,root_dir,Tensor = False,slice_=slice_,resize=resize)
        if normalize_data == True:
          self.mean = self.all_data.mean()
          self.std = self.all_data.std()
          self.all_data -= self.mean
          self.all_data = self.all_data/self.std
        if scaler_ == True:
          self.all_data -= scaler_[0]
          self.all_data = self.all_data/scaler_[1]
        if cuda:
          self.all_data=self.all_data.cuda()
        self.root = os.listdir(root_dir)
        print(self.root)
        self.root_path = root_dir
        
        self.labels_ = labels
        self.labels = labels
        
        self.classes = labels.target
        self.class_to_idx = labels.participant_id
        self.transform = transform
        self.loader = loader
        self.size = len(self.labels)
    def refit(labels):
        self.labels = self.labels_[labels]
    def __getitem__(self, idx):
        img = self.all_data[idx]
        target = self.labels['target'].iloc[idx]
        #path = os.path.join(self.root_path,file_name+'_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz')
        
        #img = load_nii_to_array(path)
        if self.transform is not None:
            img = self.transform(img)
        return img,target
        
    
    def __len__(self):
        return len(self.labels)
      
def load_dataset_in_memory(y_test,root,Tensor = True,slice_=95,resize=False):
    files=[]
    for q in tqdm(y_test.iterrows()):
        ind,att = q
        file_name = root+'/'+att['participant_id']+'_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz'
        #print(att['target'])
        files.append((load_nii_to_array(file_name,resize=resize,slice_=slice_,).astype(np.float32)))
    if Tensor:
        return torch.tensor(np.array(files))
    else:
        return np.array(files)      

In [0]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class Identity(nn.Module):
    def __init__(self,):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class Reshape(nn.Module):
    def __init__(self, shape):
        nn.Module.__init__(self)
        self.shape = shape
    def forward(self, input):
        return input.view((-1,) + self.shape)


def conv3x3x3(in_planes, out_planes, stride=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes,track_running_stats=False)
        self.relu_1 = nn.ReLU(inplace=False)
        self.relu_2 = nn.ReLU(inplace=False)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes,track_running_stats=False)
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu_1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.relu_2(out)

        return out


class VoxResNet(nn.Module):
    def __init__(self, input_shape=(128, 128, 128), num_classes=2, n_filters=32, stride=2, n_blocks=4, dropout=0, n_fc_units=128):
        super(self.__class__, self).__init__()
        self.model = nn.Sequential()
    
        self.model.add_module("conv3d_1", nn.Conv3d(1, n_filters, kernel_size=3, padding=1, stride=stride)) # n * (x/s) * (y/s) * (z/s)
        self.model.add_module("batch_norm_1", nn.BatchNorm3d(n_filters,track_running_stats=False))
        self.model.add_module("activation_1", nn.ReLU(inplace=False))
        self.model.add_module("conv3d_2", nn.Conv3d(n_filters, n_filters, kernel_size=3, padding=1)) # n * (x/s) * (y/s) * (z/s)
        self.model.add_module("batch_norm_2", nn.BatchNorm3d(n_filters,track_running_stats=False))
        self.model.add_module("activation_2", nn.ReLU(inplace=False))

        # 1
        self.model.add_module("conv3d_3", nn.Conv3d(n_filters, 2 * n_filters, kernel_size=3, padding=1, stride=2)) # 2n * (x/2s) * (y/2s) * (z/2s)
        self.model.add_module("block_1", BasicBlock(2 * n_filters, 2 * n_filters))
        self.model.add_module("block_2", BasicBlock(2 * n_filters, 2 * n_filters))
        self.model.add_module("batch_norm_3", nn.BatchNorm3d(2 * n_filters,track_running_stats=False))
        self.model.add_module("activation_3", nn.ReLU(inplace=False))

    # 2
        if n_blocks >= 2:
            self.model.add_module("conv3d_4", nn.Conv3d(2 * n_filters, 2 * n_filters, kernel_size=3, padding=1, stride=2)) # 2n * (x/4s) * (y/4s) * (z/4s)
            self.model.add_module("block_3", BasicBlock(2 * n_filters, 2 * n_filters))
            self.model.add_module("block_4", BasicBlock(2 * n_filters, 2 * n_filters))
            self.model.add_module("batch_norm_4", nn.BatchNorm3d(2 * n_filters))
            self.model.add_module("activation_4", nn.ReLU(inplace=False))

        # 3
        if n_blocks >= 3:
            self.model.add_module("conv3d_5", nn.Conv3d(2 * n_filters, 4 * n_filters, kernel_size=3, padding=1, stride=2)) # 4n * (x/8s) * (y/8s) * (z/8s)
            self.model.add_module("block_5", BasicBlock(4 * n_filters, 4 * n_filters))
            self.model.add_module("block_6", BasicBlock(4 * n_filters, 4 * n_filters))
            self.model.add_module("batch_norm_5", nn.BatchNorm3d(4 * n_filters))
            self.model.add_module("activation_5", nn.ReLU(inplace=False))

        # 4
        if n_blocks >= 4:
            self.model.add_module("conv3d_6", nn.Conv3d(4 * n_filters, 4 * n_filters, kernel_size=3, padding=1, stride=2)) # 4n * (x/16s) * (y/16s) * (z/16s)
            self.model.add_module("block_7", BasicBlock(4 * n_filters, 4 * n_filters))
            self.model.add_module("block_8", BasicBlock(4 * n_filters, 4 * n_filters))
            self.model.add_module("batch_norm_6", nn.BatchNorm3d(4 * n_filters,track_running_stats=False))
            self.model.add_module("activation_6", nn.ReLU(inplace=False))

        # self.model.add_module("max_pool3d_1", nn.MaxPool3d(kernel_size=3)) # (b/2)n * (x/(2^b)sk) * (y/(2^b)sk) * (z/(2^b)sk) ?

        self.model.add_module("flatten_1", Flatten())
        # self.model.add_module("fully_conn_1", nn.Linear(2 ** ((n_blocks + 1) // 2) * n_filters * np.prod(np.array(input_shape) // (stride * 2 ** n_blocks)), n_fc_units))
        if n_blocks == 3:
            self.model.add_module("fully_conn_1", nn.Linear(4 * n_filters * (input_shape[0] // (8 * stride)) * (input_shape[1] // (8 * stride)) * (input_shape[2] // (8 * stride)), n_fc_units))
            
        if n_blocks == 4:
            self.model.add_module("fully_conn_1", nn.Linear(3840, n_fc_units)) #3840
            
        self.model.add_module("activation_6", nn.ReLU(inplace=False))
        self.model.add_module("dropout_1", nn.Dropout(dropout))

        self.model.add_module("fully_conn_2", nn.Linear(n_fc_units, num_classes))

    def forward(self, x):
        return self.model(x)

In [0]:
def train_model(y_train,y_test, epochs_cnt = 10,minibatch_size=50,lr=1e-6,momentum=0.9,num_workers = 1, loader_trainer=None,
                loader_test=None,
               debug = True,dropout = 0.1,validate_ =True,early_break = False,continue_train=False,model=None,sampler_=None):

  tensorize = transforms.ToTensor()
  #resize = Fast3dResize()
  composed = transforms.Compose([tensorize])
  
  #loader_train = Base_loader_(y_train,data_folder,transform = composed)
  loader_train = loader_trainer
  loader_test = loader_test
  
  if continue_train ==False:  
    vox_net = VoxResNet(dropout=dropout)#Deleted drop out
    vox_net.cuda()
  
    vox_net.train()
  else:
    vox_net = model
    vox_net.train()
#   optimizer = torch.optim.Adam(vox_net.parameters(), lr= 5e-7,betas=(0.9, 0.999), eps=1e-08,weight_decay=1e-4)#, momentum = momentum,  nesterov= True)
  optimizer = torch.optim.Adam(vox_net.parameters(), lr= 5e-5,betas=(0.9, 0.999), eps=1e-08,weight_decay=1e-4)#
  criterion = nn.CrossEntropyLoss().cuda()

  train_loader = torch.utils.data.DataLoader(loader_train, batch_size = minibatch_size,
                                             shuffle = True, num_workers = num_workers,pin_memory=True)
  
  test_loader = torch.utils.data.DataLoader(loader_test, batch_size = minibatch_size, shuffle = False, 
                                               num_workers = num_workers,pin_memory=False)
  
  losses = []
  scores = []
  #try:
  for epoch in range(epochs_cnt):
      loss = 0
      train_loss = 0
      if debug:
        print(epochs_cnt)
      for batch_idx, (inputs, labels) in enumerate(train_loader):
          if(batch_idx % 5 == 0):
              print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, train_loss / 10))
              train_loss = 0

          print(inputs.shape)
          inputs, labels = inputs.view(size=(inputs.shape[0],1,inputs.shape[1],\
                                             inputs.shape[2],inputs.shape[3])).cuda(), \
          torch.tensor(labels).cuda()#labels.view(minibatch_size).cuda(0)

          optimizer.zero_grad()

          outputs = vox_net(inputs)


          loss = criterion(outputs, labels)
          loss.backward()
          train_loss += loss.data.item()

          if debug == True:
            print(outputs)
            print(labels)
            print(epoch,loss.data.item())

          optimizer.step()
          if early_break:
            break
      losses.append(train_loss/10)
      if validate_ == True:
        score = validate(vox_net,y_test,test_loader=test_loader,create_torch_loader=False)
        vox_net.train()
        #return vox_net
        print(score)
        scores.append(score)
      if early_break:
        break
  print('Finished Training')
  return vox_net,losses,scores
#   except:
#     return vox_net

def validate(model,y_test,test_loader,minibatch_size = 50,num_workers = 1,return_predictions=False,create_torch_loader=True):
  acc = []
  print('start_validation')
  model.eval()

  y_target_labels =[]
  y_target_predictions = []
  tensorize = transforms.ToTensor()
  #resize = Fast3dResize()
  composed = transforms.Compose([tensorize])
  if create_torch_loader:
    test_loader = torch.utils.data.DataLoader(test_loader, batch_size = minibatch_size, shuffle = False, 
                                               num_workers = num_workers,pin_memory=False)
  #loader_train = Base_loader_(y_train,data_folder,transform = composed)
  #loader_test = Base_loader_(y_test,data_folder,transform = composed)
  
  for inputs,y in test_loader:
    print(inputs.shape, y.shape,y)
    x = inputs.view(size=(inputs.shape[0],1,inputs.shape[1],inputs.shape[2],inputs.shape[3])).cuda()
    val = model(x)
    args = val.max(1)[1]
    print('output',val)
    print(y)
    #print(np.mean(args.cpu().numpy()==y.cpu().numpy()))
    acc.append( np.mean ( args.cpu().numpy() == y.cpu().numpy() ))
    y_target_labels.append(y.cpu().numpy().tolist())
    y_target_predictions.append(args.cpu().numpy().tolist())
  print(np.mean(acc) )
  
  #output_ = 0
  #reemovNestings(y_target_predictions)
  y_target_predictions = flatten(y_target_predictions)
  
  
  y_target_labels = flatten(y_target_labels)
  print(y_target_labels,y_target_predictions)
  #print(y_tr,y_pr =score[1],score[2]
  print('roc_auc',roc_auc_score(np.array(y_target_labels),np.array(y_target_predictions)))
  model.train()
  if return_predictions == True:
    return np.mean(acc),np.array(y_target_labels),np.array(y_target_predictions)
    
  return np.mean(acc) 

def flatten(aList):
    new_list = []
    found_list = True
    while found_list:
      found_list = False
      for elem in aList:
        if isinstance(elem, list):
          for subelem in elem:
            new_list.append(subelem)
          found_list=True
        else:
          new_list.append(elem)
    
      new_list, aList = aList, new_list
      new_list.clear()
    return aList


def kfold_validate():
  
    kf = KFold(n_splits=5,shuffle=True)
    scores_ = []
    roc_scores = []
    batch_size=50
   
    for itteration, indexes in enumerate(kf.split(y)):
      print('itteration #', itteration)
      train_index, test_index  = indexes
      y_train = y.iloc[train_index]
      y_test = y.iloc[test_index]
      y_train.reset_index(inplace=True)
      y_test.reset_index(inplace=True)
      y_train.drop(columns=['index'],inplace=True)
      y_test.drop(columns=['index'],inplace=True)
      
      
      r = RandomOverSampler()
      x, y_resampled = r.fit_resample(y_train, y_train.target.reshape(-1,1))
      x = pd.DataFrame(x)
      x.columns = y_train.columns
      
      
      loader_test= Base_loader_memory(y_test,data_folder,transform = composed,cuda=False,normalize_data=True,scaler_ = False)
      loader_train = Base_loader_memory(x,data_folder,transform = composed,cuda=False,normalize_data=True)
      
      model,loss,scores = train_model(y_train,y_test,minibatch_size=batch_size,continue_train=False,model=ouput[0],epochs_cnt=50,
                    validate_=True, sampler_=None,loader_trainer=loader_train,loader_test=loader_test,dropout=False)
      
      test_loader = torch.utils.data.DataLoader(loader_test, batch_size = batch_size, shuffle = False, 
                                               num_workers = 1,pin_memory=False)
      #loader_test = Base_loader_(y_test,data_folder,transform = composed)
      #test_loader = torch.utils.data.DataLoader(loader_test, batch_size = batch_size, shuffle = True, num_workers = 1)
      
      
     
      
      score,y_true,y_pred = validate(model,y_test,loader_test=test_loader,return_predictions=True)
      
      roc_scores.append(roc_auc_score(y_true,y_pred))
#       validate(model,y_test)
      scores_.append(score)
      print(scores_)
      print(roc_scores)
    return scores_,roc_scores

In [0]:
def load_nii_to_array(nii_path,voxel_size = [2, 2, 2],crop=True,resize=False,slice_=False):
    file_ = nib.load(nii_path)
    if resize==False:
      resampled_img = file_
    else:
      resampled_img = nibabel.processing.resample_to_output(file_, voxel_size)
    if crop:
      if slice_:
        return resampled_img.get_data()[20:175,slice_,0:165]
      else:
        return resampled_img.get_data()[20:175,:,0:165]
      
    else:
      return resampled_img.get_data()[20:175,:,0:165]
    
def preproc_data(data_file,return_y = True):
    #Encode with 1 SCHZ, and with 0 normal
    
    data_file = data_file[(participants_pandas['label']=='SCHZ') | (participants_pandas['label']=='CONTROL')]
    data_file['target'] = 0
    
    data_file.target[data_file['diagnosis'] == 'CONTROL'] = 0 
    data_file.target[data_file['diagnosis'] == 'SCHZ'] = 1
    
    if return_y == True:
        y = data_file[['participant_id','target']]
        return y
    return data_file

  
def create_loaders(y_train,y_test,oversampling = True,slice_=None,resize=False):
  
  if oversampling:
    r = RandomOverSampler()
    r = RandomOverSampler()
    x, y_resampled = r.fit_resample(y_train, y_train.target.reshape(-1,1))
    x = pd.DataFrame(x)
    x.columns = y_train.columns
    y_train = x
  tensorize = transforms.ToTensor()
  #resize = Fast3dResize()
  composed = transforms.Compose([tensorize,
                              ])
  
  loader_test= Base_loader_memory(y_test,data_folder,transform = composed,cuda=False,normalize_data=False,scaler_ = False,slice_=slice_,resize=resize)
  loader_train = Base_loader_memory(y_train,data_folder,transform = composed,cuda=False,normalize_data=False,slice_=slice_,resize=resize)
#   test_loader = torch.utils.data.DataLoader(loader_test, batch_size = batch_size, shuffle = False, 
#                                                  num_workers = 1,pin_memory=False)
  
  return loader_test,loader_train

In [24]:
data_folder = 'data/sMRI' 
targets_file ='data/LA5study_targets.csv' 

participants_pandas = pd.read_csv(targets_file)

y = preproc_data(participants_pandas)
participants_pandas.head()


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  self._update_inplace(new_data)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/in

Unnamed: 0,participant_id,diagnosis,age,gender,bart,bht,dwi,pamenc,pamret,rest,...,stopsignal,T1w,taskswitch,ScannerSerialNumber,ghost_NoGhost,label,Bipolar/Control,Bipolar/NotBipolar,ADHD/Control,Schz/Control
0,sub-10159,CONTROL,30,F,1.0,,1.0,,,1.0,...,1.0,1.0,1.0,35343.0,No_ghost,CONTROL,0.0,0,0.0,0.0
1,sub-10171,CONTROL,24,M,1.0,1.0,1.0,,,1.0,...,1.0,1.0,1.0,35343.0,No_ghost,CONTROL,0.0,0,0.0,0.0
2,sub-10189,CONTROL,49,M,1.0,,1.0,,,1.0,...,1.0,1.0,1.0,35343.0,No_ghost,CONTROL,0.0,0,0.0,0.0
3,sub-10193,CONTROL,40,M,1.0,,1.0,,,,...,,1.0,,35343.0,No_ghost,CONTROL,0.0,0,0.0,0.0
4,sub-10206,CONTROL,21,M,1.0,,1.0,,,1.0,...,1.0,1.0,1.0,35343.0,No_ghost,CONTROL,0.0,0,0.0,0.0


In [46]:
y_train,y_test = train_test_split(y,random_state=42,shuffle = True,test_size = 0.15)
y_train_zero = y_train[y_train.target == 0]
y_train_first = y_train[y_train.target == 1]

print(len(y_train[y_train.target == 1]))


y_train_sampled = pd.concat((y_train_first,y_train_zero[:len(y_train_first)]))

y_train.reset_index(inplace=True)
y_test.reset_index(inplace=True)
y_train.drop(columns=['index'],inplace=True)
y_test.drop(columns=['index'],inplace=True)

40


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  if sys.path[0] == '':
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  del sys.path[0]


In [81]:
loader_test,loader_train = create_loaders(y_train,y_test,oversampling=True,slice_=95,resize=False)

  y = column_or_1d(y, warn=True)


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


['sub-10159_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10171_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10189_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10193_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10206_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10217_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10225_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10227_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10228_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10235_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10249_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10269_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10271_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10273_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10274_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10280_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10290_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10292_T

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


['sub-10159_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10171_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10189_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10193_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10206_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10217_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10225_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10227_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10228_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10235_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10249_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10269_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10271_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10273_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10274_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10280_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10290_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz', 'sub-10292_T

In [91]:
output = train_model(y_train,y_test,minibatch_size=50,continue_train=False,epochs_cnt=5,model=None,
                    validate_=True, sampler_=None,loader_trainer=loader_train,loader_test=loader_test,dropout=0.2)

5
[1,     1] loss: 0.000
torch.Size([50, 1, 155, 165])




tensor([[ 0.0201, -0.1198],
        [ 0.2465, -0.1175],
        [ 0.5893, -0.5130],
        [ 0.1104,  0.1927],
        [ 0.1198,  0.1468],
        [ 0.3966,  0.1979],
        [ 0.2794,  0.0329],
        [ 0.0214,  0.0830],
        [ 0.1547,  0.0454],
        [ 0.3196,  0.3733],
        [ 0.1814,  0.1595],
        [-0.0020,  0.0062],
        [ 0.0965, -0.2622],
        [ 0.0856,  0.0638],
        [ 0.1336,  0.0994],
        [-0.0425,  0.0369],
        [ 0.2582,  0.0212],
        [ 0.3377,  0.1653],
        [ 0.0417,  0.1057],
        [ 0.1986, -0.1031],
        [ 0.2958,  0.0681],
        [ 0.1847,  0.1414],
        [ 0.4438,  0.1967],
        [-0.1828, -0.2636],
        [ 0.4364,  0.2272],
        [ 0.0907, -0.0107],
        [ 0.1496,  0.1412],
        [ 0.0575,  0.2370],
        [ 0.1136,  0.1356],
        [ 0.0440,  0.0212],
        [ 0.1998,  0.1684],
        [ 0.0933,  0.2307],
        [ 0.0555,  0.1303],
        [ 0.3081,  0.2634],
        [ 0.0609,  0.0633],
        [ 0.2144,  0



tensor([[-2.0757e-05,  1.7136e-01],
        [-1.9424e-01,  4.5498e-01],
        [-1.2649e-01,  3.0602e-01],
        [-4.8126e-01,  5.9188e-01],
        [-1.5851e-01,  3.4400e-01],
        [-3.4183e-01,  6.4031e-01],
        [ 2.2905e+00, -2.3656e+00],
        [-4.3485e-01,  5.9602e-01],
        [ 4.2389e-01, -3.5515e-01],
        [-7.8451e-02,  2.6723e-01],
        [ 3.6143e-01, -2.3812e-01],
        [ 7.3809e-01, -4.2188e-01],
        [ 2.8778e-01,  3.7084e-02],
        [ 3.1468e-01, -3.1188e-01],
        [ 1.7199e-01,  8.0337e-02],
        [-2.2126e-01,  6.1305e-01],
        [ 3.2132e-01, -2.3341e-01],
        [-2.0384e-01,  4.4684e-01],
        [-2.8017e-01,  4.7891e-01],
        [-2.9774e-01,  7.2284e-01],
        [ 3.1420e-03,  1.2404e-01],
        [ 1.7057e-01, -2.0364e-01],
        [-1.8840e-01,  5.6392e-01],
        [ 3.3544e-01, -1.8109e-01],
        [-2.9505e-02,  2.7110e-01],
        [ 4.4507e-01, -3.5868e-01],
        [-3.5931e-01,  5.1014e-01],
        [ 1.5416e+00, -1.643



tensor([[-0.7705,  0.7647],
        [-0.0769,  0.2977],
        [-0.7404,  0.8420],
        [ 1.7070, -1.7855],
        [ 0.4965, -0.5890],
        [ 0.6139, -0.1237],
        [-0.7798,  0.8461],
        [-0.2005,  0.6371],
        [-0.3680,  0.5240],
        [-0.3918,  0.7236],
        [ 0.2830, -0.1176],
        [-0.5766,  0.9610],
        [-0.6121,  0.7213],
        [-0.2275,  0.5454],
        [-0.5702,  0.6452],
        [-0.5480,  0.8637],
        [ 0.3659, -0.1683],
        [-0.5169,  0.3200],
        [ 0.1580, -0.0519],
        [-0.5297,  0.5730],
        [-0.3256,  0.4308],
        [ 2.0103, -2.1042],
        [-0.9589,  0.7062],
        [ 0.1357,  0.1517],
        [ 0.8685, -0.4660],
        [ 3.4947, -2.7838],
        [-0.6804,  1.1550],
        [ 0.5071, -0.3952],
        [-0.4077,  0.5741],
        [ 0.4966, -0.3036],
        [-0.5808,  0.9877],
        [ 0.3406, -0.1040],
        [ 1.8559, -1.8042],
        [ 3.1925, -3.2827],
        [ 0.4457, -0.3141],
        [ 2.8197, -2



tensor([[ 0.3681, -0.2103],
        [ 1.2200, -1.0095],
        [-0.5493,  0.8203],
        [ 2.3176, -2.0032],
        [ 1.2564, -1.0411],
        [-0.7873,  0.7392],
        [ 0.9001, -0.8195],
        [ 0.4682, -0.5313],
        [ 0.5012, -0.1672],
        [-0.4972,  0.6218],
        [ 1.0333, -1.3174],
        [-1.1127,  1.5984],
        [ 0.2347, -0.0187],
        [-0.9570,  1.4305],
        [ 1.7421, -1.4777],
        [ 1.0016, -0.5806],
        [-0.4271,  0.8057],
        [ 1.7110, -1.6369],
        [ 0.9579, -0.7846],
        [ 2.3375, -2.3496],
        [-0.8313,  0.9916],
        [ 2.5979, -2.8604],
        [ 0.7417, -0.6097],
        [ 1.1960, -1.0310],
        [-0.5396,  1.0263],
        [ 1.0707, -1.6067],
        [ 1.0871, -1.9474],
        [ 0.5170, -0.9420],
        [ 0.3028,  0.0656],
        [-0.3730,  0.7576],
        [-1.3806,  2.0385],
        [ 1.7173, -1.1483],
        [-0.5957,  0.6488],
        [ 1.2584, -1.3895],
        [-0.4833,  0.5450],
        [-1.5318,  1



tensor([[-0.8144,  1.0219],
        [ 1.4520, -1.4268],
        [ 2.8857, -2.6909],
        [ 0.5389, -0.4519],
        [ 2.1446, -2.7529],
        [ 1.7221, -1.8693],
        [-1.7040,  2.1120],
        [-0.7042,  1.3625],
        [ 1.7422, -1.4214],
        [-0.7021,  1.5094],
        [-1.5081,  2.0629],
        [-1.0929,  1.6963],
        [-1.0575,  1.3165],
        [ 1.6078, -1.6835],
        [-0.9175,  1.1507],
        [ 1.1228, -1.1873],
        [ 0.4896, -0.3156],
        [-0.5413,  0.6599],
        [-0.9528,  1.1061],
        [-0.9104,  1.1779],
        [ 1.9449, -1.7927],
        [-1.5539,  1.5623],
        [-1.5960,  1.7653],
        [ 2.4823, -3.2016],
        [-1.5647,  1.9440],
        [ 0.7974, -0.6920],
        [-0.7222,  0.9821],
        [-0.5807,  0.9274],
        [ 1.5242, -1.4801],
        [-1.0979,  1.3176],
        [-0.6431,  1.0241],
        [-1.1624,  1.3589],
        [-0.7003,  1.1826],
        [-1.0462,  1.6324],
        [-0.8768,  1.1075],
        [-0.7722,  1

In [0]:
def get_shap_values(y_train,picture,model,model_path,background_size = 20):
  files = load_dataset_in_memory(y_train,root = data_folder,)
  image_1 = load_nii_to_array(data_folder+'/'+picture+'_T1w_space-MNI152NLin2009cAsym_preproc.nii.gz',crop=True,resize=True)
  image_1 = torch.tensor(image_1)
  if model == False:
    model = torch.load(model_path)
  model.eval().cuda()
  files = files[:background_size].cuda()
  e = shap.DeepExplainer((model, model.model.conv3d_1), files.view(files.shape[0],1,1,files.shape[2]))
  q,indexes = e.shap_values(torch.Tensor(image_1).cuda().view(1,1,image_1.shape[0],image_1.shape[1],image_1.shape[2]), ranked_outputs=2)
  return q,indexes

In [0]:
def get_threshold_shap(shap_values):
  shap_values= shap_values/max(shap_values) * 255
  return shap_values*(shap_values > 55) 

In [0]:
def get_brain_mri(nifti_img,shap_values,save_name,img):
  empty_header = nib.Nifti1Header()
  another_img = nib.Nifti1Image(shap_values, nifti_img.affine, empty_header)
  nib.save(img, save_name)

In [0]:
import cv2
def save_cv2(shap_values,sign = 'le'):
  for i in range(0, shap_values.shape[5]):
    if sign == 'le':
      image = shap_values[0,0,0,:,i,:]
    else:
      image = shap_values[1,0,0,:,:,i]*( shap_values[1,0,0,:,:,i]<0 )
    pos = image/image.max()*256
    cv2.imwrite('images_output/s_C001Z%03d'%i+'.tif', pos)