In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings
from glob import glob

import mne
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP

from cocodelics.utils import COLOR_MAP, get_data, load_data

warnings.filterwarnings(
    "ignore",
    message="'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.",
    category=FutureWarning,
    module="sklearn.utils.deprecation",
)

In [None]:
DATA_DIR = "../local_data/v1"

IGNORE_FEATURES = [
    # "lsd-Closed1",
    # "lsd-Closed1-pcb",
    "lsd-Closed2",
    "lsd-Closed2-pcb",
    "lsd-Music",
    "lsd-Music-pcb",
    # "lsd-Open1",
    # "lsd-Open1-pcb",
    "lsd-Open2",
    "lsd-Open2-pcb",
    "lsd-Video",
    "lsd-Video-pcb",
    # "ketamine",
    # "ketamine-pcb",
    # "psilocybin",
    # "psilocybin-pcb",
    # "perampanel",
    # "perampanel-pcb",
    # "tiagabine",
    # "tiagabine-pcb",
    "lsd-avg",
    "lsd-avg-pcb",
]

In [None]:
data, ft_names, ch_names = load_data(DATA_DIR, IGNORE_FEATURES)
print("Loaded datasets:", ", ".join(data.keys()))
print("Feature names:", ", ".join(ft_names))
print("Channel names:", ", ".join(ch_names))

In [None]:
# def get_data(df, ft_name, avg_subjs=False, avg_chs=False):
#     col_names = [f"feature-{ft_name}.spaces-{ch}" for ch in ch_names]
#     data = df[col_names].values
#     if avg_subjs:
#         data = data.mean(axis=0, keepdims=True)
#     if avg_chs:
#         data = data.mean(axis=1, keepdims=True)
#     return data


data_arrs = {}
for name, df in data.items():
    data_arrs[name] = np.concatenate([get_data(df, ft_name, ch_names) for ft_name in ft_names], axis=-1)

In [None]:
pca = PCA(n_components=10)
comps = pca.fit_transform(np.concatenate(list(data_arrs.values()), axis=0))
comps = {name: comps[i * len(df) : (i + 1) * len(df), :] for i, (name, df) in enumerate(data.items())}
plt.title("PCA explained variance")
plt.plot(np.cumsum(pca.explained_variance_ratio_))
plt.xlabel("Number of components")
plt.show()

# dr = UMAP(n_components=2)
# dr = PCA(n_components=2)
dr = TSNE(n_components=2)

comps = dr.fit_transform(np.concatenate(list(comps.values()), axis=0))
comps = {name: comps[i * len(df) : (i + 1) * len(df), :] for i, (name, df) in enumerate(data.items())}

In [None]:
ax = plt.figure(figsize=(13, 7)).add_subplot(111, projection="3d" if dr.n_components == 3 else None)
for name, comp in comps.items():
    ax.scatter(*comp.T, label=name, color=COLOR_MAP[name.replace("-pcb", "")], marker="o" if "-pcb" not in name else "x")
plt.legend()
plt.show()