## Setup

### Imports

In [9]:
# import pprint
from warnings import simplefilter

import pandas as pd
from IPython.display import Markdown, display
from statsmodels.stats.multitest import multipletests

simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
import json
import re
import textwrap
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from statsmodels.stats.multitest import multipletests

import helpers
import matplotlib.pyplot as plt
import numpy as np
import pyperclip
import statsmodels.api as sm
from IPython.display import clear_output
from matplotlib import colormaps
from scipy import stats
from statsmodels.genmod.families import Poisson

# from reload_recursive import reload_recursive
from statsmodels.stats.mediation import Mediation
from statsmodels.stats.outliers_influence import variance_inflation_factor
from tqdm.notebook import tqdm

from mri_data import file_manager as fm

### Load Data

#### Clinical and Volumes

In [2]:
drive_root = fm.get_drive_root()
dataroot = drive_root / "3Tpioneer_bids"
data_dir = Path("/home/srs-9/Projects/ms_mri/data")
fig_path = Path(
    "/home/srs-9/Projects/ms_mri/analysis/thalamus/figures_tables/choroid_associations"
)

choroid_volumes = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/data/choroid_aschoplex_volumes.csv", index_col="subid"
)
ventricle_volumes = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/analysis/paper1/data0/ventricle_volumes.csv",
    index_col="subid",
)
csf_volumes = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/analysis/thalamus/data0/csf_volumes2.csv",
    index_col="subid",
)
third_ventricle_width = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/analysis/thalamus/data0/third_ventricle_width.csv",
    index_col="subid",
)

tiv = pd.read_csv("/home/srs-9/Projects/ms_mri/data/tiv_data.csv", index_col="subid")

df = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/data/clinical_data_processed.csv", index_col="subid"
)
sdmt = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/analysis/thalamus/SDMT_sheet.csv", index_col="subid"
)
df = df.join([choroid_volumes, tiv, ventricle_volumes, sdmt["SDMT"]])
df["allCSF"] = csf_volumes["all"]
df["thirdV"] = csf_volumes["third_ventricle"]
#! need to fix the actual segmentation files
df["periCSF"] = csf_volumes["peripheral"] - df["thirdV"]

df["thirdV_width"] = third_ventricle_width["third_ventricle_width"]
df["periCSF_frac"] = csf_volumes["peripheral"] / csf_volumes["all"]
df.rename(columns={"ventricle_volume": "LV", "choroid_volume": "CP"}, inplace=True)

df["SDMT"] = pd.to_numeric(df["SDMT"], errors="coerce")
df["thalamus_sqrt"] = np.sqrt(df["thalamus"])
df["thalamus_curt"] = np.sqrt(df["thalamus"] ** 3)
df["cortical_thickness_inv"] = 1 / df["cortical_thickness"]
df["LV_logtrans"] = np.log(df["LV"])
df["PRL_log1p"] = np.log1p(df["PRL"])

# these corrections should ultimately be made to the csf file
for struct in ["brain", "white", "grey", "thalamus", "t2lv"]:
    df[struct] = df[struct] * 1000

df["CCF"] = df["LV"] / df["allCSF"]
df["peri_ratio"] = df["periCSF"] / df["LV"]


df_z = df.copy()
numeric_cols = df.select_dtypes(include="number").columns
df_z[numeric_cols] = df_z[numeric_cols].apply(stats.zscore, nan_policy="omit")

viridis = colormaps["viridis"].resampled(20)

colors = helpers.get_colors()

MS_patients = df["dz_type2"] == "MS"
nonMS_patients = df["dz_type2"] == "!MS"
NIND_patients = df["dz_type5"] == "NIND"
OIND_patients = df["dz_type5"] == "OIND"
RMS_patients = df["dz_type5"] == "RMS"
PMS_patients = df["dz_type5"] == "PMS"

#### HIPS-THOMAS Volumes and Distances

In [3]:
df_thomas = pd.read_csv(data_dir / "hipsthomas_vols.csv", index_col="subid")
df_thomas_left = pd.read_csv(data_dir / "hipsthomas_left_vols.csv", index_col="subid")
df_thomas_right = pd.read_csv(data_dir / "hipsthomas_right_vols.csv", index_col="subid")

cols_orig = df_thomas.columns
new_colnames = {}
for col in df_thomas.columns:
    new_col = re.sub(r"(\d+)-([\w-]+)", r"\2_\1", col)
    new_col = re.sub("-", "_", new_col)
    new_colnames[col] = new_col

df_thomas = df_thomas.rename(columns=new_colnames)
df_thomas_left = df_thomas_left.rename(columns=new_colnames)
df_thomas_right = df_thomas_right.rename(columns=new_colnames)

nuclei_groupings = {
    "anterior": ["AV_2"],
    "ventral": ["VA_4", "VLa_5", "VLP_6", "VPL_7"],
    "posterior": ["Pul_8", "LGN_9", "MGN_10"],
    "medial": ["MD_Pf_12", "CM_11"],
}


def combine_nuclei(df, groupings):
    df2 = pd.DataFrame()
    for group, nuclei in groupings.items():
        df2[group] = sum([df[nucleus] for nucleus in nuclei])
    return df2


df_thomas = df_thomas.join(combine_nuclei(df_thomas, nuclei_groupings))
df_thomas_left = df_thomas_left.join(combine_nuclei(df_thomas_left, nuclei_groupings))
df_thomas_right = df_thomas_right.join(
    combine_nuclei(df_thomas_right, nuclei_groupings)
)


thalamic_nuclei = [2, 4, 5, 6, 7, 8, 9, 10, 11, 12]
thalamic_nuclei_str = [str(i) for i in thalamic_nuclei]
deep_grey = [13, 14, 26, 27, 28, 29, 30, 31, 32]
deep_grey_str = [str(i) for i in deep_grey]


hips_thomas_ref = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/data/hipsthomas_struct_index.csv", index_col="index"
)["struct"]
hips_thomas_invref = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/data/hipsthomas_struct_index.csv", index_col="struct"
)["index"]

### Functions

In [4]:
def zscore(df):
    df_z = df.copy()
    numeric_cols = df.select_dtypes(include="number").columns
    df_z[numeric_cols] = df_z[numeric_cols].apply(stats.zscore, nan_policy="omit")
    return df_z

#### Regression Functions

In [5]:
def run_regressions(
    model_data: pd.DataFrame,
    outcome: str,
    predictors: list[str],
    covariates: list[str] = None,
    robust_cov: str = "HC3",
    fdr_method: str = "fdr_bh",
    fdr_alpha: float = 0.05,
):
    if covariates is None:
        covariates = []
        
    def _get_val_by_name(obj, name, attr):
        import numpy as np

        vals = getattr(obj, attr)
        # pandas Series (has .get)
        if hasattr(vals, "get"):
            return vals.get(name, np.nan)
        # numpy array / list-like: map via model exog names
        try:
            exog_names = list(obj.model.exog_names)
        except Exception:
            exog_names = []
        if name in exog_names:
            idx = exog_names.index(name)
            try:
                return np.asarray(vals)[idx]
            except Exception:
                return np.nan
        return np.nan

    results = {}
    models = {}
    for predictor in predictors:
        independent_vars = [predictor] + covariates
        formula = f"{outcome} ~ {" + ".join(independent_vars)}"
        model = sm.OLS.from_formula(formula, model_data).fit()

        if robust_cov:
            rres = model.get_robustcov_results(cov_type=robust_cov)
        else:
            rres = model
        
        # confidence interval: conf_int() returns DataFrame when names available
        ci_df = rres.conf_int()
        if hasattr(ci_df, "loc") and predictor in ci_df.index:
            llci, ulci = float(ci_df.loc[predictor, 0]), float(ci_df.loc[predictor, 1])
        else:
            # fallback via exog_names -> index
            try:
                exog_names = list(rres.model.exog_names)
                idx = exog_names.index(predictor)
                ci_arr = np.asarray(ci_df)
                llci, ulci = float(ci_arr[idx, 0]), float(ci_arr[idx, 1])
            except Exception:
                llci = ulci = np.nan

        ci_str = f"[{llci:.3}, {ulci:.3}]" if not np.isnan(llci) else ""
        results[predictor] = {
            "beta": _get_val_by_name(rres, predictor, "params"),
            "p_fdr": None,
            "se": _get_val_by_name(rres, predictor, "bse"),
            "llci": llci,
            "ulci": ulci, 
            "ci_str": ci_str,
            "p_raw": _get_val_by_name(rres, predictor, "pvalues"),
            "R2": rres.rsquared_adj,
        }
        models[predictor] = model

    results = pd.DataFrame(results).T

    fdr_method: str = "fdr_bh"
    fdr_alpha = 0.05
    _, p_fdr_vals, _, _ = multipletests(
        results["p_raw"], alpha=fdr_alpha, method=fdr_method
    )
    results["p_fdr"] = p_fdr_vals

    return results


def run_regressions_refactored(
    model_data: pd.DataFrame,
    outcomes,
    predictors,
    covariates: list = [],
    robust_cov: str = "HC3",
    fdr_method: str = "fdr_bh",
    fdr_alpha: float = 0.05,
):
    """
    Run OLS for every (struct, predictor).
    Returns (results_by_struct, results_by_predictor)
    - results_by_struct: dict struct -> DataFrame indexed by predictor
    - results_by_predictor: dict predictor -> DataFrame indexed by struct
    Each DataFrame columns: coef, pval, se, llci, ulci, ci, R2, p_fdr, coef_sig
    """
    if covariates is None:
        covariates = []
    outcomes = list(outcomes)
    predictors = list(predictors)
    # container: per-struct dataframes
    results_by_struct = {}
    
    def _get_val_by_name(obj, name, attr):
        
        import numpy as np
        vals = getattr(obj, attr)
        # pandas Series (has .get)
        if hasattr(vals, "get"):
            return vals.get(name, np.nan)
        # numpy array / list-like: map via model exog names
        try:
            exog_names = list(obj.model.exog_names)
        except Exception:
            exog_names = []
        if name in exog_names:
            idx = exog_names.index(name)
            try:
                return np.asarray(vals)[idx]
            except Exception:
                return np.nan
        return np.nan

    for struct in outcomes:
        rows = []
        for pred in predictors:
            exog = [pred] + covariates
            formula = f"{struct} ~ {" + ".join(exog)}"
            try:
                res = sm.OLS.from_formula(formula, data=model_data).fit()
                if robust_cov:
                    rres = res.get_robustcov_results(cov_type=robust_cov)
                else:
                    rres = res

                coef = _get_val_by_name(rres, pred, "params")
                pval = _get_val_by_name(rres, pred, "pvalues")
                se = _get_val_by_name(rres, pred, "bse")

                # confidence interval: conf_int() returns DataFrame when names available
                ci_df = rres.conf_int()
                if hasattr(ci_df, "loc") and pred in ci_df.index:
                    llci, ulci = float(ci_df.loc[pred, 0]), float(ci_df.loc[pred, 1])
                else:
                    # fallback via exog_names -> index
                    try:
                        exog_names = list(rres.model.exog_names)
                        idx = exog_names.index(pred)
                        ci_arr = np.asarray(ci_df)
                        llci, ulci = float(ci_arr[idx, 0]), float(ci_arr[idx, 1])
                    except Exception:
                        llci = ulci = np.nan

                ci_str = f"[{llci:.3}, {ulci:.3}]" if not np.isnan(llci) else ""
                r2 = res.rsquared_adj
            except Exception as e:
                print(f"Error occurred while processing {pred} for {struct}: {e}")
                coef = pval = se = llci = ulci = np.nan
                ci_str = ""
                r2 = np.nan
                raise e
            rows.append(
                {
                    "predictor": pred,
                    "coef": coef,
                    "pval": pval,
                    "se": se,
                    "llci": llci,
                    "ulci": ulci,
                    "ci": ci_str,
                    "R2": r2,
                }
            )
        df_struct = pd.DataFrame(rows).set_index("predictor")
        # FDR across predictors for this struct
        pvals = df_struct["pval"].fillna(1.0).values 
        _, p_fdr_vals, _, _ = multipletests(pvals, alpha=fdr_alpha, method=fdr_method)
        df_struct.insert(2, "p_fdr", p_fdr_vals)
        df_struct["coef_sig"] = df_struct["coef"].where(df_struct["p_fdr"] < fdr_alpha, 0.0)
        results_by_struct[struct] = df_struct

    # build results_by_predictor for compatibility
    results_by_predictor = {}
    cols = next(iter(results_by_struct.values())).columns
    for pred in predictors:
        rows = []
        for struct in outcomes:
            row = results_by_struct[struct].loc[pred].to_dict()
            row["struct"] = struct
            rows.append(row)
        df_pred = pd.DataFrame(rows).set_index("struct")[cols]
        results_by_predictor[pred] = df_pred

    return results_by_struct, results_by_predictor

#### Functions for Aesthetics

In [None]:
def create_formula(outcome, )

## Analysis

In [8]:
model_data = df.join(df_thomas)[MS_patients]
model_data = zscore(model_data)

outcomes = [
    "medial",
    "posterior",
    "ventral",
    "anterior"
]

predictor = ["CP"]
covariates = ["THALAMUS_1", "age", "Female", "tiv"]

_, results = run_regressions_refactored(
    model_data,
    outcomes,
    predictor,
    covariates
)
results = results["CP"]
display_order = results["coef"].apply(np.abs).sort_values(ascending=False).index
display(Markdown(f"`struct ~ CP + {" + ".join(covariates)}`"))
display(results.loc[display_order, :])

`struct ~ CP + THALAMUS_1 + age + Female + tiv`

Unnamed: 0_level_0,coef,pval,p_fdr,se,llci,ulci,ci,R2,coef_sig
struct,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
ventral,0.129242,1.346145e-10,1.346145e-10,0.019667,0.090595,0.16789,"[0.0906, 0.168]",0.871857,0.129242
medial,-0.075815,0.0002140643,0.0002140643,0.020318,-0.115742,-0.035889,"[-0.116, -0.0359]",0.891332,-0.075815
posterior,-0.045265,0.01175638,0.01175638,0.017896,-0.080432,-0.010099,"[-0.0804, -0.0101]",0.903201,-0.045265
anterior,0.037002,0.3997702,0.3997702,0.043903,-0.049273,0.123277,"[-0.0493, 0.123]",0.561576,0.0


In [12]:
hips_thomas_ref[thalamic_nuclei]

index
2         AV_2
4         VA_4
5        VLa_5
6        VLP_6
7        VPL_7
8        Pul_8
9        LGN_9
10      MGN_10
11       CM_11
12    MD_Pf_12
Name: struct, dtype: object

In [14]:
model_data = df.join(df_thomas)[MS_patients]
model_data = zscore(model_data)

outcomes = hips_thomas_ref[thalamic_nuclei]

predictor = ["CP"]
covariates = ["THALAMUS_1", "age", "Female", "tiv"]

_, results = run_regressions_refactored(
    model_data,
    outcomes,
    predictor,
    covariates
)
results = results["CP"]
display_order = results["coef"].apply(np.abs).sort_values(ascending=False).index
display(Markdown(f"`struct ~ CP + {" + ".join(covariates)}`"))
display(results.loc[display_order, :])

`struct ~ CP + THALAMUS_1 + age + Female + tiv`

Unnamed: 0_level_0,coef,pval,p_fdr,se,llci,ulci,ci,R2,coef_sig
struct,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
VA_4,0.177934,4.109224e-09,4.109224e-09,0.029681,0.119607,0.236261,"[0.12, 0.236]",0.674796,0.177934
VLa_5,0.144004,0.0002241412,0.0002241412,0.038716,0.067923,0.220084,"[0.0679, 0.22]",0.587137,0.144004
VLP_6,0.112568,6.509353e-06,6.509353e-06,0.024677,0.064075,0.16106,"[0.0641, 0.161]",0.817064,0.112568
CM_11,-0.103807,0.007827863,0.007827863,0.038863,-0.180177,-0.027436,"[-0.18, -0.0274]",0.62244,-0.103807
VPL_7,0.088725,0.00186658,0.00186658,0.028358,0.032998,0.144452,"[0.033, 0.144]",0.751606,0.088725
LGN_9,-0.073721,0.0855337,0.0855337,0.042783,-0.157794,0.010353,"[-0.158, 0.0104]",0.558797,0.0
MGN_10,-0.07351,0.01898802,0.01898802,0.031227,-0.134874,-0.012145,"[-0.135, -0.0121]",0.705842,-0.07351
MD_Pf_12,-0.068005,0.001550602,0.001550602,0.021358,-0.109976,-0.026034,"[-0.11, -0.026]",0.871944,-0.068005
Pul_8,-0.040285,0.0299022,0.0299022,0.018496,-0.076631,-0.00394,"[-0.0766, -0.00394]",0.894306,-0.040285
AV_2,0.037002,0.3997702,0.3997702,0.043903,-0.049273,0.123277,"[-0.0493, 0.123]",0.561576,0.0
