# Data Exploring

In [1]:
import os
import glob

root = '/mnt/project/'
proj_root = root + 'UKB_Main/yihwang/RegionBAE'
brain_root = os.path.join(root, 'Bulk/Brain MRI/T1/')

In [None]:
files = list()

for i in range(10, 60):
    files.extend(glob.glob(f'{brain_root}/{i}/*'))

print(len(files))
print(files[0])
print(files[-1])

### 20263 MRI image visualizing

In [None]:
!pip install nibabel ipywidgets numpy cv2 matplotlib zipfile
!pipi install SimpleITK ants

In [None]:
%matplotlib inline
import nibabel as nib
import zipfile
import sys
import ants

helpers_path = proj_root + 'src/helpers.py'
sys.path.append(os.path.abspath(helpers_path))
from helpers import *

test_20263_zip = '/mnt/project/Bulk/Brain MRI/T1//59/5999855_20263_2_0.zip'
# unzip 필요


data_20263 = [
    'orig.mgz',
    'brain.mgz',
    'brainmask.mgz',
    'FLAIR.mgz',
    'norm.mgz',
    'T1.mgz',
    'wm.mgz',
    'aseg.mgz'
]

for image in data_20263:
    path = root_20263 + image
    print(path)
    orig = nib.load(path)
    data = orig.get_fdata()
    explore_3D_array(data)

# Phecode

In [None]:
codes = set()
files_20252 = []
files_20263 = []
idx = list()
seen = set()
for file in files:
    file_split = file.split('_')
    id = file_split[0].split('/')[-1]
    id = int(id)
    if id not in seen:
        idx.append(id)
        seen.add(id)
    code = file_split[1]
    codes.add(code)
    if code == '20252':
        files_20252.append(file)
    else:
        files_20263.append(file)

print(codes)
print(files_20252[0])
print(files_20252[1])
print(files_20252[2])
print(f'Total number of samples with MRI data: {len(idx)}')
print(f'Total number of 20252 data: len(files_20252)') 
print(f'Total number of 20263 data: len(files_20263)') 

In [None]:
!pip install pandas

In [None]:
import pandas as pd

phe_b = 'PheCode_ICD10_withCovar_221012.txt' # 1.6G
phe_root = os.path.join(root, 'WGS/Pheno', phe_b) 

chunk_size = 10000  # 한 번에 읽을 line 수
chunks = []
for chunk in pd.read_csv(phe_root, sep=' ', chunksize=chunk_size):
    chunks.append(chunk)

phecode = pd.concat(chunks, ignore_index=True)
phecode

In [None]:
brain_mri_phe = phecode.loc[phecode.IID.isin(idx)].reset_index(drop=True)
brain_mri_phe

In [None]:
brain_mri_phe_id = brain_mri_phe.set_index('IID', inplace=False)
brain_mri_phe_id

In [None]:
mri_disease = dict()

def code_filtering(data):
    unique_roots = []
    seen = set()
    
    for item in data:
        root = item.split('.') 
        value = ''.join(root[1])
        if value not in seen:  
            unique_roots.append(value)
            seen.add(value)
    
    return unique_roots

for eid in brain_mri_phe['IID']:
    age = brain_mri_phe[brain_mri_phe['IID']==eid]['Age'].iloc[0]
    sex = brain_mri_phe[brain_mri_phe['IID']==eid]['Sex'].iloc[0]
    sample = brain_mri_phe_id.loc[eid]
    disease_codes = sample[sample > 0].index.tolist()[13:]
    filtered_disease_codes = code_filtering(disease_codes)
    disease = 1
    if len(disease_codes) == 0:
        disease = 0
    mri_disease[eid] = [age, sex, disease, filtered_disease_codes, disease_codes]

# 딕셔너리를 DataFrame으로 변환
mri_disease_df = pd.DataFrame([
    {'id': key, 'age': value[0], 'sex': value[1], 'disease': value[2], 'filtered_phe_codes': value[3], 'phe_codes': value[4]}
    for key, value in mri_disease.items()
])
mri_disease_df

In [None]:
mri_disease_df['disease'].value_counts()

In [None]:
def age_dist_plot(cn=None, disease=None):
    if cn:
        plt.hist(cn, bins=30, alpha=0.5, label='CN', density=True, edgecolor='black')
    if disease:
        plt.hist(disease, bins=30, alpha=0.5, label='Disease', density=True, edgecolor='black')
    plt.title('Age Distribution by group')
    plt.xlabel('Age')
    plt.ylabel('Density')
    plt.legend(loc='upper right')
    plt.show()

cn_ages = mri_disease_df['age'][mri_disease_df['disease']==0]
disease_ages = mri_disease_df['age'][mri_disease_df['disease']==1]
age_dist_plot(cn_ages, disease_ages)

In [None]:
phe_esm = pd.read_excel(root+'/yihwang/phecode_esm.xlsx', sheet_name=1)
phe_esm

In [None]:
disease_group = phe_esm['group'].unique()
disease_group

In [None]:
def id_filtering_group(group_name):
    group_phe = sorted(phe_esm['phecode'][phe_esm['group']==group_name].tolist())
    print(group_phe)

    disease_data = dict()
    for phecode in group_phe:
        disease_name = phe_esm['name'][phe_esm['phecode']==phecode].iloc[0]
        print(phecode, disease_name)
        filtered_id = id_filtering_single(phecode)
        disease_data[phecode] = len(filtered_id)

    id_disease_disorders = list()
    for key, value in mri_disease.items():
        for code in group_phe:
            if str(code) in value[3]:
                id_disease_disorders.append(key)
                break
    print(f"\nTotal number of {group_name}: {len(id_disease_disorders)}")

    group_plot(disease_data)
    
    return id_disease_disorders


def id_filtering_single(disease_code):
    id_disease = list()
    disease_name = phe_esm['name'][phe_esm['phecode']==disease_code].iloc[0]
    # print(disease_code, disease_name)
    
    for key, value in mri_disease.items():
        if str(disease_code) in value[3]:
            id_disease.append(key)
            
    # print(f"\nTotal number of {disease_name}: {len(id_disease)}")
    return id_disease

def group_plot(data):
    x = list(data.keys())
    y = list(data.values())
    
    # 플롯 그리기
    plt.figure(figsize=(20, 10))
    plt.xticks(range(min(x), max(x), 1), fontsize=8)
    plt.bar(x, y, color='skyblue')  # 막대 그래프를 사용
    plt.xlabel('Key')  # x축 레이블
    plt.ylabel('Count')  # y축 레이블
    plt.title('Count per Key')  # 그래프 제목
    plt.show()

In [None]:
id_neurological_disorders = id_filtering_group(disease_group[6])

In [None]:
id_mental_disorders = id_filtering_group(disease_group[0])

In [None]:
id_circulatory_disorders = id_filtering_group(disease_group[2])

In [None]:
id_metabolic = id_filtering_group(disease_group[4])

In [None]:
id_diabete = id_filtering_single(250)
print(len(id_diabete))

In [None]:
id_digestive = id_filtering_group(disease_group[1]) # 550: 탈장, 531: 식도 질환, 561: 소화기계 관련 증상, 564: 소화 기능 장애, 535: 위 & 십이지장 염증

# UKBB Alzheimer's Disease Region-Specific Brain Age Prediction

In [2]:
id_ad = id_filtering_single(0)
print(len(id_ad))
age_dist_plot(None, id_ad['age'])

NameError: name 'id_filtering_single' is not defined

In [None]:
def mk_csv(filtered_id, csv_name):
    mri_disease = brain_img_df[brain_img_df['id'].isin(filtered_id)].reset_index(drop=True)
    save_path = os.path.join(proj_root, 'data/ukbb_phe_' + csv_name + '.csv')
    mri_disease.to_csv(save_path)

mk_csv(id_ad, 'ad')

### 1. Preprocessing

In [None]:
!pip install tqdm

In [None]:
save_path = proj_root + 'data/ukbb_region/disease'
regions = ['caudate', 'cerebellum', 'frontal_lobe', 'insula', 'occipital_lobe', 'parietal_lobe', 'putamen', 'temporal_lobe', 'thalamus']
diseases = ['ad']

for code in diseases:
    ukb_icd = pd.read_csv(phecode_path)

    csv_data = {
        'subjectID': [],
        'imgs': [],
        'mask': [],
        'age': []
    }
    
    for region in regions:
        csv_data[region] = []
        csv_data[f'{region}_mask'] = []
        
    for i in range(50):
        subject = ukb_icd['id'][i]
    
        csv_data['subjectID'].append(ukb_icd['id'][i])
        csv_data['age'].append(ukb_icd['Age'][i])
        csv_data['imgs'].append(f'{save_path}/{subject}/T1w_registered.nii.gz')
        csv_data['mask'].append(f'{save_path}/{subject}/T1w_brain_mask_registered.nii.gz')
        csv_data['caudate'].append(f'{save_path}/{subject}/region_1.nii.gz')
        csv_data['caudate_mask'].append(f'{save_path}/3068434/region_1_mask.nii.gz')
        csv_data['cerebellum'].append(f'{save_path}/{subject}/region_2.nii.gz')
        csv_data['cerebellum_mask'].append(f'{save_path}/{subject}/region_2_mask.nii.gz')
        csv_data['frontal_lobe'].append(f'{save_path}/{subject}/region_3.nii.gz')
        csv_data['frontal_lobe_mask'].append(f'{save_path}/{subject}/region_3_mask.nii.gz')
        csv_data['insula'].append(f'{save_path}/{subject}/region_4.nii.gz')
        csv_data['insula_mask'].append(f'{save_path}/{subject}/region_4_mask.nii.gz')
        csv_data['occipital_lobe'].append(f'{save_path}/{subject}/region_5.nii.gz')
        csv_data['occipital_lobe_mask'].append(f'{save_path}/{subject}/region_5_mask.nii.gz')
        csv_data['parietal_lobe'].append(f'{save_path}/{subject}/region_6.nii.gz')
        csv_data['parietal_lobe_mask'].append(f'{save_path}/{subject}/region_6_mask.nii.gz')
        csv_data['putamen'].append(f'{save_path}/{subject}/region_7.nii.gz')
        csv_data['putamen_mask'].append(f'{save_path}/{subject}/region_7_mask.nii.gz')
        csv_data['temporal_lobe'].append(f'{save_path}/{subject}/region_8.nii.gz')
        csv_data['temporal_lobe_mask'].append(f'{save_path}/{subject}/region_8_mask.nii.gz')
        csv_data['thalamus'].append(f'{save_path}/{subject}/region_9.nii.gz')
        csv_data['thalamus_mask'].append(f'{save_path}/{subject}/region_9_mask.nii.gz')

    df = pd.DataFrame(csv_data)
    csv_save_path = proj_root + f'data/csv/{code}_region.csv'
    df.to_csv(csv_save_path, index=False)

In [None]:
def process_subject(root, subject, atlas_template_path, save_root, regions, df, trial, tmp_path, skip):
    subject = subject
    img_dir = os.path.join(root, subject, 'T1/T1_brain_to_MNI.nii.gz')
    save_path = os.path.join(save_root, subject)
    print(subject, save_path, flush=True)

    # save_path가 이미 존재하면 함수 종료
    curr_save_path = glob.glob(save_path + '/*')
    if os.path.exists(save_path) and len(curr_save_path) == 20:
        skip += 1
        print(f"{skip}th Skipping {subject}, as save_path already exists.", flush=True)
        return skip

    os.makedirs(save_path, exist_ok=True)

    image = ants.image_read(img_dir)
    image = ants.resample_image(image, (128, 128, 128), 1, 0)
    mask = ants.get_mask(image)

    img_path = save_path + '/T1w_registered.nii.gz'
    mask_path = save_path + '/T1w_brain_mask_registered.nii.gz'

    ants.image_write(image, img_path)
    ants.image_write(mask, mask_path)

    transformation = ants.registration(
        fixed=image,
        moving=ants.image_read(atlas_template_path), 
        type_of_transform='SyN',
        outprefix=tmp_path
    )
    registered_atlas_ants = transformation['warpedmovout']
    gc.collect()

    for region_idx in range(1, 10):
        region_mask = registered_atlas_ants == region_idx
        region_mask_dilated = ants.morphology(region_mask, radius=4, operation='dilate', mtype='binary')

        extracted_region = image.numpy() * region_mask_dilated.numpy()
        extracted_region_ants = ants.from_numpy(extracted_region)

        region_path = save_path + f'/region_{region_idx}.nii.gz'
        # region_mask_path = save_path + f'/region_{region_idx}_mask.nii.gz'
        ants.image_write(extracted_region_ants, region_path)
        # ants.image_write(region_mask_dilated, region_mask_path)

        del region_mask, region_mask_dilated, extracted_region
        gc.collect()

In [None]:
ad_region_df = proj_root + f'data/csv/ad_region.csv'
img_save_root = proj_root + f'data/ukbb_region/disease'
os.makedirs(img_save_root, exist_ok=True)
tmp_path = proj_root + f'data/ukbb_region/tmp'

atlas_template_path = '/media/leelabsg-storage1/yein/research/data/template/MNI-maxprob-thr0-1mm.nii.gz'
regions = ['caudate', 'cerebellum', 'frontal_lobe', 'insula', 'occipital_lobe', 'parietal_lobe', 'putamen', 'temporal_lobe', 'thalamus']
skip = 0 
for i in tqdm(range(len(ad_region_df))):
    subject = str(df['id'][i])
    skip = process_subject(root, subject, atlas_template_path, img_save_root, regions, ad_region_df, tmp_path, skip)
    gc.collect()

# inference

In [None]:
!pip install scikit-learn torch torchsummary torchvision

In [None]:
from pathlib import Path
import pickle
import random

from sklearn.model_selection import KFold

dataset_path = proj_root + 'src/dataset.py'
sys.path.append(os.path.abspath(dataset_path))
from dataset import *

CNN_path = proj_root + 'src/CNN.py'
sys.path.append(os.path.abspath(CNN_path))
from CNN import *

CNN_Trainer_path = proj_root + 'src/CNN_Trainer.py'
sys.path.append(os.path.abspath(CNN_Trainer_path))
from CNN_Trainer import *

lr_scheduler = proj_root + 'src/lr_scheduler.py'
sys.path.append(os.path.abspath(lr_scheduler))
from learning_rate import lr_scheduler as lr

early_stopping_path = proj_root + 'src/early_stopping.py'
sys.path.append(os.path.abspath(early_stopping_path))
from early_stopping import EarlyStopping

import torch
from torch import nn
import time

In [None]:
# setting & hypterparameters
IMGS = proj_root + f'data/csv/ad_region.csv' # preprocessed global and 9 regions mri paths
dataset_df = pd.read_csv(IMGS)
DATASET = 'ukbb'
BATCH_SIZE = 4
EPOCHS = 40
RESULTS_FOLDER = proj_root + f'data/csv/test/ad'
MODEL_SAVE_FOLDER = proj_root + 'model'
INPUT_SIZE = (1, 128, 128, 128) 
LEARNING_RATE = 0.0001
LEARNING_RATE_Scheduler =1
N_WORKERS = 8
REGIONS = 0
ROI = 'imgs'
MODEL_LOAD_FOLDER = proj_root + 'model'
MODEL_LOAD = 1
MODEL_LOAD_EPOCH = 40
# DATA_SIZE = config.data_size
DATA_SIZE = len(dataset_df)
MODE = 'test'
PATIENCE = 0
ngpus = torch.cuda.device_count()
GPU = ngpus

# setting log
print("="* 20, " Setting ", "="* 20)
print("Dataset :                 ", DATASET)
print("Mode :                    ", MODE)
print("GPU :                     ", GPU)
print("Number of gpus :          ", ngpus)
print("Batch size :             ", BATCH_SIZE)
print("Data size:               ", DATA_SIZE)
print("Epochs :                 ", EPOCHS)
print("Learning Rate :          ", LEARNING_RATE)
print("Early Stopping Patience :", PATIENCE)
print("# of Workers  :          ", N_WORKERS)
print("Region of Interest :     ", ROI)
print("Model Save Path:         ", MODEL_SAVE_FOLDER)
print("Loaded Model Epoch:      ", MODEL_LOAD_EPOCH)
print("="* 50)

In [None]:
# create our k-folds
kf = KFold(n_splits=4, random_state=7, shuffle=True)
cv_num = 0
# obtain the indices for our dataset
dataset_indices = list(dataset_df.index)[:DATA_SIZE]

# loop over each fold
for train_indices, valid_indices in kf.split(dataset_indices):

    print('\n<<< StratifiedKFold: {0}/{1} >>>'.format(cv_num+1, 4))
    
    # create a new dataset for this fold
    # train_dataset = Region_Dataset(config.root, dataset_df, train_indices, ROI)
    # valid_dataset = Region_Dataset(config.root, dataset_df, valid_indices, ROI)
    train_dataset = Region_Dataset(dataset_df, train_indices, ROI)
    valid_dataset = Region_Dataset(dataset_df, valid_indices, ROI)    
    dataloader_train = DataLoader(train_dataset, 
                                batch_size=BATCH_SIZE, 
                                sampler=RandomSampler(train_dataset),
                                collate_fn=train_dataset.collate_fn,
                                pin_memory=True,
                                num_workers=N_WORKERS)
    dataloader_valid = DataLoader(valid_dataset, 
                                batch_size=BATCH_SIZE, 
                                sampler=SequentialSampler(valid_dataset),
                                collate_fn=valid_dataset.collate_fn,
                                pin_memory=True,
                                num_workers=N_WORKERS)
    
    print("Train Dataset & Validation Dataset size: ", len(dataloader_train), len(dataloader_valid))
    print(valid_indices)

    # Define model and optimizer
    model = CNN(in_channels=1).cuda()
    if not GPU:
        model = CNN(in_channels=1).to('cpu')
    # Apply the weight_initialiation
    model.apply(initialize_weights)
    model = torch.nn.DataParallel(model) # use with multi-gpu environment
    # summary(model, input_size=INPUT_SIZE, device="cuda") # model-summary

    if LEARNING_RATE_Scheduler == 0:
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=config.weight_decay)
        scheduler = None
    else:
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
        t_0 = int(len(train_indices) // BATCH_SIZE // 6)
        scheduler = lr.CustomCosineAnnealingWarmUpRestarts(optimizer,T_0= t_0, T_up=10, T_mult=2, eta_max=1e-3, gamma=0.5)

    # Loss function
    mse_criterion = torch.nn.MSELoss()
    mae_criterion = torch.nn.L1Loss()
    
    # ------------------------ Train the model
    if MODE == 'train': 

        # Early Stopping
        if PATIENCE == 0:
            EARLY_STOPPING = None
        else:
            EARLY_STOPPING = EarlyStopping(patience=config.patience, verbose=True)

        trainer = CNN_Trainer(model=model, 
                            model_load_folder=MODEL_LOAD_FOLDER,
                            model_save_folder=MODEL_SAVE_FOLDER,
                            results_folder=RESULTS_FOLDER,
                            dataloader_train=dataloader_train, 
                            dataloader_valid=dataloader_valid, 
                            dataloader_test=None,
                            epochs=EPOCHS, 
                            optimizer=optimizer, 
                            early_stopping=EARLY_STOPPING,
                            scheduler=scheduler,
                            cv_num=cv_num,
                            region=ROI,
                            model_load=MODEL_LOAD)
        # Model Loading
        if MODEL_LOAD == 1:
            trainer.load(cv_num, MODEL_LOAD_EPOCH, GPU) # pre-trained model load
            
        train_start = time.time()
        trainer.train() # This will create the lists as instance variables

        # Now you can access the lists as:
        train_mse_list = trainer.train_mse_list
        train_mae_list = trainer.train_mae_list
        valid_mse_list = trainer.valid_mse_list
        valid_mae_list = trainer.valid_mae_list

        print("train_mse_list: ", train_mse_list)
        print("train_mae_list: ", train_mae_list)
        print("valid_mse_list: ", valid_mse_list)
        print("valid_mae_list: ", valid_mae_list)
        train_end = time.time()

        print(f"\nElapsed time for one epoch in cv: {(train_end - train_start) / 60:.0f} minutes")
    
    # ------------------------ Test the model
    else:
        
        pred_age_data = dict()
        true_age_data = dict()
        feature_data = dict()

        results_folder = ''
    
        for _, v in REGIONS.items():      

            model_load_folder = os.path.join(MODEL_LOAD_FOLDER, v)

            # model_save folder & results folder related
            if MODE == 'test':
                results_folder = os.path.join(RESULTS_FOLDER, str(cv_num))
            else: # MODE == 'test_tf'
                results_folder = os.path.join(RESULTS_FOLDER + '_tf', str(cv_num))

            # test_dataset related
            test_dataset = Region_Dataset(dataset_df, dataset_indices, v)
            dataloader_test = DataLoader(test_dataset, 
                                        batch_size=BATCH_SIZE, 
                                        sampler=SequentialSampler(test_dataset),
                                        collate_fn=test_dataset.collate_fn,
                                        pin_memory=True,
                                        num_workers=N_WORKERS)
            if (DATASET == 'ukbb' and MODE == 'test') or (DATASET == 'adni' and MODE == 'test_tf'):
                dataloader_test = dataloader_valid

            trainer = CNN_Trainer(model=model, 
                                model_load_folder = model_load_folder,
                                model_save_folder=MODEL_SAVE_FOLDER,
                                results_folder=results_folder,
                                dataloader_train=None, 
                                dataloader_valid=None, 
                                dataloader_test=dataloader_test,
                                epochs=EPOCHS, 
                                optimizer=optimizer, 
                                early_stopping=None,
                                scheduler=scheduler,
                                cv_num=cv_num,
                                region=ROI,
                                model_load=MODEL_LOAD)

            trainer.load(cv_num, MODEL_LOAD_EPOCH, GPU) # pre-trained model load
            pred_ages, true_ages, features = trainer.test() # test

            pred_age_data.setdefault(v, [])
            true_age_data.setdefault(v, [])
            feature_data.setdefault(v, [])

            pred_age_data[v].extend(pred_ages)
            true_age_data[v].extend(true_ages)
            feature_data[v].extend(features)

        # save the data
        trainer.test_age_data_extraction(pred_age_data,
                                    true_age_data,
                                    feature_data)

# PAD Analysis

In [None]:
!pip install seaborn scipy

import seaborn as sns
import matplotlib.colors as mcolors
import scipy.stats as stats
from sklearn.linear_model import LinearRegression

regions = {0: 'imgs', 1: 'caudate', 2: 'cerebellum', 3: 'frontal_lobe', 4: 'insula', 5: 'occipital_lobe', 6: 'parietal_lobe', 7: 'putamen', 8: 'temporal_lobe', 9: 'thalamus'}
root = RESULTS_FOLDER
age_diff_groups = {}
cn_size = DATA_SIZE
disease_l = 'mental'

In [None]:
def bias_correction(true, pred):
    reg = LinearRegression().fit(true.reshape(-1, 1), pred) # reshape to 2D array
    beta_0 = reg.intercept_
    beta_1 = reg.coef_[0]

    corrected_pred = (pred - beta_0) / beta_1
    return corrected_pred

def pickle_load(root):
    map = {v: {'pred_ages': [], 'true_ages': []} for _, v in regions.items()}

    # true_ages 4-cv model results gathering for averaging them
    for cv_num in range(4):
        pickel_path = os.path.join(root, str(cv_num), 'true_ages.pkl')
        if os.path.exists(pickel_path):
            with open(pickel_path, 'rb') as file:
                curr_cv = pickle.load(file)  # cv-split size의 dict (region: age_lists)  # age_lists length: 630
                for region, age_list in curr_cv.items():
                    map[region]['true_ages'] = age_list
    
    # pred_ages 4-cv model results gathering for averaging them
    for cv_num in range(4):
        pickle_path = os.path.join(root, str(cv_num), 'pred_ages.pkl')
        if os.path.exists(pickle_path):
            with open(pickle_path, 'rb') as file:
                curr_cv = pickle.load(file)
                for region, age_list in curr_cv.items():
                    map[region]['pred_ages'].append(age_list)

    for region, _ in map.items():
        pred_ages_lists = map[region]['pred_ages']
        pred_avg_age = [sum(values) / len(values) for values in zip(*pred_ages_lists)]
        map[region]['pred_ages'] = pred_avg_age

    
    for region, age_list in map.items():
        print(region, len(map[region]['true_ages']), len(map[region]['pred_ages']))

    print("=" * 30)
    return map          

def age_dist_plot(region, cn, disease):
    plt.figure(figsize=(12, 8))
    plt.hist(cn[region]['true_ages'], bins=30, alpha=0.6, label='CN', density=True, edgecolor='black', color='#2986cc')
    plt.hist(disease[region]['true_ages'], bins=30, alpha=0.4, label=disease_l, density=True, edgecolor='black', color='#cc0000')
    plt.title('Age Distribution by Group', fontsize=16)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.xlabel('Age', fontsize=14)
    plt.ylabel('Density', fontsize=14)
    plt.legend(loc='upper right', fontsize=12)
    plt.show()

def age_dist_plot_single_group(region, group):
    plt.figure(figsize=(12, 8))
    plt.hist(group[region]['true_ages'], bins=30, alpha=0.6, label='True', density=True, edgecolor='black', color='#2986cc')
    plt.hist(group[region]['pred_ages'], bins=30, alpha=0.4, label='Pred', density=True, edgecolor='black', color='#ffa833')
    plt.hist(group[region]['corrected_pred_ages'], bins=30, alpha=0.4, label='Corrected Pred', density=True, edgecolor='black', color='#cc0000')
    plt.title('Age Distribution by Group', fontsize=16)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.xlabel('Age', fontsize=14)
    plt.ylabel('Density', fontsize=14)
    plt.legend(loc='upper right', fontsize=12)
    plt.show()

def age_dist_re_counts(region, cn, re_cn):
    plt.figure(figsize=(12, 8))
    # density=True를 제거하여 실제 값이 표시되도록 합니다.
    plt.hist(cn[region]['true_ages'], bins=30, alpha=0.6, label='CN', edgecolor='black', color='#2986cc')
    plt.hist(re_cn[region]['true_ages'], bins=30, alpha=0.4, label='Resampled_CN', edgecolor='black', color='#ffa833')
    plt.title('Age Distribution by Group', fontsize=16)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.xlabel('Age', fontsize=14)
    plt.ylabel('Counts', fontsize=14)  # "Density"를 "Counts"로 변경
    plt.legend(loc='upper right', fontsize=12)
    plt.show()

def age_dist_plot_counts(region, cn, disease):
    plt.figure(figsize=(12, 8))
    plt.hist(cn[region]['true_ages'], bins=30, alpha=0.6, label='CN', edgecolor='black', color='#2986cc')
    plt.hist(disease[region]['true_ages'], bins=30, alpha=0.4, label=disease_l, edgecolor='black', color='#cc0000')
    plt.title('Age Distribution by Group', fontsize=16)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.xlabel('Age', fontsize=14)
    plt.ylabel('Counts', fontsize=14)
    plt.legend(loc='upper right', fontsize=12)
    plt.show()

def age_dist_plot_pred(region, cn, disease):
    plt.figure(figsize=(12, 8))
    plt.hist(cn[region]['pred_ages'], bins=30, alpha=0.6, label='CN', density=True, edgecolor='black', color='#2986cc')
    plt.hist(disease[region]['pred_ages'], bins=30, alpha=0.4, label=disease_l, density=True, edgecolor='black', color='#cc0000')
    plt.title('Age Distribution by Group', fontsize=16)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.xlabel('Age', fontsize=14)
    plt.ylabel('Density', fontsize=14)
    plt.legend(loc='upper right', fontsize=12)
    plt.show()

def age_dist_plot_pred_counts(region, cn, disease):
    plt.figure(figsize=(12, 8))
    plt.hist(cn[region]['pred_ages'], bins=30, alpha=0.6, label='CN', edgecolor='black', color='#2986cc')
    plt.hist(disease[region]['pred_ages'], bins=30, alpha=0.4, label=disease_l, edgecolor='black', color='#cc0000')
    plt.title('Age Distribution by Group', fontsize=16)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.xlabel('Age', fontsize=14)
    plt.ylabel('Counts', fontsize=14)
    plt.legend(loc='upper right', fontsize=12)
    plt.show()

def age_plot(data, region):
    true = data[region]['true_ages']
    pred = data[region]['pred_ages']
    if region == 'imgs':
        region = 'global'
    plt.figure(figsize=(7, 5))
    plt.scatter(true, pred, alpha=0.2)
    plt.plot([min(true), max(true)], [min(true), max(true)], color='red')  # y=x line
    plt.title(f'{region} ========== True Age and Predicted Age')
    plt.xlabel('True_Age')
    plt.ylabel('Predicted_Age')
    plt.show()

def age_plot_corrected(data, region):
    true = data[region]['true_ages']
    pred = data[region]['pred_ages']
    if region == 'imgs':
        region = 'global'
    plt.figure(figsize=(7, 5))
    plt.scatter(true, pred, alpha=0.2)
    plt.plot([min(true), max(true)], [min(true), max(true)], color='red')  # y=x line
    plt.title(f'{region} ========== True Age and Corrected Predicted Age')
    plt.xlabel('True_Age')
    plt.ylabel('Predicted_Age')
    plt.show()

def age_diff_plot(age_diff_dict, regions):
    fig, ax = plt.subplots(figsize=(12, 8))

    age_diff_max = max(abs(max(age_diff_dict.values())), abs(min(age_diff_dict.values())))

    cmap = plt.get_cmap('coolwarm')
    norm = mcolors.Normalize(vmin=-age_diff_max, vmax=age_diff_max)

    for region, age_diff in age_diff_dict.items():
        color = cmap(norm(age_diff))
        ax.barh(region, age_diff, color=color)

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax)
    cbar.set_label('PAD (years)')

    ax.set_xlabel('Predicted Age Difference (years)')
    ax.set_title('Regional Predicted Age Difference (PAD)')
    plt.show()

# 예측 연령 차이 계산
def calculate_age_diff_avg(data, region):
    pred_ages = np.array(data[region]['pred_ages'])
    true_ages = np.array(data[region]['true_ages'])
    return np.mean(pred_ages - true_ages)

def calculate_corrected_age_diff_avg(data, region):
    corrected_pred_ages = np.array(data[region]['corrected_pred_ages'])
    true_ages = np.array(data[region]['true_ages'])
    avg_diff = np.mean(corrected_pred_ages - true_ages)
    return np.round(avg_diff, 7)

def calculate_age_diff_dist(data, region):
    corrected_pred_ages = np.array(data[region]['corrected_pred_ages'])
    true_ages = np.array(data[region]['true_ages'])
    diff = calculate_corrected_age_diff_avg(data, region)

    plt.hist(corrected_pred_ages, bins=30, alpha=0.5, label='Predicted', density=True, edgecolor='black', color='red')
    plt.hist(true_ages, bins=30, alpha=0.5, label='True', density=True, edgecolor='black', color='grey')
    # plt.hist(diff, bins=30, alpha=0.5, label='Difference', density=True, edgecolor='black')
    plt.title('Age Difference Distribution')
    plt.xlabel('Age Difference')
    plt.ylabel('Density')
    plt.legend(loc='upper right')
    plt.show()


def cn_resampling(cn, disease):
    
    new_cn_ages = dict()

    for _, reg in regions.items():  
        print(reg, flush=True)
        
        cn_ages_np = np.array(cn[reg]['true_ages'])
        disease_ages_np = np.array(disease[reg]['true_ages'])
        print(cn_ages_np.size, disease_ages_np.size)
        
        # disease_ages의 분포를 추정하기 위해 Kernel Density Estimation (KDE) 사용
        kde = KernelDensity(kernel='gaussian', bandwidth=1.0).fit(disease_ages_np.reshape(-1, 1))
        
        # cn_ages에서 disease_ages의와 비슷한 분포를 가지도록 샘플링
        log_densities = kde.score_samples(cn_ages_np.reshape(-1, 1))
        probabilities = np.exp(log_densities)
        print(f'log_densities: {log_densities}, probabilities: {probabilities}')
        
        # cn_ages의 샘플링 확률을 disease_ages의 분포에 맞추어 정규화
        probabilities /= probabilities.sum()
        print(f'regularized probabilities: {probabilities}')
        
        # cn_ages의에서 disease_ages의와 비슷한 분포를 가지는 sample indices
        sampled_indices = np.random.choice(np.arange(len(cn_ages_np)), size=len(disease_ages_np), p=probabilities, replace=False)
        print(sampled_indices)
        print("="*30)
        
        cn_ages_true_np = np.array(cn[reg]['true_ages'])
        cn_ages_pred_np = np.array(cn[reg]['pred_ages'])
        print(cn_ages_true_np.shape, cn_ages_pred_np.shape, flush=True) # (25656,) (25656,)
        
        sampled_true_np = cn_ages_true_np[sampled_indices]
        sampled_pred_np = cn_ages_pred_np[sampled_indices]
       
        # 결과 확인
        print(f"Sampled size of True after Resampling: {len(sampled_true_np)}", flush=True) # 3118
        print(f"Sampled size of Pred after Resampling: {len(sampled_pred_np)}", flush=True) # 3118
        print("=" * 30)

        ages = {
            'true_ages': list(sampled_true_np),
            'pred_ages': list(sampled_pred_np)
        }
        new_cn_ages[reg] = ages

    return new_cn_ages

In [None]:
disease_root = root + f'ukb_{disease_l}'
disease_ages = pickle_load(disease_root)
disease_len = len(disease_ages['imgs']['true_ages'])
print(disease_len)

In [None]:
# 각 지역에 대한 예측 연령 차이 계산
age_diff_dict = {}
age_diff_groups[disease_l] = dict()

for key, region in regions.items():
    region_n = region
    if region == 'imgs':
        region_n = 'global'
    age_plot(disease_ages, region)
    age_plot_corrected(disease_ages, region)
    age_diff_dict[region_n] = calculate_corrected_age_diff_avg(disease_ages,region)
    age_diff_groups[disease_l][region_n] = age_diff_dict[region_n]

# 시각화
age_diff_plot(age_diff_dict, regions.values())

In [None]:
age_diff_groups['CN'] = {
  'global': -0.3092825,
  'caudate': -0.0997129,
  'cerebellum': -0.0510616,
  'frontal_lobe': -0.0776628,
  'insula': -0.1058889,
  'occipital_lobe': -0.0443206,
  'parietal_lobe': -0.067503,
  'putamen': -0.086337,
  'temporal_lobe': -0.0768889,
  'thalamus': -0.0631945
}

In [None]:
def multi_group_age_diff_plot(age_diff_dict, groups, regions):
    fig, ax = plt.subplots(figsize=(14, 8))

    colors = {
        groups[0]: '#2986cc',   
        groups[1]: '#cc0000'    
    }

    bar_width = 0.3  
    index = np.arange(len(regions))

    # 각 그룹에 대해 PAD 값을 시각화
    for i, group in enumerate(groups):
        pad_values = [age_diff_dict[group][region] for region in regions]
        bar_positions = index + i * bar_width
        
        bars = ax.barh(bar_positions, pad_values, bar_width, label=group, color=colors[group], edgecolor='black')
        
        # 막대에 값 표시 (애너테이션)
        for bar in bars:
            width = bar.get_width()
            label_x_pos = width + 0.1 if width < 0 else width - 0.1
            ax.text(label_x_pos, bar.get_y() + bar.get_height()/2, f'{width:.2f}', 
                    va='center', ha='center', color='black', fontsize=12)

    # y축 정렬 및 스타일 설정
    ax.set_yticks(index + bar_width * (len(groups) - 1) / 2)
    ax.set_yticklabels(regions, fontsize=14)
    ax.set_xlabel('Predicted Age Difference (years)', fontsize=14)
    ax.set_title('Regional Predicted Age Difference (PAD) by Group', fontsize=18)
    # 범례 순서 뒤집기
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles[::-1], labels[::-1], fontsize=14)
    
    # x축 눈금 폰트 크기 설정
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    # 그리드와 스타일 설정
    ax.grid(True, which='both', linestyle='--', linewidth=0.5)
    sns.despine(left=True, bottom=True)
    
    plt.tight_layout()
    plt.show()

regions_l = list(regions.values())
regions_l.pop(0)
regions_l.insert(0, 'global')
regions_l = regions_l[::-1]
multi_group_age_diff_plot(age_diff_groups, ['CN', disease_l], regions_l)