# Channel  Quality Assessment and Pruning

In [None]:

# %%
import cedalion
import cedalion.sigproc.quality as quality
import matplotlib.pyplot as plt
from cedalion import units
import numpy as np
import os
import pickle
from cedalion.sigproc.quality import repair_amp
import warnings
warnings.filterwarnings("ignore")
import configs 

# %%
data_types = ['Syn_Finger_Tapping']
#data_types = ['HD_Squeezing']
#data_type = 'Stroop_J'
#data_type = 'Stroop_D'
#data_types = ['BS_Laura']
#data_type = 'Syn_Ft'

save = True

data_path = configs.data_path_prefix
data_configs = configs.load_dataset_configs(data_types, test=False)

if 'BS_Laura' in data_types:
    try: 
        laura_sens_path = os.path.join(data_path, 'BS_Laura', "BS_Laura_YY_parcel_sens_channels")
        with open(laura_sens_path, 'rb') as f:
            channel_roi_sens_laura = pickle.load(f)
    except FileNotFoundError:
        print(f"could not find file {laura_sens_path}")
        channel_roi_sens_laura = None

if 'Syn_Finger_Tapping' in data_types:
    try:
        nn22_roi_path = os.path.join(data_path, 'NN22_Resting_State', "NN_22_C3_C4_close_channels")
        with open(nn22_roi_path, 'rb') as f:
            channel_roi_c3_c4 = pickle.load(f)
    except FileNotFoundError:
        print(f"could not find file {nn22_roi_path}")
        channel_roi_c3_c4 = None

print(data_types)
print(data_configs)

for data_type in data_types:
    data_config = data_configs[data_type]
    subject_list = data_config.all_subjects
    base_path = data_config.base_path
    snirf_path_template = data_config.snirf_path
    if data_type == 'Syn_Finger_Tapping':
        snirf_path_template = data_config.resting_snirf_path


    num_clean_channels = {}

    sci_threshold = 0.6
    psp_threshold = 0.1
    pc_clean_threshold = 0.5
    clean_channel_ratio = 0.5
    snr_thresh = 10

    clean_subjects = []
    avg_n_clean_chan = []

    avg_snrs = []
    std_snrs = []

    for subject in subject_list:
        print(f"Subject: {subject}")

        num_clean_channels[subject] = []

        for run in data_config.runs(subject):
            print("Run: " , run)

            try:
                snirf_path = snirf_path_template.format(subject=subject, run=run)
                rec = cedalion.io.read_snirf(base_path + snirf_path)[0]

            except FileNotFoundError:
                print(f"File not found for {subject}, {run}. Skipping.")
                continue

            window_length = 5*units.s   

            rec["amp"] = repair_amp(rec["amp"], median_len=0)

            print("Total channels before ROI filter: ", rec["amp"].channel.size)
            print("Total time points: ", rec["amp"].time[-1].values)

            if data_type == 'BS_Laura' and channel_roi_sens_laura is not None:
                print("")
                rec["amp"] = rec["amp"].sel(channel=channel_roi_sens_laura)
                
            elif data_type == 'Syn_Finger_Tapping' and channel_roi_c3_c4 is not None:
                rec["amp"] = rec["amp"].sel(channel=channel_roi_c3_c4)

            # SNR
            snr, snr_mask = quality.snr(rec["amp"], snr_thresh)
            reduced_snr_mask = snr_mask.all(dim='wavelength')
            avg_snrs.append(snr.values.mean())
            std_snrs.append(snr.values.std())

            # SCI & PSP
            sci, sci_mask = quality.sci(rec["amp"], window_length, sci_threshold)
            psp, psp_mask = quality.psp(rec["amp"], window_length, psp_threshold)
            sci_psp_mask = sci_mask & psp_mask
            perc_time_clean = sci_psp_mask.sum(dim="time") / len(sci_psp_mask.time)
            
            clean_mask = (perc_time_clean > pc_clean_threshold)
            ncc = clean_mask.sum().values
            clean_channels_snr = list(rec["amp"].channel.values[reduced_snr_mask])
            print("Number of channels passing SNR threshold: ", len(clean_channels_snr))

            path = base_path + f'epochs_labels/clean_channels/{subject}/{run}/'
            filepath = path + "clean_channels.pkl"

            print("filepath: ", filepath)

            if save:
                if not os.path.exists(path):
                    os.makedirs(path)
                with open(filepath, 'wb') as f:
                    pickle.dump(clean_channels_snr, f)
                    print("wrote clean channels to" + filepath)

            print(f"Number of clean channels: {ncc}")

            num_clean_channels[subject].append(ncc)

            print("clean_channels_snr_a:", clean_channels_snr)
            print("length clean_channels_snr_a:", len(clean_channels_snr))
            try:            
                with open(filepath, 'rb') as f:
                    clean_channels_snr_loaded = pickle.load(f)
                    print("clean_channels_snr_b:", clean_channels_snr_loaded)
                    print("length clean_channels_snr_b:", len(clean_channels_snr_loaded))
            except FileNotFoundError:
                print("File not found:", filepath)


        # Average per subject
        if num_clean_channels[subject]: 
            avg_clean_channels = np.mean(num_clean_channels[subject])
            avg_n_clean_chan.append(avg_clean_channels)
            num_clean_channels[subject + "_average"] = avg_clean_channels
            print(f"Average number of clean channels: {avg_clean_channels}")
            
            total_channels_in_rec = rec["amp"].channel.size
            print(f"Total channels in this rec (post-filter): {total_channels_in_rec}")
            print("Clean Channel Ratio: " , avg_clean_channels / total_channels_in_rec)
            print("")
            if avg_clean_channels > (total_channels_in_rec * clean_channel_ratio):
                clean_subjects.append(subject)
        else:
            avg_n_clean_chan.append(0) # add 0 to keep indexing correct
            num_clean_channels[subject + "_average"] = 0


    # clean subjects
    print(clean_subjects)

    # drop subjects
    print([sub for sub in subject_list if sub not in clean_subjects])

    channel_count = rec["amp"].channel.size 
    print(f"Plotting ratio based on channel count: {channel_count}")

    cc_ratio = [cc / channel_count for cc in avg_n_clean_chan]

    # %%
    plt.figure(figsize=(15, 5))
    x = np.arange(len(subject_list)) 
    width = 0.25  # width of each bar
    plt.bar(x, cc_ratio, label='N Clean Channel')

    plt.ylabel("Clean Channel Ratio")
    plt.xlabel("Subject")
    plt.title(f"Clean Channel Ratio - {data_type}")
    plt.xticks(x, subject_list, rotation=90)
    plt.ylim(0, 1)
    plt.tight_layout()

    # save plots
    if save:
        plot_path = configs.images_path_prefix + "signal_quality/"
        if not os.path.exists(plot_path):
            os.makedirs(plot_path)
            
        plt.savefig(f"{plot_path}cc_ratio_{data_type}.pdf")
        plt.savefig(f"{plot_path}cc_ratio_{data_type}.png")
        plt.savefig(f"{plot_path}cc_ratio_{data_type}.svg")
    plt.show()

    # %%
    print("Clean Channel Ratios:", cc_ratio)

    # %%
    print([subject_list[i] for i in range(len(subject_list)) if (np.array(avg_n_clean_chan) > channel_count * 0.4)[i]])