## Set Up

### Imports

In [2]:
import pprint
from warnings import simplefilter

import pandas as pd
from IPython.display import Markdown, display
from statsmodels.stats.multitest import multipletests

simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
import json
import re
import textwrap
from collections import defaultdict
from datetime import datetime
from pathlib import Path

import helpers
import matplotlib.pyplot as plt
import numpy as np
import pyperclip
import statsmodels.api as sm
from IPython.display import clear_output
from matplotlib import colormaps
from scipy import stats
from statsmodels.genmod.families import Poisson

# from reload_recursive import reload_recursive
from statsmodels.stats.mediation import Mediation
from statsmodels.stats.outliers_influence import variance_inflation_factor
from tqdm.notebook import tqdm

from mri_data import file_manager as fm

### Load Data

#### Clinical and Volumes

In [None]:
drive_root = fm.get_drive_root()
dataroot = drive_root / "3Tpioneer_bids"
data_dir = Path("/home/srs-9/Projects/ms_mri/data")
fig_path = Path(
    "/home/srs-9/Projects/ms_mri/analysis/thalamus/figures_tables/choroid_associations"
)

choroid_volumes = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/data/choroid_aschoplex_volumes.csv", index_col="subid"
)
ventricle_volumes = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/analysis/paper1/data0/ventricle_volumes.csv",
    index_col="subid",
)
csf_volumes = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/analysis/thalamus/data0/csf_volumes.csv",
    index_col="subid",
)
third_ventricle_width = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/analysis/thalamus/data0/third_ventricle_width.csv",
    index_col="subid",
)

tiv = pd.read_csv("/home/srs-9/Projects/ms_mri/data/tiv_data.csv", index_col="subid")

df = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/data/clinical_data_processed.csv", index_col="subid"
)
sdmt = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/analysis/thalamus/SDMT_sheet.csv", index_col="subid"
)
df = df.join(
    [
        choroid_volumes,
        ventricle_volumes,
        csf_volumes,
        third_ventricle_width,
        tiv,
        sdmt["SDMT"],
    ]
)
rename_columns = {
    "ventricle_volume": "LV",
    "choroid_volume": "CP",
    "peripheral": "periCSF",
    "all": "allCSF",
    "third_ventricle": "thirdV",
    "third_ventricle_width": "thirdV_width"
}
#! need to fix the actual segmentation files
df.rename(columns={"ventricle_volume": "LV", "choroid_volume": "CP"}, inplace=True)

df["periCSF_frac"] = csf_volumes["peripheral"] / csf_volumes["all"]

df["SDMT"] = pd.to_numeric(df["SDMT"], errors="coerce")
df["thalamus_sqrt"] = np.sqrt(df["thalamus"])
df["thalamus_curt"] = np.sqrt(df["thalamus"] ** 3)
df["cortical_thickness_inv"] = 1 / df["cortical_thickness"]
df["LV_logtrans"] = np.log(df["LV"])

# these corrections should ultimately be made to the csf file
for struct in ["brain", "white", "grey", "thalamus", "t2lv"]:
    df[struct] = df[struct] * 1000

df["CCF"] = df["LV"] / df["allCSF"]
df["peri_ratio"] = df["periCSF"] / df["LV"]


df_z = df.copy()
numeric_cols = df.select_dtypes(include="number").columns
df_z[numeric_cols] = df_z[numeric_cols].apply(stats.zscore, nan_policy="omit")

viridis = colormaps["viridis"].resampled(20)

colors = helpers.get_colors()

MS_patients = df["dz_type2"] == "MS"
nonMS_patients = df["dz_type2"] == "!MS"
NIND_patients = df["dz_type5"] == "NIND"
OIND_patients = df["dz_type5"] == "OIND"
RMS_patients = df["dz_type5"] == "RMS"
PMS_patients = df["dz_type5"] == "PMS"

#### HIPS-THOMAS Volumes and Distances

In [None]:
df_thomas = pd.read_csv(data_dir / "hipsthomas_vols.csv", index_col="subid")
df_thomas_left = pd.read_csv(data_dir / "hipsthomas_left_vols.csv", index_col="subid")
df_thomas_right = pd.read_csv(data_dir / "hipsthomas_right_vols.csv", index_col="subid")

cols_orig = df_thomas.columns
new_colnames = {}
for col in df_thomas.columns:
    new_col = re.sub(r"(\d+)-([\w-]+)", r"\2_\1", col)
    new_col = re.sub("-", "_", new_col)
    new_colnames[col] = new_col

df_thomas = df_thomas.rename(columns=new_colnames)
df_thomas_left = df_thomas_left.rename(columns=new_colnames)
df_thomas_right = df_thomas_right.rename(columns=new_colnames)

nuclei_groupings = {
    "anterior": ["AV_2"],
    "ventral": ["VA_4", "VLa_5", "VLP_6", "VPL_7"],
    "posterior": ["Pul_8", "LGN_9", "MGN_10"],
    "medial": ["MD_Pf_12", "CM_11"],
}


def combine_nuclei(df, groupings):
    df2 = pd.DataFrame()
    for group, nuclei in groupings.items():
        df2[group] = sum([df[nucleus] for nucleus in nuclei])
    return df2


df_thomas = df_thomas.join(combine_nuclei(df_thomas, nuclei_groupings))
df_thomas_left = df_thomas_left.join(combine_nuclei(df_thomas_left, nuclei_groupings))
df_thomas_right = df_thomas_right.join(
    combine_nuclei(df_thomas_right, nuclei_groupings)
)


thalamic_nuclei = [2, 4, 5, 6, 7, 8, 9, 10, 11, 12]
thalamic_nuclei_str = [str(i) for i in thalamic_nuclei]
deep_grey = [13, 14, 26, 27, 28, 29, 30, 31, 32]
deep_grey_str = [str(i) for i in deep_grey]


hips_thomas_ref = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/data/hipsthomas_struct_index.csv", index_col="index"
)["struct"]
hips_thomas_invref = pd.read_csv(
    "/home/srs-9/Projects/ms_mri/data/hipsthomas_struct_index.csv", index_col="struct"
)["index"]

In [14]:
hips_thomas_ref

index
1     THALAMUS_1
2           AV_2
4           VA_4
5          VLa_5
6          VLP_6
7          VPL_7
8          Pul_8
9          LGN_9
10        MGN_10
11         CM_11
12      MD_Pf_12
13         Hb_13
14        MTT_14
26        Acc_26
27        Cau_27
28        Cla_28
29        GPe_29
30        GPi_30
31        Put_31
32         RN_32
33         GP_33
34        Amy_34
Name: struct, dtype: object

In [None]:
def zscore(df):
    df_z = df.copy()
    numeric_cols = df.select_dtypes(include="number").columns
    df_z[numeric_cols] = df_z[numeric_cols].apply(stats.zscore, nan_policy="omit")
    return df_z

## Checks

In [None]:
def screen_variable(data, var_name):
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    # Histogram
    axes[0, 0].hist(data[var_name].dropna(), bins=30, edgecolor='black')
    axes[0, 0].set_title(f'Histogram: {var_name}')
    axes[0, 0].set_xlabel(var_name)

    # Q-Q plot
    stats.probplot(data[var_name].dropna(), dist="norm", plot=axes[0, 1])
    axes[0, 1].set_title(f'Q-Q Plot: {var_name}')

    # Boxplot
    axes[1, 0].boxplot(data[var_name].dropna())
    axes[1, 0].set_title(f'Boxplot: {var_name}')
    axes[1, 0].set_ylabel(var_name)

    # Summary stats
    axes[1, 1].axis('off')
    summary = data[var_name].describe()
    skew = data[var_name].skew()
    kurt = data[var_name].kurtosis()

    summary_text = f"""
    Mean: {summary['mean']:.2f}
    Median: {summary['50%']:.2f}
    Std: {summary['std']:.2f}
    Min: {summary['min']:.2f}
    Max: {summary['max']:.2f}
    
    Skewness: {skew:.2f}
    Kurtosis: {kurt:.2f}
    
    Rule of thumb:
    |Skew| < 1: OK
    |Skew| 1-2: Moderate
    |Skew| > 2: Severe
    """
    axes[1, 1].text(0.1, 0.5, summary_text, fontsize=10, family='monospace')

    plt.tight_layout()
    plt.savefig(
        f'/mnt/user-data/outputs/{var_name}_screening.png', dpi=300, bbox_inches='tight'
    )
    plt.close()

    return skew, kurt

In [None]:
all_predictors = [
    "LV",
    "CP",
    "periCSF",
    "allCSF",
    "thirdV",
    "thirdV_width",
    "THALAMUS_1",
    "medial",
    "posterior",
    "ventral",
    "anterior",
    "t2lv",
    "brain",
    "white",
    "grey",
]

data