In [24]:
# Classification of ASD vs Controls based on different atlases.
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch import nn
import torch.optim as optim
from torch.autograd import Variable
import torch
import torch.nn.functional as F
from pprint import pprint
from sklearn.utils import shuffle
from scipy.stats import mode
from sklearn.metrics import accuracy_score
import os.path as osp
import os
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import KFold
import torch.utils.data as data_utils
from sklearn.metrics import confusion_matrix

In [25]:
torch.cuda.set_device(1)

In [26]:
class Abide1DConvNet(nn.Module):
    def __init__(self, nROIS=2):
        super(Abide1DConvNet, self).__init__()
        
        self.conv1 = nn.Conv1d(nROIS, 64, 14)
        self.avg = nn.AdaptiveAvgPool1d((1))
        self.drop2 = nn.Dropout(p=0.2)
        self.linear1 = nn.Linear(64, 2)

        
    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = self.avg(x).view(-1, 64)
        x = self.drop2(x)
        x =self.linear1(x)
        x = F.softmax(x,dim=1)
        
        return x

In [27]:
def validate_model(net, val_data_loader, criterion=nn.CrossEntropyLoss().cuda()):
     
    net.eval()
    loss = 0.0
    labels = np.empty([],dtype=int)
    predictions = np.empty([],dtype=int)
    for i, (tc, dx) in enumerate(val_data_loader):

            tc = Variable(tc).type(torch.cuda.FloatTensor)
            dx = Variable(dx).type(torch.cuda.LongTensor)
   
            # forward pass
            output = net(tc)

            # calculate loss
            loss += criterion(output, torch.max(dx,1)[1])
            
            #append labels and predictions of each batch
            labels = np.append(labels,torch.argmax(dx,1).cpu().numpy())
            predictions = np.append(predictions,torch.argmax(output,1).cpu().numpy())
    
    # Calculate confusion matrix
    cm1 = confusion_matrix(labels[1:] , predictions[1:])
    total1=sum(sum(cm1))
    accuracy1=(cm1[0,0]+cm1[1,1])/total1


    return loss/len(val_data_loader), accuracy1

In [28]:
def train_model(train_data,val_data,exp_dir,nepochs,verbose =True):
    
    
    train_data_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_data_loader = DataLoader(val_data, batch_size=32, shuffle=True)

    
    nrois = train_data.__getitem__(0)[0].shape[0] # Trick to get the nrois (=nchannels)
    
    # Calculating class distribution
    
    class_db = (np.unique(np.argmax(train_data.tensors[1].numpy(),1), return_counts=True))[1]
    class_db = class_db[::-1]
    class_weigths = torch.tensor((class_db/np.sum(class_db)),dtype=torch.float).cuda()


    
    net = Abide1DConvNet(nROIS=nrois).cuda()
    
    criterion = nn.CrossEntropyLoss(weight=class_weigths)
    optimizer = optim.Adam(net.parameters(), lr=.0005,weight_decay=0.02)
    
    net.train()
    train_loss = []
    val_loss = []
    best_val_loss = None
    best_net = None
    
    #print(f'Training ...')
    for i_epoch in range(nepochs):

        epoch_loss = 0.0
        for i, (tc, dx) in enumerate(train_data_loader):

            tc = Variable(tc).type(torch.cuda.FloatTensor)
            dx = Variable(dx).type(torch.cuda.LongTensor)
            # forward pass
            output = net(tc).cuda(1)

            # calculate loss
            loss = criterion(output.cuda(1), torch.max(dx,1)[1].cuda(1)).cuda(1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss

        epoch_train_loss = epoch_loss/i
        epoch_val_loss, epoch_val_acc = validate_model(net, val_data_loader,criterion=criterion)
        
        train_loss.append(epoch_train_loss)
        val_loss.append(epoch_val_loss)
        
        #if verbose and i_epoch%1 == 0:
            #print('Epoch:{} --- Train_loss:{} --- Val_loss:{} --- Val_acc:{}'.format(i_epoch, epoch_train_loss, epoch_val_loss, epoch_val_acc))
        
        #Save model with best validation loss
        if not best_val_loss or epoch_val_loss < best_val_loss:
            with open(os.path.join(exp_dir,"model.pt"), 'wb') as f:
                #print("saving best model....")
                torch.save(net, f)
            best_net = net
            best_val_loss = epoch_val_loss
                                    
    return best_net

In [29]:
def test_model(net,test_data):
    test_data_loader = DataLoader(test_data, batch_size=4, shuffle=True)
    net.cuda().eval()
    tot_acc = 0.0
    tot_spec = 0.0
    labels = np.empty([],dtype=int)
    predictions = np.empty([],dtype=int)
    for i, (tc, dx) in enumerate(test_data_loader):

            tc = Variable(tc).type(torch.cuda.FloatTensor)
            dx = Variable(dx).type(torch.cuda.LongTensor)

            # forward pass
            output = net(tc).cuda()
            labels = np.append(labels,torch.argmax(dx,1).cpu().numpy())
            predictions = np.append(predictions,torch.argmax(output,1).cpu().numpy())
            

    cm1 = confusion_matrix(labels[1:] , predictions[1:])
    total1=sum(sum(cm1))
    accuracy1=(cm1[0,0]+cm1[1,1])/total1
    sensitivity1 = cm1[0,0]/(cm1[0,0]+cm1[0,1])
    specificity1 = cm1[1,1]/(cm1[1,0]+cm1[1,1])

    return accuracy1,sensitivity1,specificity1

In [33]:
# Used for training the model and validating it based on k-fold validation scheme
def train_kfold(
    exp_dir='/data/agelgazzar/projects/models/cv/exp1/',
 atlas_name='schaefer_400',
 root_dir='/data_local/deeplearning/ABIDE_ML_inputs/', 
 data_info_file='data_info.csv', 
 dir1= "bptf",
 dir2 = "no_nilearn_regress",
 nTime_min=200, 
 zscore=True,
 folds = 10,
 epochs = 100):    

        
        
        # Check if valid atlas name
        if atlas_name not in ['AAL', 'HO_cort_maxprob_thr25-2mm', 'schaefer_100', 'schaefer_400','JAMA_IC19','JAMA_IC52',"JAMA_IC7"]:
            raise ValueError('atlas_name not found')
        
        print("preparing dataset ....")
        
        # Read the parent CSV file
        
        data_info = pd.read_csv(osp.join(root_dir, data_info_file))
        
        #nTime_max = 250
        #data_info_new = data_info[data_info.nTimes > nTime_max]
        #max_subjects = len(data_info_new)
        
        
        ## Filter out badly preprocessed samples
        text_file = open("/data_local/deeplearning/ABIDE_LC/list_nogood.txt", "r")
        lines = text_file.read().split('\n')
        ind = [int(i) for i in lines[:-1]]
        data_info = data_info[~ np.isin(data_info["SUB_ID"],ind)]
        
        
        # Filter the dataframe to contain subjects with nTimes > nTime threhsold
        data_info = data_info[data_info.nTimes > nTime_min]
        

        #data_info = data_info.sample(max_subjects)

        data_info = shuffle(data_info,random_state = 1)
        

        
        
        # Determine the nchannels (=nrois) from the data by using the first sample
        sample_file = data_info['tc_file'].iloc[0].replace('ATLAS', atlas_name).replace("BPTF",dir1).replace("CONFOUNDS",dir2)
        nrois = pd.read_csv(sample_file).values.shape[1]
        
        total_subjects = len(data_info)
        
        
        # Initialize an np array to store all timecourses and labels
        tc_data = np.zeros((total_subjects, nrois, nTime_min))
        labels = np.zeros(total_subjects, dtype=int)

        
        # Load data       
        for i, sub_i in enumerate(data_info.index):
            tc_file = data_info['tc_file'].loc[sub_i].replace('ATLAS', atlas_name).replace("BPTF",dir1).replace("CONFOUNDS",dir2)
            tc_vals = pd.read_csv(tc_file).values.transpose()[:, :nTime_min]

            if (zscore):       
                tc_vals =  np.array([(tc_vals[:,i] - np.mean(tc_vals[:,i]))/np.std(tc_vals[:,i]) for i in range (tc_vals.shape[1])])
                tc_data[i] = tc_vals.transpose()
            else:
                tc_data[i] = tc_vals

            labels[i] = data_info['DX_GROUP'].loc[sub_i]
   
        labels = np.eye(2)[labels]
        
        kfold = KFold(folds, True, 1)
        
        j = 1
        
        total_accuracy = 0
        total_sensitivity = 0
        total_specificity = 0
        
        #K-fold Cross_validation
        for train_index, test_index in kfold.split(tc_data):
            path = osp.join(exp_dir,"{}/{}/fold{}".format(str(nTime_min),atlas_name,str(j)))
            if not osp.exists(path):
                os.makedirs(path)
            #spltitting training fold into 90% training and 10% validation    
            train_split = int(0.9 * len(train_index))
            train_i = train_index[0:train_split]
            val_i = train_index[train_split:]
            
            # Create training,testing and validation datasets
            train_data = torch.from_numpy(tc_data[train_i])
            train_labels= torch.from_numpy(labels[train_i])
            val_data = torch.from_numpy(tc_data[val_i])
            val_labels = torch.from_numpy(labels[val_i])
            test_data = torch.from_numpy(tc_data[test_index])
            test_labels = torch.from_numpy(labels[test_index])   
            train = data_utils.TensorDataset(train_data, train_labels)
            val = data_utils.TensorDataset(val_data, val_labels)
            test = data_utils.TensorDataset(test_data, test_labels)
            
            #train network
            trained_network = train_model(train,val,path,epochs)
            #test network
            test_accuracy, test_sens, test_spec = test_model(trained_network,test)
            total_accuracy += test_accuracy
            total_sensitivity += test_sens
            total_specificity += test_spec
            j +=1
            #print("----Test results of of fold {} are : {} acc., {} sens. and  {} spec. ----".format(j, test_accuracy, test_sens, test_spec))
            
            
        acc = total_accuracy/folds
        sens = total_sensitivity/folds
        pec = total_specificity/folds
        print("{} in {} nTime results are: {} acc., {} sens. and  {} spec. ".format(atlas_name, nTime_min, acc, sens, spec))

            
            

            

In [34]:
# Used for training the model and validating it based on leave one site out validation scheme

def train_site_val(
    exp_dir='/data/agelgazzar/projects/models/cv/exp1/',
 atlas_name='schaefer_400',
 root_dir='/data_local/deeplearning/ABIDE_ML_inputs/', 
 data_info_file='data_info.csv', 
 dir1= "bptf",
 dir2 = "no_nilearn_regress",
 nTime_min=200, 
 zscore=True,
 folds = 10,
 epochs = 100):    

        
        
        # Check if valid atlas name
        if atlas_name not in ['AAL', 'HO_cort_maxprob_thr25-2mm', 'schaefer_100', 'schaefer_400','JAMA_IC19','JAMA_IC52',"JAMA_IC7"]:
            raise ValueError('atlas_name not found')
        
        #print("preparing dataset ....")
        # Read the parent CSV file
        data_info = pd.read_csv(osp.join(root_dir, data_info_file))
        
        #nTime_max = 250   
        #data_info_new = data_info[data_info.nTimes > nTime_max]
        #max_subjects = len(data_info_new)
        
        # Filter the dataframe to contain subjects with nTimes > ntime threshold
        data_info = data_info[data_info.nTimes > nTime_min]
        
        data_info = shuffle(data_info,random_state = 1)

        # filter out sites with number of subjects less than 10 
        sites,counts = np.unique(data_info["SITE_ID"].values,  return_counts=True)
        filtered_sites = sites[np.where(counts>10)[0]]
        data_info_filtered = data_info[np.isin(data_info["SITE_ID"],filtered_sites)]
          
        
        # Determine the nchannels (=nrois) from the data by using the first sample
        sample_file = data_info_filtered['tc_file'].iloc[0].replace('ATLAS', atlas_name).replace("BPTF",dir1).replace("CONFOUNDS",dir2)
        nrois = pd.read_csv(sample_file).values.shape[1]
        
        
        # Initialize an np array to store all timecourses and labels
        total_subjects = len(data_info_filtered)    
        tc_data = np.zeros((total_subjects, nrois, nTime_min))
        labels = np.zeros(total_subjects, dtype=int)

        
        # Load data
        for i, sub_i in enumerate(data_info_filtered.index):
            tc_file = data_info_filtered['tc_file'].loc[sub_i].replace('ATLAS', atlas_name).replace("BPTF",dir1).replace("CONFOUNDS",dir2)
            tc_vals = pd.read_csv(tc_file).values.transpose()[:, :nTime_min]

            if (zscore):       
                tc_vals =  np.array([(tc_vals[:,i] - np.mean(tc_vals[:,i]))/np.std(tc_vals[:,i]) for i in range (tc_vals.shape[1])])
                tc_data[i] = tc_vals.transpose()
            else:
                tc_data[i] = tc_vals     
        labels[i] = data_info_filtered['DX_GROUP'].loc[sub_i]
       
        # One-hot enconding of labels
        labels = np.eye(2)[labels]
        
        

        
        #Site Cross_validation
        for site in filtered_sites:
            
            test_index = np.where(data_info_filtered["SITE_ID"] == site)[0]
            size = len(test_index)
            train_index = np.where(data_info_filtered["SITE_ID"] != site)[0]
            np.random.shuffle(train_index)
            
            path = osp.join(exp_dir,"{}/{}/{}".format(str(nTime_min),atlas_name,site))
            if not osp.exists(path):
                os.makedirs(path)
            
            #spltitting training fold into 90% training and 10% validation    
            train_split = int(0.9 * len(train_index))
            train_i = train_index[0:train_split]
            val_i = train_index[train_split:]
            
            # Create training,testing and validation datasets
            train_data = torch.from_numpy(tc_data[train_i])
            train_labels  = torch.from_numpy(labels[train_i])
            val_data = torch.from_numpy(tc_data[val_i])
            val_labels = torch.from_numpy(labels[val_i])
            test_data = torch.from_numpy(tc_data[test_index])
            test_labels = torch.from_numpy(labels[test_index])
            train = data_utils.TensorDataset(train_data, train_labels)
            val  = data_utils.TensorDataset(val_data, val_labels)
            test = data_utils.TensorDataset(test_data, test_labels)
            # train network
            trained_network = train_model(train,val,path,epochs)
            # test network 
            test_accuracy, test_sens, test_spec = test_model(trained_network,test)
            print("{} :acc {}, sens {},  spec {}".format(site, test_accuracy, test_sens, test_spec))


            
            

            