In [1]:
import numpy as np
import pandas as pd
import itertools
from collections import defaultdict
from typing import Dict, Tuple, Any, List
from scipy import stats
from statsmodels.stats.power import TTestIndPower
import warnings
import os
os.chdir('../../..')

from Scripts.Spectral_Analysis.Spectrum_Filter import filter_spectras
warnings.filterwarnings("ignore")

In [2]:
FEATURE_GROUPS = {
    "day_time": {
        "Day":      ["Day"],
        "Evening":  ["Evening"],
    },

    "stim_type": {
        "r": ["r"],
        "g": ["g"],
    },

    "stim_label": {
        "line":   [0, 1, 4, 6, 7, 11, 8],
        "figure": [2, 3, 5, 9, 10, 12],
    },

    "gender": {
        "m": ["m"],
        "f": ["f"],
    },

    "age": {
        "18-22": [18, 19, 20, 21, 22],
        "23-29": [23, 24, 25, 25, 26, 27, 28, 29],
        "30-35": [30, 31, 32, 33, 34, 35],
    },

    "handiness": {
        "l": ["l"],
        "r": ["r"],
    },
}


DEFAULT_BANDS = {
    'Delta': (1, 4),
    'Tetta': (4, 7),
    'Alpha': (7, 13),
    'Beta': (13, 30)
}

_RESULT_FIELDS = [
    "power",
    "phase",
    "s_id",
    "t_id",
    "gender",
    "handiness",
    "age",
    "label",
    "img",
    "task_type",
]

def normalize_results(raw_results: List[Any]) -> List[Dict[str, Any]]:
    """
    Превращает вывод filter_spectras(...) в список словарей.
    Если каких-то полей нет в конкретном npz — просто не добавляем.
    Важно: нам критично только наличие 'power' и 's_id'.
    """
    norm = []
    for item in raw_results:
        if isinstance(item, dict):
            norm.append(item.copy())
        else:
            rec = {}
            for idx, key in enumerate(_RESULT_FIELDS):
                if idx < len(item):
                    rec[key] = item[idx]
            norm.append(rec)
    return norm

def make_band_cols_from_bands(
    bands: Dict[str, Tuple[float, float]],
    freqs: np.ndarray,
) -> Dict[str, np.ndarray]:
    """
    bands: {"alpha": (8,12), ...} в Гц
    freqs: массив частот для этого npz (длины n_freqs)
    Возвращает:
        {"alpha": array([...индексы...]), "beta": array([...]), ...}
    """
    band_cols = {}
    for band_name, (f_lo, f_hi) in bands.items():
        idx = np.where((freqs >= f_lo) & (freqs <= f_hi))[0]
        band_cols[band_name] = idx
    return band_cols

def aggregate_by_subject(
    results: List[Dict[str, Any]],
    band_cols: Dict[str, np.ndarray],
    n_channels: int,
) -> Dict[Tuple[int, str], Dict[str, Any]]:
    """
    results: список dict после normalize_results.
    Возвращает subj_vecs[(ch, band_name)] = {'ids': [...], 'vals': np.array([...])}
    """
    subj_band_vals = defaultdict(lambda: defaultdict(list))
    # subj_band_vals[(ch, band_name)][sid] -> [vals_per_trial...]

    for rec in results:
        power = rec.get("power", None)
        s_id  = rec.get("s_id", None)
        if power is None or s_id is None:
            continue

        try:
            sid_int = int(s_id)
        except Exception:
            continue

        # приведение к (C,F)
        if hasattr(power, "ndim") and power.ndim == 3:
            # (C,F,T) -> среднее по времени -> (C,F)
            vecs = power.mean(axis=2).astype(np.float32, copy=False)
        else:
            vecs = np.asarray(power, dtype=np.float32)

        if vecs.ndim != 2:
            continue  # пропускаем если не (C,F)

        # канал × бэнд
        for ch in range(n_channels):
            if ch >= vecs.shape[0]:
                break
            for band_name, freq_idx in band_cols.items():
                if freq_idx.size == 0:
                    continue
                if np.max(freq_idx) >= vecs.shape[1]:
                    continue
                band_val = float(np.nanmean(vecs[ch, freq_idx]))
                subj_band_vals[(ch, band_name)][sid_int].append(band_val)

    # усредняем по trial'ам субъекта
    subj_vecs = {}
    all_keys = list(subj_band_vals.keys())
    for (ch, band_name) in all_keys:
        subj_vals = subj_band_vals[(ch, band_name)]
        ids, vals = [], []
        for sid, arr in subj_vals.items():
            if arr:
                ids.append(sid)
                vals.append(float(np.nanmean(arr)))
        subj_vecs[(ch, band_name)] = {
            'ids': ids,
            'vals': np.array(vals, dtype=float)
        }

    return subj_vecs

def build_two_condition_vectors(
    condA_subj: Dict[Tuple[int, str], Dict[str, Any]],
    condB_subj: Dict[Tuple[int, str], Dict[str, Any]],
    condA_name: str,
    condB_name: str,
    n_channels: int,
    band_names: List[str],
) -> Dict[Tuple[int, str], Dict[str, Any]]:
    """
    Возвращает структуру вида:
        {(ch, band): {
            condA_name: np.array([...]),
            condB_name: np.array([...]),
            condA_name+'_ids': [...],
            condB_name+'_ids': [...],
        }}
    """
    subject_vectors: Dict[Tuple[int, str], Dict[str, Any]] = {}

    for ch in range(n_channels):
        for band_name in band_names:
            key = (ch, band_name)
            A = condA_subj.get(key, {'ids': [], 'vals': np.array([], float)})
            B = condB_subj.get(key, {'ids': [], 'vals': np.array([], float)})

            subject_vectors[key] = {
                condA_name: np.asarray(A.get('vals', np.array([], float)), dtype=float),
                condB_name: np.asarray(B.get('vals', np.array([], float)), dtype=float),
                f"{condA_name}_ids": list(A.get('ids', [])),
                f"{condB_name}_ids": list(B.get('ids', [])),
            }

    return subject_vectors

def hedges_g(x, y) -> float:
    """
    Hedges' g для двух независимых групп (с поправкой J).
    """
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    nx, ny = len(x), len(y)
    if nx < 2 or ny < 2:
        return np.nan

    vx, vy = np.var(x, ddof=1), np.var(y, ddof=1)

    # обе дисперсии нулевые → считаем эффект 0 (со знаком разницы средних)
    if vx == 0 and vy == 0:
        return float(np.sign(np.nanmean(x) - np.nanmean(y)) * 0.0)

    sp2 = ((nx - 1) * vx + (ny - 1) * vy) / (nx + ny - 2)
    if sp2 <= 0:
        return np.nan

    d = (np.nanmean(x) - np.nanmean(y)) / np.sqrt(sp2)
    J = 1 - 3 / (4 * (nx + ny) - 9) if (nx + ny) > 2 else 1.0
    return float(J * d)


def build_band_table_two_conditions(
    subject_vectors: Dict[Tuple[int, str], Dict[str, Any]],
    condA_name: str,
    condB_name: str,
    alpha: float = 0.05
) -> pd.DataFrame:
    """
    Для каждой пары (канал, бэнд):
      - Welch t-test
      - p-value
      - Hedges' g
      - оценка мощности
    """
    power_calc = TTestIndPower()
    rows = []

    for (ch, band_name), vecs in subject_vectors.items():
        x = np.asarray(vecs.get(condA_name, []), float)
        y = np.asarray(vecs.get(condB_name, []), float)
        nA, nB = len(x), len(y)
        if nA == 0 or nB == 0:
            continue

        try:
            t_stat, p_val = stats.ttest_ind(
                x, y,
                equal_var=False,
                nan_policy='omit'
            )
        except Exception:
            t_stat, p_val = np.nan, np.nan

        g = hedges_g(x, y)

        try:
            ratio = nB / max(nA, 1)
            es = abs(g)
            power = power_calc.solve_power(
                effect_size=es,
                nobs1=nA,
                ratio=ratio,
                alpha=alpha,
                alternative='two-sided'
            )
        except Exception:
            power = np.nan

        rows.append({
            'channel': ch,
            'band': band_name,

            f'n_{condA_name}': nA,
            f'n_{condB_name}': nB,

            f'mean_{condA_name}': float(np.nanmean(x)) if nA else np.nan,
            f'mean_{condB_name}': float(np.nanmean(y)) if nB else np.nan,

            f'delta_{condA_name}_minus_{condB_name}':
                float(np.nanmean(x) - np.nanmean(y)) if (nA and nB) else np.nan,

            't_stat': float(t_stat) if np.isfinite(t_stat) else np.nan,
            'p_value': float(p_val) if np.isfinite(p_val) else np.nan,
            'hedges_g': float(g) if np.isfinite(g) else np.nan,
            'power': float(power) if np.isfinite(power) else np.nan,

            'sig_alpha_0.05': bool((p_val <= 0.05) if np.isfinite(p_val) else False),
            'sig_alpha_0.01': bool((p_val <= 0.01) if np.isfinite(p_val) else False),
            'sig_and_power': bool(
                (p_val <= 0.05) and (power >= 0.8)
            ) if (np.isfinite(p_val) and np.isfinite(power)) else False,
        })

    df = pd.DataFrame(rows)
    if not df.empty:
        df = df.sort_values(['p_value', 'power'],
                            ascending=[True, True]).reset_index(drop=True)
    return df



def run_stats_between_groups(
    data_A_raw,
    data_B_raw,
    group_A_name: str,
    group_B_name: str,
    band_cols: Dict[str, np.ndarray],
    n_channels: int,
    alpha: float = 0.05,
):
    # нормализация формата
    data_A = normalize_results(data_A_raw)
    data_B = normalize_results(data_B_raw)

    # агрегация по субъектам
    A_by_subj = aggregate_by_subject(
        data_A,
        band_cols=band_cols,
        n_channels=n_channels,
    )
    B_by_subj = aggregate_by_subject(
        data_B,
        band_cols=band_cols,
        n_channels=n_channels,
    )

    # какие бэнды считаем
    band_names = list(band_cols.keys())

    # векторизуем A и B
    subj_vecs = build_two_condition_vectors(
        condA_subj=A_by_subj,
        condB_subj=B_by_subj,
        condA_name=group_A_name,
        condB_name=group_B_name,
        n_channels=n_channels,
        band_names=band_names,
    )

    # считаем t-тесты и эффекты
    stats_table = build_band_table_two_conditions(
        subject_vectors=subj_vecs,
        condA_name=group_A_name,
        condB_name=group_B_name,
        alpha=alpha,
    )

    return {
        'summary_table': stats_table,
        'meta': {
            'group_A': group_A_name,
            'group_B': group_B_name,
        }
    }


def _collapse_allowed_values(allowed_values: List[Any]):
    """
    Если группа — это одна категория (например ["m"]), вернуть "m".
    Если это бин (например [18,19,20,...]) — вернуть весь список.
    """
    if len(allowed_values) == 1:
        return allowed_values[0]
    return allowed_values

def filter_group(
    npz_path: str,
    feature_name: str,
    allowed_values: List[Any],
    extra_kwargs: dict,
):

    kwargs = {
        "exec_spec_path": npz_path,
    }

    meta_path = extra_kwargs.get("day_time_meta_path", None)
    if meta_path is None:
        raise RuntimeError(
            "day_time_meta_path обязателен: передай его в extra_kwargs "
            "при вызове run_full_analysis_for_npz(...)"
        )
    kwargs["day_time_meta_path"] = meta_path

    # ["m"] -> "m", ["Day"] -> "Day", [18,19,20] остаётся списком
    val = _collapse_allowed_values(allowed_values)

    if feature_name == "day_time":
        kwargs["day_time"] = val
    elif feature_name == "stim_type":
        kwargs["stim_type"] = val
    elif feature_name == "stim_label":
        kwargs["stim_label"] = val
    elif feature_name == "gender":
        kwargs["gender"] = val
    elif feature_name == "age":
        kwargs["age"] = val
    elif feature_name == "handiness":
        kwargs["handiness"] = val
    else:
        raise ValueError(f"Unknown feature: {feature_name}")

    raw_results = filter_spectras(**kwargs)
    return raw_results

def prepare_session_config_for_npz(
    npz_path: str,
    bands: Dict[str, Tuple[float, float]] = None,
    extra_kwargs: dict = None,
):

    if bands is None:
        bands = DEFAULT_BANDS.copy()
    if extra_kwargs is None:
        extra_kwargs = {}

    meta_path = extra_kwargs.get("day_time_meta_path", None)
    if meta_path is None:
        raise RuntimeError(
            "day_time_meta_path обязателен: передай extra_kwargs={'day_time_meta_path': '...xlsx'}"
        )

    base_kwargs = {
        "exec_spec_path": npz_path,
        "day_time_meta_path": meta_path,
    }

    all_trials_raw = filter_spectras(**base_kwargs)
    all_trials_norm = normalize_results(all_trials_raw)

    # найдём хотя бы один trial с power
    example_power = None
    for rec in all_trials_norm:
        if "power" in rec and rec["power"] is not None:
            example_power = rec["power"]
            break

    if example_power is None:
        raise RuntimeError("Не удалось найти ни одного trial с 'power' в этом npz (после filter_spectras).")

    # определяем размерности
    if hasattr(example_power, "ndim") and example_power.ndim == 3:
        n_channels, n_freqs, _ = example_power.shape
    elif hasattr(example_power, "ndim") and example_power.ndim == 2:
        n_channels, n_freqs = example_power.shape
    else:
        raise ValueError(f"Неожиданная форма 'power': {getattr(example_power,'shape',None)}")

    freqs = np.linspace(2, 40, n_freqs)

    # индексы частот для каждого EEG-бэнда
    band_cols = make_band_cols_from_bands(bands, freqs)

    session_cfg = {
        "npz_path": npz_path,
        "n_channels": n_channels,
        "band_cols": band_cols,
        "bands": bands,
        "extra_kwargs": extra_kwargs,  # важно: тут живёт day_time_meta_path
    }
    return session_cfg

def analyze_single_feature(
    session_cfg: dict,
    feature_name: str,
    alpha: float = 0.05,
):
    """
    Для фичи 'gender':
      - получаем группы {'m': ['m'], 'f': ['f']}
      - через filter_group() вытаскиваем триалы для каждой группы
      - сравниваем попарно (m vs f)
    Возвращает dict { (groupA, groupB): {'summary_table': df, 'meta': {...}}, ... }
    """

    npz_path = session_cfg["npz_path"]
    band_cols = session_cfg["band_cols"]
    n_channels = session_cfg["n_channels"]
    extra_kwargs = session_cfg["extra_kwargs"]

    group_spec = FEATURE_GROUPS[feature_name]

    # получаем данные под каждую группу (через filter_spectras)
    subsets = {
        group_label: filter_group(
            npz_path=npz_path,
            feature_name=feature_name,
            allowed_values=allowed_values,
            extra_kwargs=extra_kwargs,
        )
        for group_label, allowed_values in group_spec.items()
    }

    results_for_feature = {}

    # перебираем все пары групп внутри фичи
    for g1, g2 in itertools.combinations(subsets.keys(), 2):
        data_A_raw = subsets[g1]
        data_B_raw = subsets[g2]

        stats_result = run_stats_between_groups(
            data_A_raw=data_A_raw,
            data_B_raw=data_B_raw,
            group_A_name=g1,
            group_B_name=g2,
            band_cols=band_cols,
            n_channels=n_channels,
            alpha=alpha,
        )

        results_for_feature[(g1, g2)] = stats_result

    return results_for_feature

def analyze_feature_pair(
    session_cfg: dict,
    feature_A: str,
    feature_B: str,
    alpha: float = 0.05,
):
    """
    Для ('day_time','stim_type'):
      - строим группы A: Day / Evening
      - строим группы B: r / g
      - сравниваем всё со всем:
          Day vs r, Day vs g, Evening vs r, Evening vs g
    Возвращает dict { (label_A, label_B): {'summary_table': df, ...}, ... }
    """

    npz_path = session_cfg["npz_path"]
    band_cols = session_cfg["band_cols"]
    n_channels = session_cfg["n_channels"]
    extra_kwargs = session_cfg["extra_kwargs"]

    groups_A = FEATURE_GROUPS[feature_A]
    groups_B = FEATURE_GROUPS[feature_B]

    subsets_A = {
        label_A: filter_group(
            npz_path=npz_path,
            feature_name=feature_A,
            allowed_values=allowed_values_A,
            extra_kwargs=extra_kwargs,
        )
        for label_A, allowed_values_A in groups_A.items()
    }

    subsets_B = {
        label_B: filter_group(
            npz_path=npz_path,
            feature_name=feature_B,
            allowed_values=allowed_values_B,
            extra_kwargs=extra_kwargs,
        )
        for label_B, allowed_values_B in groups_B.items()
    }

    results_for_pair = {}

    for label_A, data_A_raw in subsets_A.items():
        for label_B, data_B_raw in subsets_B.items():
            stats_result = run_stats_between_groups(
                data_A_raw=data_A_raw,
                data_B_raw=data_B_raw,
                group_A_name=label_A,
                group_B_name=label_B,
                band_cols=band_cols,
                n_channels=n_channels,
                alpha=alpha,
            )
            results_for_pair[(label_A, label_B)] = stats_result

    return results_for_pair


def run_full_analysis_for_npz(
    npz_path: str,
    alpha: float = 0.05,
    feature_list_single=None,
    feature_pairs=None,
    extra_kwargs: dict = None,
    bands: Dict[str, Tuple[float, float]] = None,
):


    if extra_kwargs is None or "day_time_meta_path" not in extra_kwargs:
        raise RuntimeError(
            "run_full_analysis_for_npz: передай extra_kwargs={'day_time_meta_path': '...xlsx'}"
        )

    session_cfg = prepare_session_config_for_npz(
        npz_path=npz_path,
        bands=bands,
        extra_kwargs=extra_kwargs,
    )

    if feature_list_single is None:
        feature_list_single = list(FEATURE_GROUPS.keys())

    if feature_pairs is None:
        feature_pairs = list(itertools.combinations(FEATURE_GROUPS.keys(), 2))

    final_results = {
        "single_feature": {},
        "pair_feature": {},
    }

    # одиночные фичи (зелёные сравнения)
    for feat in feature_list_single:
        final_results["single_feature"][feat] = analyze_single_feature(
            session_cfg=session_cfg,
            feature_name=feat,
            alpha=alpha,
        )

    # пары фичей (красные сравнения)
    for (featA, featB) in feature_pairs:
        final_results["pair_feature"][(featA, featB)] = analyze_feature_pair(
            session_cfg=session_cfg,
            feature_A=featA,
            feature_B=featB,
            alpha=alpha,
        )

    return final_results



In [3]:
npz_path = r'./Generated/Spectrums/psds_array_morlet.npz'
meta_path = r"./Supplementary/Experiment_Metadata.xlsx"

bands = {
    'Delta': (1, 4),
    'Tetta': (4, 7),
    'Alpha': (7, 13),
    'Beta': (13, 30)
}


extra_kwargs = {
    "day_time_meta_path": meta_path
}


results = run_full_analysis_for_npz(
    npz_path=npz_path,
    alpha=0.05,
    extra_kwargs=extra_kwargs,
    bands=bands,
    feature_list_single=None,   # None => все фичи по отдельности
    feature_pairs=None,         # None => все пары фичей
)

# gender (m vs f)
if "gender" in results["single_feature"] and ("m","f") in results["single_feature"]["gender"]:
    gender_table = results["single_feature"]["gender"][("m","f")]["summary_table"]

    if not gender_table.empty:
        print("GENDER: m vs f (top by p-value):")
        display(gender_table.sort_values("p_value").head(10))

        sig_gender = gender_table[
            (gender_table["p_value"] <= 0.05) &
            (gender_table["power"]  >= 0.8)
        ].sort_values("p_value")
        print("GENDER: significant effects (p<=0.05 & power>=0.8):")
        display(sig_gender.head(20))
    else:
        print("gender_table пустая (недостаточно данных для m vs f).")
else:
    print("Нет сравнения ('m','f') в results['single_feature']['gender'].")

# day_time (Day vs Evening)
if "day_time" in results["single_feature"] and ("Day","Evening") in results["single_feature"]["day_time"]:
    dt_table = results["single_feature"]["day_time"][("Day","Evening")]["summary_table"]
    if not dt_table.empty:
        print("DAY_TIME: Day vs Evening (top by p-value):")
        display(dt_table.sort_values("p_value").head(10))
    else:
        print("Day vs Evening пусто (возможно, есть только Day или только Evening).")
else:
    print("Нет пары ('Day','Evening') в single_feature['day_time'].")

# пара фичей: (day_time, stim_type) → Day vs r
if ("day_time","stim_type") in results["pair_feature"]:
    block = results["pair_feature"][("day_time","stim_type")]
    if ("Day","r") in block:
        dr_table = block[("Day","r")]["summary_table"]
        if not dr_table.empty:
            print("PAIR: Day vs r (top by p-value):")
            display(dr_table.sort_values("p_value").head(10))
        else:
            print("Day vs r есть логически, но таблица пустая (недостаточно данных).")
    else:
        print("Нет сравнения ('Day','r') в pair_feature[('day_time','stim_type')].")
else:
    print("('day_time','stim_type') не было посчитано.")



GENDER: m vs f (top by p-value):


Unnamed: 0,channel,band,n_m,n_f,mean_m,mean_f,delta_m_minus_f,t_stat,p_value,hedges_g,power,sig_alpha_0.05,sig_alpha_0.01,sig_and_power
0,39,Beta,8,8,0.039274,0.07919,-0.039916,-2.895621,0.012404,-1.368839,0.721233,True,False,False
1,58,Beta,8,8,0.046096,0.10318,-0.057084,-2.84253,0.021782,-1.343741,0.705409,True,False,False
2,28,Beta,8,8,0.044466,0.103799,-0.059332,-2.761304,0.024612,-1.305344,0.680423,True,False,False
3,35,Tetta,8,8,0.147817,0.286991,-0.139174,-2.728179,0.024979,-1.289685,0.669984,True,False,False
4,21,Beta,8,8,0.032785,0.055815,-0.02303,-2.347028,0.037819,-1.109504,0.54214,True,False,False
5,7,Beta,8,8,0.027848,0.048271,-0.020423,-2.312942,0.040232,-1.093391,0.530266,True,False,False
6,26,Beta,8,8,0.041728,0.074063,-0.032334,-2.258304,0.040774,-1.067562,0.51118,True,False,False
7,61,Beta,8,8,0.040416,0.06508,-0.024664,-2.369223,0.042437,-1.119996,0.549851,True,False,False
8,48,Beta,8,8,0.058988,0.085606,-0.026618,-2.229778,0.043192,-1.054077,0.501201,True,False,False
9,60,Beta,8,8,0.04008,0.110848,-0.070767,-2.412956,0.043851,-1.14067,0.564988,True,False,False


GENDER: significant effects (p<=0.05 & power>=0.8):


Unnamed: 0,channel,band,n_m,n_f,mean_m,mean_f,delta_m_minus_f,t_stat,p_value,hedges_g,power,sig_alpha_0.05,sig_alpha_0.01,sig_and_power


DAY_TIME: Day vs Evening (top by p-value):


Unnamed: 0,channel,band,n_Day,n_Evening,mean_Day,mean_Evening,delta_Day_minus_Evening,t_stat,p_value,hedges_g,power,sig_alpha_0.05,sig_alpha_0.01,sig_and_power
0,1,Beta,9,7,0.069974,0.036717,0.033257,2.610242,0.027214,1.107933,0.534583,True,False,False
1,15,Beta,9,7,0.086674,0.058647,0.028027,2.426645,0.029371,1.109208,0.535515,True,False,False
2,39,Beta,9,7,0.072947,0.0416,0.031347,2.185808,0.04821,0.967199,0.431623,True,False,False
3,28,Beta,9,7,0.094699,0.04769,0.047009,2.21908,0.052552,0.940936,0.412706,False,False,False
4,20,Beta,9,7,0.112791,0.068724,0.044067,2.133686,0.055178,0.928558,0.403859,False,False,False
5,61,Beta,9,7,0.062034,0.040809,0.021225,2.12068,0.057874,0.915386,0.394497,False,False,False
6,9,Beta,9,7,0.176116,0.068118,0.107999,2.15806,0.060979,0.905557,0.387551,False,False,False
7,8,Beta,9,7,0.168995,0.075693,0.093302,1.96093,0.079419,0.834254,0.338375,False,False,False
8,60,Beta,9,7,0.098793,0.045471,0.053322,1.868953,0.094083,0.790054,0.309183,False,False,False
9,10,Beta,9,7,0.073309,0.047539,0.02577,1.815228,0.094124,0.797494,0.314018,False,False,False


Day vs r есть логически, но таблица пустая (недостаточно данных).
