In [None]:
import glob
import pandas as pd
import numpy as np
from IPython import display

# for dataset
import torch
import torchio as tio
from torch.utils.data.dataset import Dataset
import xml.etree.ElementTree as ET
from sklearn.preprocessing import MinMaxScaler
from sklearn.impute import KNNImputer
from sklearn.cluster import DBSCAN

import import_ipynb
from utils import *

CLASS = {"CN":0., "AD":1., "sMCI":0., "pMCI":1.}
# CLINICALS = ['PTGENDER','APOE4','AGE','PTEDUCAT','CDRSB','ADAS11','ADAS13','ADASQ4','MMSE',
#              'RAVLT_immediate','RAVLT_learning','RAVLT_forgetting','RAVLT_perc_forgetting','CDR']
CLINICALS = ['PTGENDER','APOE4','AGE','PTEDUCAT','CDRSB','ADAS11','ADAS13',
             'RAVLT_immediate','RAVLT_learning','RAVLT_forgetting','RAVLT_perc_forgetting']

In [None]:
class ADNIDataset(Dataset): # for torch DataLoader
    
    def __init__(self, subjects, train=False, modality='multimodality'):
        self.subjects = subjects
        self.modality = modality
        
        if train is True:
            self.transform = tio.Compose([
                tio.RescaleIntensity((0,1), percentiles=(0.5, 99.5)),
                tio.CropOrPad(96),
            ])
        else:
            self.transform = tio.Compose([
                tio.RescaleIntensity((0,1), percentiles=(0.05, 99.5)),
                tio.CropOrPad(96),
            ])
            
            
    def __len__(self):
        return len(self.subjects)
         
        
    def __getitem__(self, index):
        s = self.subjects.iloc[index]
        
        if self.modality=='cdata':
            mri = torch.zeros(1)
        else:
            mri = nii2img(s['mri'])
            mri = np.array([mri], dtype=np.float32)
            mri = self.transform(mri)
            
        return (mri, s[CLINICALS].to_numpy(dtype=np.float32), s['label'], s['task'])        

In [None]:
class DataBuilder():    
    
    # dataset path should be:
    #   sdir/xxx_S_xxxx_Ixxxxx.nii
    #   sdir/xml/xxx_S_xxxx_Ixxxxx.xml
    #   ./csv/ADNIMERGE.csv
    #   ./csv/DXSUM_pMCI.csv
    
    def __init__(self, sdir):
        self.sdir = sdir
        print('Subjects DIR:',sdir,'\n')
        self.xmls = sorted(glob.glob(sdir+'/xml/'+'*.xml'))
        self.subjects = pd.DataFrame(columns=['label', 'task', 'mri', *CLINICALS])
        self.csv_demographic = CSV('csv/ADNIMERGE.csv')
        self.csv_dxsum = CSV('csv/DXSUM_pMCI.csv')
        
        
    def getSubject(self, xml):
        '''return subject info'''
        
        # find metadata
        root = ET.parse(xml).getroot()
        subject = root.find('project').find('subject')
        
        # build subject info
        sid   = subject.find('subjectIdentifier').text
        label = subject.find('researchGroup').text
        label = self.checkMCI(sid, label)
        CDR   = subject.find('visit')[3].find('component').find('assessmentScore').text
        mri   = glob.glob(self.sdir+'/'+sid+'*strip2.nii.gz')[0]
        task  = 1 if (label=='AD' or label=='CN') else 2
        cdata = self.csv_demographic.getItems(sid, CLINICALS).astype(np.float, errors='ignore').tolist()        
        return sid, [label, task, mri, *cdata]

    
    def checkMCI(self, sid, label):
        vis = self.csv_dxsum.getItem(sid, 'VISCODE2')
        phase = self.csv_dxsum.getItem(sid, 'Phase')
        
        if phase=='ADNI1' and vis is not None:
            if vis=='bl':
                return 'pMCI'
            elif int(vis[1:])>36:
                return 'sMCI'
            else:
                return 'pMCI'
        
        if 'MCI' in label:
            label = 'sMCI'
            
        return label
        
            
    def build(self):
        for xml in self.xmls:
            sid, data = self.getSubject(xml)
            self.subjects.loc[sid] = data
        self.subjects = self.subjects.apply(pd.to_numeric, errors='ignore') #.dropna()
        self.subjects = self.imputation()
        return self
        
        
    def get(self, label):      
        # get subjects corresponds to label
        subjects = self.normalize()
        subjects = subjects.loc[subjects['label']==label]
        subjects = subjects.replace(label, CLASS[label])
        subjects = subjects.replace('Male', 0.)
        subjects = subjects.replace('Female', 1.)  
        return subjects
        
    
    def imputation(self):
        # preprocessing: data imputation
        subjects = self.subjects.copy()
        cdata = subjects[CLINICALS[2:]]
        items = cdata.columns
        sids  = cdata.index
        cdata = KNNImputer(n_neighbors=5).fit_transform(cdata)
        cdata = pd.DataFrame(cdata, index=sids, columns=items)
        subjects.update(cdata, overwrite=True)
        return subjects
    
    
    def normalize(self):
        # preprocessing: normalize to [0,1]
        subjects = self.subjects.copy()
        cdata = subjects[CLINICALS[2:]]
        cdata = cdata/cdata.max(axis=0)
        subjects.update(cdata)
        return subjects
    
    
    def print(self):
        # print demographic        
        info = self.subjects.groupby('label').agg(
            Total  = ('label',    'count'),
            Gender = ('PTGENDER', lambda x: '{}/{}'.format((x=='Male').sum(),(x=='Female').sum())),
            APOE4  = ('APOE4',    lambda x: '{:2d}/{:2d}/{:2d}'.format((x==0).sum(),(x==1).sum(),(x==2).sum())),
        )
        
        for c in CLINICALS[2:]:
            nums = self.subjects.groupby('label')[c].agg(['mean', 'std'])
            nums['mean'] = nums['mean'].map('{:.1f}'.format)
            nums['std']  = nums['std'].map('{:.1f}'.format)
            nums = nums['mean'] + '±' + nums['std']
            nums = nums.to_frame()
            nums.columns = [c]
            info = info.join(nums)
            
        pd.set_option("float_format", '{:,.2f}'.format)
        display.display(info)
        

In [None]:
class CSV():
    def __init__(self, filename):
        self.df = pd.read_csv(filename)
        self.df = self.df.set_index('PTID')
        
    def getItems(self, sid, items):
        return self.df.loc[sid, items]
        
    def getItem(self, sid, item):
        try:
            return self.df.at[sid, item]
        except:
            return None

In [None]:
if __name__ == '__main__':
    SDIR = '/media/mbl/HDD/Screening_PREPSD3'
    dataset = DataBuilder(SDIR).build()
    subjects = pd.concat([dataset.get('CN'), dataset.get('AD'), dataset.get('sMCI'), dataset.get('pMCI')])
    dataset.print()
    print('Total subjects:',len(subjects))