In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings
from glob import glob

import mne
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP

from cocodelics.utils import COLOR_MAP, get_feature_data, load_data

warnings.filterwarnings(
    "ignore",
    message="'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.",
    category=FutureWarning,
    module="sklearn.utils.deprecation",
)

In [None]:
DATA_DIR = "../local_data/v1"

IGNORE_FEATURES = [
    # "lsd-Closed1",
    # "lsd-Closed1-pcb",
    "lsd-Closed2",
    "lsd-Closed2-pcb",
    "lsd-Music",
    "lsd-Music-pcb",
    # "lsd-Open1",
    # "lsd-Open1-pcb",
    "lsd-Open2",
    "lsd-Open2-pcb",
    "lsd-Video",
    "lsd-Video-pcb",
    # "ketamine",
    # "ketamine-pcb",
    # "psilocybin",
    # "psilocybin-pcb",
    # "perampanel",
    # "perampanel-pcb",
    # "tiagabine",
    # "tiagabine-pcb",
    "lsd-avg",
    "lsd-avg-pcb",
]
IGNORE_ROIS = [
    "ROI_Frontal_Left",
    "ROI_Frontal_Right",
    # "ROI_Central_Left",
    # "ROI_Central_Right",
    # "ROI_Parietal_Left",
    # "ROI_Parietal_Right",
    # "ROI_Temporal_Left",
    # "ROI_Temporal_Right",
    # "ROI_Occipital_Left",
    # "ROI_Occipital_Right",
    # "ROI_Midline",
]

In [None]:
data, ft_names, ch_names, col_names = load_data(DATA_DIR, ignore_features=IGNORE_FEATURES, rois=False, ignore_rois=IGNORE_ROIS)
print("Loaded datasets:", ", ".join(data.keys()))
print("Feature names:", ", ".join(ft_names))
print("Channel names:", ", ".join(ch_names))

In [None]:
data_arrs = {}
for name, df in data.items():
    data_arrs[name] = np.concatenate([get_feature_data(df, ft_name, ch_names) for ft_name in ft_names], axis=-1)

In [None]:
# Categories, using "MeanEpochs" only
entropy_measures = [f for f in ft_names if "Entropy" in f and "MeanEpochs" in f]
fractal_measures = [f for f in ft_names if ("Fd" in f or "detrendedFluctuation" in f) and "MeanEpochs" in f]
complexity_measures = [f for f in ft_names if "Complexity" in f and "MeanEpochs" in f]


# col_names is your full feature (column) list for arrays
def get_indices(col_names, measure_list):
    return [i for i, name in enumerate(col_names) if any(feature in name for feature in measure_list)]


def get_matching_cols(df, feature_list):
    """Return columns of df that contain any feature in feature_list as substring."""
    return [col for col in df.columns for feat in feature_list if feat in col]


entropy_idx = get_indices(col_names, entropy_measures)
fractal_idx = get_indices(col_names, fractal_measures)
complexity_idx = get_indices(col_names, complexity_measures)

# Filtered arrays
entropy_arrs = {cond: arr[:, entropy_idx] for cond, arr in data_arrs.items()}
fractal_arrs = {cond: arr[:, fractal_idx] for cond, arr in data_arrs.items()}
complexity_arrs = {cond: arr[:, complexity_idx] for cond, arr in data_arrs.items()}

In [None]:
source_data_pca = entropy_arrs
pca = PCA(n_components=10)
pca_comps = pca.fit_transform(np.concatenate(list(source_data_pca.values()), axis=0))
bounds = np.cumsum([0] + [len(arr) for arr in source_data_pca.values()])
pca_comps = {name: pca_comps[bounds[i] : bounds[i + 1], :] for i, name in enumerate(source_data_pca.keys())}
plt.title("PCA explained variance")
plt.plot(np.cumsum(pca.explained_variance_ratio_))
plt.xlabel("Number of components")
plt.show()

# dr = UMAP(n_components=2)
# dr = PCA(n_components=2)
dr = TSNE(n_components=2)

source_data = data_arrs
comps = dr.fit_transform(np.concatenate(list(source_data.values()), axis=0))
bounds = np.cumsum([0] + [len(arr) for arr in source_data.values()])
comps = {name: comps[bounds[i] : bounds[i + 1], :] for i, name in enumerate(source_data.keys())}

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Ellipse


def plot_cov_ellipse(mean, cov, ax, n_std=2.0, facecolor="none", edgecolor="black", **kwargs):
    """Draw an ellipse for the covariance matrix."""
    # Eigenvalues and eigenvectors for orientation and axes
    vals, vecs = np.linalg.eigh(cov)
    order = vals.argsort()[::-1]
    vals, vecs = vals[order], vecs[:, order]
    # Angle of ellipse
    theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
    # Width and height are 2*n_std*sqrt(eigenvalues)
    width, height = 1 * n_std * np.sqrt(vals)
    ellip = Ellipse(
        xy=mean, width=width, height=height, angle=theta, edgecolor=edgecolor, facecolor=facecolor, lw=2, alpha=0.25, **kwargs
    )
    ax.add_patch(ellip)


fig, ax = plt.subplots(figsize=(13, 7))
for name, comp in comps.items():
    color = COLOR_MAP[name.replace("-pcb", "")]
    # Scatter points
    ax.scatter(*comp.T, label=name, color=color, marker="o" if "-pcb" not in name else "x")
    # Mean
    mean = comp.mean(axis=0)
    ax.scatter(*mean, color=color, s=150, marker="*", edgecolor="k", zorder=10)
    # Covariance ellipse (spread)
    if comp.shape[0] > 2:  # need at least 3 points for cov
        cov = np.cov(comp.T)
        plot_cov_ellipse(mean, cov, ax, edgecolor=color, facecolor=color)
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Ellipse

# dr = UMAP(n_components=2, n_neighbors=30, random_state=42)
# dr = PCA(n_components=2)
dr = TSNE(n_components=2, perplexity=30)

categories = [
    ("Entropy", entropy_arrs),
    ("Fractality", fractal_arrs),
    ("Complexity", complexity_arrs),
    ("All Features", data_arrs),
]

fig, axs = plt.subplots(2, 2, figsize=(18, 12))  # Wider/taller for 2x2 clarity

# Flatten axs for easy zip
axs = axs.flatten()

for ax, (catname, features) in zip(axs, categories):
    comps = dr.fit_transform(np.concatenate(list(features.values()), axis=0))
    bounds = np.cumsum([0] + [len(arr) for arr in features.values()])
    comps = {name: comps[bounds[i] : bounds[i + 1], :] for i, name in enumerate(features.keys())}

    for name, comp in comps.items():
        color = COLOR_MAP[name.replace("-pcb", "")]
        ax.scatter(*comp.T, label=name, color=color, marker="o" if "-pcb" not in name else "x")
        mean = comp.mean(axis=0)
        ax.scatter(*mean, color=color, s=150, marker="*", edgecolor="k", zorder=10)
        if comp.shape[0] > 2 and comp.shape[1] == 2:
            cov = np.cov(comp.T)
            plot_cov_ellipse(mean, cov, ax, edgecolor=color, facecolor=color)
    ax.set_title(catname, fontsize=20)
    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")
    ax.relim()
    ax.autoscale_view()
    ax.margins(0.1)

# Add legend to the last subplot
axs[-1].legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=15)

plt.tight_layout()
plt.show()