In [None]:
import os
import nibabel as nib
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import shutil
import time

import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

In [2]:
_type = 'long'
imgpath = f'/NFS/FutureBrainGen/data/{_type}/img'

In [None]:
with open('/NFS/FutureBrainGen/data/cross/bad_qc_files_cross.txt') as f:
    bad_mris = f.read().splitlines()
len(bad_mris)

In [None]:
imglist = os.listdir(imgpath)
imglist.sort()
imgs = [os.path.join(imgpath, i) for i in imglist]
print("Original number of images: ", len(imgs))

imgs = [i for i in imgs if i not in bad_mris]
print("Number of images after removing bad ones: ", len(imgs))

## QC for merge data

In [10]:
# imgs = imgs[7000:8000]
imgs = os.listdir('/NFS/FutureBrainGen/data/long/new_img')
imgs = imgs[2000:3000]

In [11]:
def img_genertor(img_list, num_split=100):
    need_list = len(img_list) // num_split
    for i in range(need_list):
        yield img_list[i*num_split:(i+1)*num_split]

img_list = img_genertor(imgs)

In [None]:
for idx, imgs in enumerate(img_list):
    print(f"########################## {idx} Split ##########################")
    fig, axs = plt.subplots(10, 10, figsize=(30, 30))

    for i in range(10):
        for j in range(10):
            img = nib.load(os.path.join('/NFS/FutureBrainGen/data/long/new_img', imgs[j*10+i]))
            img = img.get_fdata()
            axs[i][j].imshow(img[:,:,50], cmap='gray')
            axs[i][j].axis('off')
            axs[i][j].set_title(f'{imgs[j*10+i].split("/")[-1]}', fontsize=7)

    plt.tight_layout()
    plt.show()

## Filtered phenotype data

In [5]:
# Date Format: 1953-12-03_11_47_07.0

folder_format = 'Sag_3D_MP-RAGE'

mcsa_phenotype = pd.read_csv('/NFS/MRI/MCSA/phenotype/MCSA_Phenotype.csv')
use_col = ['mcsa_id', 'imagingdate', 'calc_age_vis', 'male', 'CDRSUM', 'fcogimpr'] # FCOGIMPR: 0=NORMAL, 1=MCI, 3=DEMENTIA

mcsa_phenotype = mcsa_phenotype[use_col]
mcsa_phenotype.dropna(subset=use_col, inplace=True, axis=0)
mcsa_phenotype.rename(columns={'mcsa_id': 'Subject',
                               'calc_age_vis': "Age",
                               'male':"Sex(1=m,2=f)",
                               'fcogimpr': "Group",
                               'imagingdate':'Date'}, inplace=True)
mcsa_phenotype['Sex(1=m,2=f)'].replace({1:1, 0:2}, inplace=True)
mcsa_phenotype = mcsa_phenotype[mcsa_phenotype['Group']!=4]
mcsa_phenotype['Group'].replace({0:'HC', 1:'MCI', 3:'Dementia'}, inplace=True)

In [None]:
mcsa_phenotype

In [None]:
MCSA_MRIPATH = '/NFS/MRI/MCSA/original/'
mcsa_subj_list = os.listdir(MCSA_MRIPATH)
mcsa_subj_list.sort()
mcsa_subj_list = mcsa_subj_list[:-5]

filter_df = pd.DataFrame(columns=['Subject', 'Date',])

for subj in mcsa_subj_list:
    subjPATH = os.path.join(MCSA_MRIPATH, subj, folder_format)
    mrsession = os.listdir(subjPATH)

    if '.DS_Store' in mrsession:
        mrsession.remove('.DS_Store')
        # shutil.rmtree(os.path.join(subjPATH, '.DS_Store'))
    try:
        mri_name = os.listdir(os.path.join('/NFS/MRI/MCSA/preprocess/', subj, 'mri'))
        mri_name = [i for i in mri_name if i.startswith('wm')]

        if not os.path.exists(os.path.join('/NFS/FutureBrainGen/data/cross/img/', mri_name[0])):
            shutil.copy(os.path.join('/NFS/MRI/MCSA/preprocess/', subj, 'mri', mri_name[0]), 
                        f'/NFS/FutureBrainGen/data/cross/img/')
            print(mri_name)
        # shutil.copy(os.path.join('/NFS/MRI/MCSA/preprocess/', subj, 'mri', mri_name[0]), 
                    # f'/NFS/FutureBrainGen/data/cross/img/')
        # time.sleep(1)
        # print(mri_name)
        
    except:
        mri_name = None

    date = mrsession[0].split('_')[0]
    filter_df = pd.concat([filter_df, pd.DataFrame({'Subject':[subj], 'Date':[date],
                                                    'File name': mri_name})], axis=0)

In [8]:
temp_df  = pd.merge(filter_df, mcsa_phenotype, on=['Subject', 'Date'], how='inner')

In [None]:
temp_df.info()

In [11]:
# temp_df.to_csv('/NFS/MRI/MCSA/phenotype/MCSA_Phenotype_filtered.csv', index=False)

In [None]:
temp_df['Dataset'] = 'MCSA'
temp_df.drop(columns=['Date',"CDRSUM"], inplace=True)
temp_df

## Merge Phenotype

In [13]:
add_bad_mris = [
    "wmAnnArbor_sub26099_scan_mprage_skullstripped.nii",
    "wmNewYork_sub44979_scan_mprage_skullstripped.nii",
    "wmOAS42041_MR_d3027_T1w.nii",
    "wmOAS42061_MR_d3014_T1w.nii",
    "wmOAS42074_MR_d3027_T1w.nii",
    "wmOAS42139_MR_d3025_T1w.nii",
    "wmOAS42137_MR_d3033_T1w.nii",
    "wmOAS42094_MR_d3027_T1w.nii",
    "wmOAS42160_MR_d3024_T1w.nii",
    "wmOAS42164_MR_d3021_T1w.nii",
    "wmOAS42041_MR_d3027_T1w.nii",
    "wmOAS42168_MR_d2966_T1w.nii",
    "wmOAS42238_MR_d3022_T1w.nii",
    "wmOAS42214_MR_d3027_T1w.nii",
    "wmOAS42201_MR_d3003_T1w.nii",
    "wmOAS42248_MR_d2966_T1w.nii",
    "wmOAS42249_MR_d2890_T1w.nii",
    "wmOAS42263_MR_d3801_T1w.nii",
    "wmOAS42269_MR_d3037_T1w.nii",
    "wmOAS42274_MR_d3016_T1w.nii",
    "wmOAS42323_MR_d3010_T1w.nii",
    "wmOAS42365_MR_d3016_T1w.nii",
    "wmOAS42375_MR_d3042_T1w.nii",
    "wmOAS42413_MR_d3022_T1w.nii",
    "wmOAS42409_MR_d3019_T1w.nii",
    "wmOAS42394_MR_d3021_T1w.nii",
    "wmOAS42377_MR_d3036_T1w.nii",
    "wmOAS42418_MR_d3024_T1w.nii",
    "wmOAS42451_MR_d3007_T1w.nii",
    "wmOAS42455_MR_d2249_T1w.nii",
    "wmOAS42545_MR_d3017_T1w.nii",
    "wmOAS42528_MR_d3038_T1w.nii",
    "wmOAS42483_MR_d3056_T1w.nii",
    "wmOAS42555_MR_d2865_T1w.nii",
    "wmOAS42573_MR_d3001_T1w.nii",
    "wmOAS42574_MR_d3034_T1w.nii",
    "wmOAS42689_MR_d3263_T1w.nii",
    "wmOAS42670_MR_d3088_T1w.nii",
    "wmOAS42667_MR_d3023_T1w.nii",
    "wmOAS42717_MR_d3036_T1w.nii"
]

bad_mris = bad_mris + add_bad_mris

In [None]:
len(bad_mris)

In [None]:
cross_phenotype = pd.read_csv('/NFS/FutureBrainGen/data/cross/CrossSectional_included_file.csv')
cross_phenotype

In [None]:
cross_temp = pd.concat([cross_phenotype, temp_df], axis=0)
cross_temp

In [None]:
cross_temp = cross_temp[~cross_temp['File name'].isin(bad_mris)]
cross_temp

In [18]:
cross_temp.to_csv('/NFS/FutureBrainGen/data/cross/CrossSectional_included_file_v2.csv', index=False)

## AIBL

In [None]:
BASE = "/NFS/MRI/AIBL/preprocess/cat12"
subjs = os.listdir(BASE)

In [5]:
for sub in subjs:
    subjfolder = os.path.join(BASE, sub, "mri")
    wm_files = [i for i in os.listdir(subjfolder) if i.startswith('wm')]

    for wm in wm_files:
        shutil.copy(os.path.join(subjfolder, wm), '/NFS/FutureBrainGen/data/long/new_img/')


## Visual QC

In [6]:
import os
import nibabel as nib
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
data = "MCSA"

DATAPATH = os.path.join("/NFS/MRI/", data)
cross_df = pd.read_csv(os.path.join(DATAPATH, 'phenotype','MCSA_Phenotype.csv'))
preproc_img_list = os.listdir(os.path.join(DATAPATH, 'preprocess'))
preproc_img_list.sort()

In [10]:
def find_mri_file(DATAPATH, subj_folder):
    BASE = os.path.join(DATAPATH, 'preprocess')
    IMGPATH = os.path.join(BASE, subj_folder, 'mri')
    
    try:
        imgs = [os.path.join(IMGPATH, mri) for mri in os.listdir(IMGPATH) if mri.startswith('wm')]
    except:
        imgs = []

    if len(imgs) == 0:
        print(f'No MRI found for {subj_folder}')
        return subj_folder
    else:
        return imgs

In [None]:
imgs = []

error_subjs = []

for subj_folder in tqdm(preproc_img_list):

    mris = find_mri_file(DATAPATH, subj_folder)

    if type(mris) == str:
        error_subjs.append(mris)
    else:
        for mri in mris:
            imgs.extend(mris)

In [12]:
imgs.sort()
imgs1 = imgs[:100]
imgs2 = imgs[100:200]
imgs3 = imgs[200:300]
imgs4 = imgs[300:400]
imgs5 = imgs[400:500]
imgs6 = imgs[500:600]
imgs7 = imgs[600:700]
imgs8 = imgs[700:800]


img_list =  [imgs1, imgs2, imgs3, imgs4, imgs5, imgs6, imgs7, imgs8]

In [None]:
for imgs in img_list:
    fig, axs = plt.subplots(10, 10, figsize=(20, 20))

    for i in range(10):
        for j in range(10):
            img = nib.load(os.path.join(imgs[j*10+i]))
            img = img.get_fdata()
            axs[i][j].imshow(img[:,:,50], cmap='gray')
            axs[i][j].axis('off')
            axs[i][j].set_title(f'{imgs[j*10+i]}', fontsize=6)
    plt.tight_layout()
    plt.show()