In [None]:
import os
import numpy as np
import scipy.io as sio
from glob import glob
from scipy.optimize import curve_fit

def quadratic(x, a, b, c):
    return a * x**2 + b * x + c

def detect_smile_or_sad_shape(data, threshold=0.7):
    x = np.linspace(0, 1, len(data))
    try:
        popt, _ = curve_fit(quadratic, x, data)
        a = popt[0]
        r_squared = 1 - (np.sum((data - quadratic(x, *popt))**2) / np.sum((data - np.mean(data))**2))

        if r_squared > threshold:
            return True if a != 0 else False
        return False
    except:
        return False

def collect_all_data(directory, all_data):
    for filename in glob(os.path.join(directory, '*.mat')):
        mat_data = sio.loadmat(filename)
        corrected_data = mat_data['detrended_data']

        if corrected_data.shape[1] != 48:
            print(f"Warning: Unexpected number of channels in {filename}. Shape: {corrected_data.shape}")
            continue

        print(f"Processing file: {filename}, Shape: {corrected_data.shape}")

        for channel in range(corrected_data.shape[1]):
            if not detect_smile_or_sad_shape(corrected_data[:, channel]):
                if channel not in all_data:
                    all_data[channel] = []
                all_data[channel].append(corrected_data[:, channel])

    return all_data

def normalize_and_save_data(directory, max_values, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for filename in glob(os.path.join(directory, '*.mat')):
        mat_data = sio.loadmat(filename)
        corrected_data = mat_data['detrended_data']

        if corrected_data.shape[1] != 48:
            print(f"Warning: Unexpected number of channels in {filename}. Shape: {corrected_data.shape}")
            continue

        for channel in range(corrected_data.shape[1]):
            corrected_data[:, channel] /= max_values[channel]

        save_path = os.path.join(save_dir, os.path.basename(filename))
        sio.savemat(save_path, {'corrected_data': corrected_data})
        print(f"Saved normalized data to: {save_path}")

# Base directory
base_dir ='' 

# Define subdirectories
subdirs = {


           'rest_hc' :'hc',
           'rest_mci' : 'mci'
    # 'stroop_mci': 'stroop_seg_mrg/baseline1_2/baseline1_2_onlydatas/MCI',
    # 'stroop_hc': 'stroop_seg_mrg/baseline1_2/baseline1_2_onlydatas/HC',
    # 'nback_mci': 'nback_seg_mrg/baseline1_2/baseline1_2_onlydatas/MCI',
    # 'nback_hc': 'nback_seg_mrg/baseline1_2/baseline1_2_onlydatas/HC'
}

all_data = {}

# Collect all data
for key, subdir in subdirs.items():
    full_path = os.path.join(base_dir, subdir)
    print(f"Collecting data from: {full_path}")
    all_data = collect_all_data(full_path, all_data)

# Find the maximum absolute value for each channel
max_values = {}
for channel, data in all_data.items():
    max_values[channel] = np.max(np.abs(np.concatenate(data)))

print(f"Maximum absolute values for each channel: {max_values}")

# Normalize and save data
save_base_dir = ''
for key, subdir in subdirs.items():
    full_path = os.path.join(base_dir, subdir)
    if 'rest' in key:
        if 'mci' in key:
            save_dir = os.path.join(save_base_dir, 'rest', 'mci')
        else:
            save_dir = os.path.join(save_base_dir, 'rest', 'hc')

    # if 'stroop' in key:
    #     if 'mci' in key:
    #         save_dir = os.path.join(save_base_dir, 'stroop', 'mci')
    #     else:
    #         save_dir = os.path.join(save_base_dir, 'stroop', 'hc')
    # else:
    #     if 'mci' in key:
    #         save_dir = os.path.join(save_base_dir, 'nback', 'mci')
    #     else:
    #         save_dir = os.path.join(save_base_dir, 'nback', 'hc')

    print(f"Normalizing and saving data from: {full_path} to {save_dir}")
    normalize_and_save_data(full_path, max_values, save_dir)
