In [None]:
import os, pickle, numpy as np, pandas as pd, matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.preprocessing import LabelEncoder

import configs as data_configs

from configs import GlobalConfig, RunContext, syn_cfg, hd_sq_cfg, bsl_cfg, load_dataset_configs
from classify_utils import (
    within_subject_cv, parcel_within_subject_cv,
    loso_classification, parcel_loso_classification,
    extract_features, aggregate_runs, preload_runs,
    subject_stats_to_df, build_save_dir, save_results_df,
    apply_sma_pruning, load_sma_sens_channels, load_pickle
)

from plotting_utils import (
    plot_grouped_bars_by_dt, plot_grouped_bars_by_subject,
    loso_barplot, barplot_subsets, raincloud_subsets_runs
)

# ──────────────────────────────────────────────────────────────────────────────
# Config
# ──────────────────────────────────────────────────────────────────────────────

#data_types = ['BS_Laura', 'HD_Squeezing']
#global_cfg = hd_sq_cfg
data_types = ['HD_Squeezing']
global_cfg = hd_sq_cfg
#data_types = ['BS_Laura']
#global_cfg = bsl_cfg

#data_types = ['Syn_Finger_Tapping']
#global_cfg = syn_cfg

dataset_configs = load_dataset_configs(data_types, test=True, load_sensitivity=True)

print("epo label path: ", dataset_configs[data_types[0]].epochs_labels_path(0,0,0,0))

# ──────────────────────────────────────────────────────────────────────────────
# Main Pipeline
# ──────────────────────────────────────────────────────────────────────────────
if global_cfg.run_pipeline:
    for data_type in data_types:
        ds = dataset_configs[data_type]
        subjects, base_path = ds.subjects, ds.base_path
        long_chs, probe_area = ds.long_channels, ds.probe_area
        ft_slices = ds.feature_slices
        result_path = os.path.join(global_cfg.result_root, data_type)

        subsets_data = load_pickle(os.path.join(base_path, 'subsets_data'))
        subset_keys = list(reversed(subsets_data.keys()))
        #subset_keys = ['subset_3', 'subset_2', 'full']
        channel_roi_laura = load_pickle(os.path.join(global_cfg.datasets_path, 'BS_Laura', "BS_Laura_YY_parcel_sens_channels"))

        for int_scaling in global_cfg.int_scalings:
            for spatial_scaling in global_cfg.spatial_scalings:
                print(f"\nIntensity Scaling: {int_scaling} | Spatial Scaling: {spatial_scaling}\n")
                data = None
                data, clean_map = preload_runs(ds, subjects, base_path, int_scaling, spatial_scaling)

                ctx = RunContext(
                    g=global_cfg, ds=ds, feature_types=["Slope"], reduce_features=False,
                    prune_channels=True, prune_chans_sma=True, data_type=data_type,
                    int_scaling=int_scaling, spatial_scaling=spatial_scaling,
                    subsets_data=subsets_data, clean_ch_map=clean_map,
                    sma_sens_channels= load_sma_sens_channels(base_path, data_type), ft_slices=ft_slices,
                )

                clean_sma_map = apply_sma_pruning(clean_map, ctx.sma_sens_channels) if ctx.prune_chans_sma else clean_map
                #ctx.clean_ch_map = clean_sma_map

                for clf_name, clf in global_cfg.classifiers.items():
                    ctx.clf_name, ctx.classifier = clf_name, clf

                    # ============== CHANNEL WS =========================
                    ctx.prune_chans_sma = True
                    if global_cfg.run_ch_ws:
                        scores = {dt: {k: {} for k in subset_keys} for dt in global_cfg.dt_conditions}
                        run_scores = {dt: {k: {} for k in subset_keys} for dt in global_cfg.dt_conditions}


                        for dt in global_cfg.dt_conditions:
                            ctx.dt = dt
                            for sub_key in subset_keys:
                                ctx.sub_key, subset_chs = sub_key, subsets_data[sub_key]["all"]

                                for subject in subjects:
                                    run_stats = []
                                    for run in data[subject]:
                                        ctx.subject, ctx.run = subject, run
                                        if "ss" in dt:
                                            run_data, ss_run_data = data[subject][run][sub_key][dt], None
                                        else:
                                            run_data = data[subject][run]["all"]
                                            ss_run_data = data[subject][run][sub_key].get(dt + "_ss_mean")

                                        prune_chs = clean_sma_map[subject][run] if ctx.prune_channels else None
                                        run_stat = within_subject_cv(run_data, ss_run_data, ctx, subject, subset_chs, prune_chs)
                                        run_stats.append(run_stat)

                                        subj_stat = run_stat.get(subject, None)
                                        if subj_stat is not None:
                                            acc = subj_stat.get("mean", np.nan)
                                            if acc is not None and not np.isnan(acc):
                                                run_scores[dt][sub_key].setdefault(subject, []).append(float(acc))

                                    agg = aggregate_runs(run_stats)
                                    scores[dt][sub_key][subject] = agg.get(subject, {"mean": np.nan, "std": np.nan})

                                save_dir = build_save_dir(result_path, space="channel", mode="ws", ctx=ctx)
                                os.makedirs(save_dir, exist_ok=True)
                                df = subject_stats_to_df(scores[dt][sub_key], include_all_row=True)
                                save_results_df(df, save_dir, f"results_{dt}_{sub_key}.csv")

                        avg_by_dt, std_by_dt = {dt: [] for dt in global_cfg.dt_conditions}, {dt: [] for dt in global_cfg.dt_conditions}
                        for dt in global_cfg.dt_conditions:
                            for sub_key in subset_keys:
                                vals = [scores[dt][sub_key][s]["mean"] for s in subjects if s in scores[dt][sub_key]]
                                avg_by_dt[dt].append(np.nanmean(vals) if vals else np.nan)
                                std_by_dt[dt].append(np.nanstd(vals) if vals else 0)

                        optodes = [subsets_data[k]["n_optodes"] / probe_area for k in subset_keys]

                        raincloud_subsets_runs(
                            run_scores=run_scores,subset_keys=subset_keys,optodes_per_cm2=optodes,
                            dt_labels=global_cfg.dt_labels,
                            save_dir=build_save_dir(result_path, space="channel", mode="ws_raincloud", ctx=ctx),
                            spatial_scaling=ctx.spatial_scaling, int_scaling=ctx.int_scaling,save_plot=global_cfg.save_plot,
                        )

                        #barplot_subsets(avg_by_dt, std_by_dt, subset_keys, optodes, global_cfg.dt_labels, save_dir,
                        #                ctx.spatial_scaling, ctx.int_scaling, global_cfg.save_plot)

                        full_scores = {dt: {s: scores[dt].get('full', {}).get(s, {}).get('mean', np.nan) for s in subjects}
                                    for dt in global_cfg.dt_conditions}
                        full_stds   = {dt: {s: scores[dt].get('full', {}).get(s, {}).get('std', np.nan) for s in subjects}
                                    for dt in global_cfg.dt_conditions}

                        plot_grouped_bars_by_dt(data_dict=full_scores, dt_labels_map=global_cfg.dt_labels, std_dict=full_stds, 
                                                title="Channel Space Classification Accuracy")
                        plot_grouped_bars_by_subject(data_dict=full_scores, dt_labels_map=global_cfg.dt_labels, std_dict=full_stds, 
                                                     title="Channel Space Classification Accuracy")

                    ctx.prune_chans_sma = False
                    ctx.g.sel_hrf_roi = None
                    # ============== CHANNEL LOSO =======================
                    if global_cfg.run_ch_loso:
                        for sub_key in ["full"]:
                            ctx.sub_key = sub_key
                            subset_chs = subsets_data[sub_key]["all"]
                            for dt in global_cfg.dt_conditions:
                                dt_before, dt_loso = dt, (dt + '_full') if 'ss' in dt else dt
                                ctx.dt = dt
                                all_data = {s: [] for s in subjects}
                                for subject in subjects:
                                    for run in data[subject]:
                                        entry = data[subject][run]['full'][dt_loso] if 'ss' in dt_loso else data[subject][run]['all']
                                        epo, y_raw = entry["epochs"], entry["splits"][0]["y"]
                                        if ctx.data_type == 'BS_Laura':
                                            epo = epo.sel(channel=[c for c in channel_roi_laura if c in epo.channel.values])
                                        if ctx.dt == "long":
                                            epo = epo.sel(channel=[c for c in long_chs if c in epo.channel.values])
                                        chans = [c for c in subset_chs if c in epo.channel.values]
                                        if ctx.prune_chans_sma and ctx.sma_sens_channels:
                                            chans = [c for c in chans if c in ctx.sma_sens_channels]
                                        epo = epo.sel(channel=chans)
                                        X = extract_features(epo, ctx, long_chs=None, prune_chs=None)
                                        if ctx.prune_channels and ctx.g.prune_by_zeroing_loso:
                                            X = X.where(X.channel.isin(ctx.clean_ch_map[subject][run]), 0)
                                        X_np = X.values if hasattr(X, "values") else X
                                        y = LabelEncoder().fit_transform(y_raw)
                                        all_data[subject].append((X_np, y))
                                loso_stats = loso_classification(all_data, subjects, ctx, k=ctx.g.n_reduced_feat_loso)
                                save_dir = build_save_dir(result_path, space="channel", mode="loso", ctx=ctx); os.makedirs(save_dir, exist_ok=True)
                                ttl = f"Channel LOSO per Subject - {global_cfg.dt_labels[dt_before]} - {ctx.clf_name} - subset: {sub_key}"
                                loso_barplot(loso_stats, ttl, save_dir,
                                             f"{ctx.clf_name.replace(' ', '_')}_loso_accuracy_{sub_key}_{dt_before}.png",
                                             global_cfg.save_plot)
                                df = subject_stats_to_df(loso_stats, include_all_row=True)
                                save_results_df(df, save_dir, f"results_{dt_before}_{sub_key}.csv")

                    # ============== PARCEL WS ==========================
                    if global_cfg.run_parcel_pipeline and global_cfg.run_parcel_ws:
                        dt_parcel = ["all_od", "all_od_ss_mean"]
                        dt_labels_parcel = {"all_od":"Parcel No SS Correction","all_od_ss_mean":"Parcel SS corrected"}
                        parcel_subset = ds.parcel_subset
                        subset_name, parcels = list(parcel_subset.items())[0]

                        for sub_key in subset_keys:
                            ctx.sub_key = sub_key
                            subset_chs = subsets_data[sub_key]['all']
                            acc_by_dt, std_by_dt = {}, {}

                            for dt in dt_parcel:
                                ctx.dt = dt
                                acc_by_dt[dt], std_by_dt[dt] = {}, {}
                                for si, subject in enumerate(subjects):
                                    run_stats = []
                                    for run in range(ds.n_runs(si)):
                                        run_data = data[subject][run]["full"][dt] if "ss" in dt else data[subject][run]["all_od"]
                                        clean_parcel_chs = [c for c in clean_map[subject][run] if c in subset_chs]
                                        stat = parcel_within_subject_cv(
                                            run_data=run_data, ctx=ctx,
                                            Adot=ds.Adot, B=ds.B,
                                            clean_chs=clean_parcel_chs, parcels=parcels,
                                            subject_id=subject, prune=ctx.prune_channels
                                        ); run_stats.append(stat)
                                    agg = aggregate_runs(run_stats)
                                    acc_by_dt[dt][subject] = agg.get(subject, {}).get("mean", np.nan)
                                    std_by_dt[dt][subject] = agg.get(subject, {}).get("std", 0.0)

                                save_dir = build_save_dir(result_path, space="parcel", mode="ws", ctx=ctx)
                                stats_dt = {s: {"mean": acc_by_dt[dt].get(s, np.nan), "std": std_by_dt[dt].get(s, np.nan)} for s in subjects}
                                df = subject_stats_to_df(stats_dt, include_all_row=True)
                                save_results_df(df, save_dir, f"results_{dt}_{subset_name}_chset_{sub_key}.csv")

                            plot_grouped_bars_by_dt(data_dict=acc_by_dt, dt_labels_map=dt_labels_parcel, std_dict=std_by_dt,
                                                    title=f"Parcel Space WS Accuracy {ctx.data_type} (ch subset: {sub_key})")
                            plot_grouped_bars_by_subject(data_dict=acc_by_dt, dt_labels_map=dt_labels_parcel, std_dict=std_by_dt,
                                                         title=f"Parcel Space WS Accuracy {ctx.data_type} (ch subset: {sub_key})")

                    # ============== PARCEL LOSO ========================
                    if global_cfg.run_parcel_pipeline and global_cfg.run_parcel_loso:
                        dt_parcel = ["all_od", "all_od_ss_mean"]
                        dt_labels_parcel = {"all_od":"Parcel No SS Correction","all_od_ss_mean":"Parcel SS corrected"}
                        parcel_subset = {"sensitive": ds.sensitive_parcels}
                        subset_name, parcels = list(parcel_subset.items())[0]

                        for sub_key in ["full"]:
                            ctx.sub_key = sub_key
                            subset_chs = subsets_data[sub_key]['all']
                            for dt in dt_parcel:
                                dt_before, dt_loso = dt, (dt + '_full') if 'ss' in dt else dt
                                data_parcel = {
                                    s: {r: (data[s][r]['full'][dt_loso] if "ss" in dt_loso else data[s][r]['all_od']) for r in data[s]}
                                    for s in data
                                }
                                clean_sub = {
                                    sbj: {run: [
                                        c for c in clean_map[sbj][run] if c in subset_chs
                                    ] for run in clean_map[sbj]} for sbj in clean_map
                                }
                                loso_stats = parcel_loso_classification(
                                    data=data_parcel, subjects=subjects, ctx=ctx,
                                    Adot=ds.Adot, B=ds.B, clean_ch_map=clean_sub,
                                    parcels=parcels, prune=ctx.prune_channels
                                )
                                save_dir = build_save_dir(result_path, space="parcel", mode="loso", ctx=ctx); os.makedirs(save_dir, exist_ok=True)
                                ttl = f"Parcel LOSO per Subject - {dt_labels_parcel[dt_before]} - {ctx.clf_name} - ch subset: {sub_key}"
                                loso_barplot(loso_stats, ttl, save_dir,
                                             f"{ctx.clf_name.replace(' ', '_')}_loso_accuracy_{sub_key}_{dt_before}.png",
                                             global_cfg.save_plot)
                                df = subject_stats_to_df(loso_stats, include_all_row=True)
                                save_results_df(df, save_dir, f"results_{dt_before}_{subset_name}_chset_{sub_key}.csv")
