In [1]:
import os
import sys

import torch

import numpy as np
import pandas as pd

import nibabel as nib

import matplotlib.pyplot as plt

In [4]:
def split_train_test(df, split_ratio=0.8):
    datasets = df.Dataset.unique()
    train_df = pd.DataFrame()
    test_df = pd.DataFrame()

    for dataset in datasets:
        dataset_df = df[df['Dataset'] == dataset]
        dataset_df = dataset_df.sample(frac=1).reset_index(drop=True) # shuffle
        train_df = pd.concat([train_df, dataset_df.iloc[:int(len(dataset_df)*split_ratio)]])
        test_df = pd.concat([test_df, dataset_df.iloc[int(len(dataset_df)*split_ratio):]])

    train_df['mode'] = 'train'
    test_df['mode'] = 'val'
    print(f"Total: {len(train_df)+len(test_df)} || Train Sample: {len(train_df)}, Test Sample: {len(test_df)}")
    merge_df = pd.concat([train_df, test_df], axis=0)

    return merge_df

## Final EDA

In [None]:
long_df = pd.read_csv('/NFS/FutureBrainGen/data/long/long_old_HC_subj_phenotype_splited.csv')
long_df.rename({"SubID":"Subject"}, inplace=True, axis=1)
long_df.describe()

In [None]:
long_df['mode'].value_counts()

In [None]:
long_df.info()

In [None]:
cross_df = pd.read_csv('/NFS/FutureBrainGen/data/cross/cross_old_subj_phenotype_splited.csv')
cross_df.describe()

In [None]:
cross_df.info()

## EDA for Cross Sectional data

In [None]:
cross_df = pd.read_csv('/NFS/FutureBrainGen/data/cross/CrossSectional_included_file_v2.csv')
hc_cross_df = cross_df[cross_df['Group'] == 'HC']
hc_cross_df = hc_cross_df[hc_cross_df['Dataset'] != 'BGSP']
hc_cross_df = hc_cross_df[hc_cross_df['Dataset'] != 'BNU']
hc_cross_df = hc_cross_df[hc_cross_df['Dataset'] != 'RBP-L1']
hc_cross_df = hc_cross_df[hc_cross_df['Age'] >= 40]
hc_cross_df

In [None]:
hc_cross_df.Dataset.value_counts()

In [None]:
plt.hist(hc_cross_df['Age'], bins=10)
plt.show()

In [None]:
splited_hc_cross_df = split_train_test(hc_cross_df, 0.85)
splited_hc_cross_df

In [24]:
# hc_cross_df.to_csv("/NFS/FutureBrainGen/data/cross/cross_old_subj_phenotype.csv", index=False)
# splited_hc_cross_df.to_csv("/NFS/FutureBrainGen/data/cross/cross_old_subj_phenotype_splited.csv", index=False)

## Longitudinal EDA

In [None]:
long_df = pd.read_csv('/NFS/FutureBrainGen/data/long/long_phenotype_v2_clean_group.csv', index_col=0)
hc_long_df = long_df[(long_df['Group_B'] == 'HC') & (long_df['Group_F'] == 'HC')]
old_hc_long_df = hc_long_df[hc_long_df['Age_B'] >= 40]
old_hc_long_df.head(3)

In [None]:
splited_old_hc_long_df = split_train_test(old_hc_long_df, 0.94)
splited2_old_hc_long_df = split_train_test(splited_old_hc_long_df[splited_old_hc_long_df['mode']=='train'], 0.94)
splited2_old_hc_long_df_test = splited_old_hc_long_df[splited_old_hc_long_df['mode']=='val']
splited2_old_hc_long_df_test['mode']='test'

In [92]:
splited2_old_hc_long_df = pd.concat([splited2_old_hc_long_df, splited2_old_hc_long_df_test], axis=0)
splited2_old_hc_long_df

In [93]:
# splited2_old_hc_long_df.to_csv("/NFS/FutureBrainGen/data/long/long_old_HC_subj_phenotype_splited.csv", index=False)

In [94]:
plt.hist(old_hc_long_df['Age_B'], bins=10, alpha=0.7, color='navy', label='Baseline')
plt.hist(old_hc_long_df['Age_F'], bins=10, alpha=0.7, color='orange', label='Follow-up')
plt.text(85, 410, f"Total Session: {len(old_hc_long_df)}", fontsize=8)
plt.legend()
plt.show()

In [95]:
plt.bar(np.arange(1, 11, 1), old_hc_long_df['Interval'].value_counts().sort_index(),
        edgecolor='black', color='skyblue', label='Interval')
plt.xticks(np.arange(1, 11, 1))
plt.legend()
plt.show()

## Crop Image

In [25]:
MRIPATH = '/NFS/FutureBrainGen/data/long/down_img_1.7mm/'
PHENO = '/NFS/FutureBrainGen/data/cross/CrossSectional_included_file.csv'
MRILIST = os.listdir(MRIPATH)

temp = MRILIST

In [None]:
img = nib.load(MRIPATH + temp[0])
img = img.get_fdata()

# Convert the numpy array to a PyTorch tensor
img_data = torch.from_numpy(img).float()
img_data = img_data.unsqueeze(0)

# Get the original dimensions (assumed to be 3D data)
d, h, w = img_data.shape[1:]  # Shape without the channel

# Define the target crop size
target_d, target_h, target_w = (86, 106, 86)

# Calculate the start and end indices for cropping (crop from the center)
start_d = (d - target_d) // 2
start_h = (h - target_h) // 2
start_w = (w - target_w) // 2

img_data = img_data[:, start_d:start_d + target_d, start_h:start_h + target_h, start_w:start_w + target_w]

In [None]:
fig, ax = plt.subplots(1, 3)
ax[0].imshow(img_data[0, :, :, 42])
ax[1].imshow(img_data[0, 42, :, :])
ax[2].imshow(img_data[0, :, 53, :])

ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[2].set_xticks([])
ax[2].set_yticks([])


plt.show()