In [None]:
import warnings
from glob import glob

import mne
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP

warnings.filterwarnings(
    "ignore",
    message="'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.",
    category=FutureWarning,
    module="sklearn.utils.deprecation",
)

In [None]:
DATA_DIR = "../local_data/v1"
ACT_MINUS_PCB = True
NORMALIZE = False

IGNORE_FEATURES = [
    "lsd-Closed1",
    "lsd-Closed1-pcb",
    "lsd-Closed2",
    "lsd-Closed2-pcb",
    "lsd-Music",
    "lsd-Music-pcb",
    # "lsd-Open1",
    # "lsd-Open1-pcb",
    "lsd-Open2",
    "lsd-Open2-pcb",
    "lsd-Video",
    "lsd-Video-pcb",
    # "ketamine",
    # "ketamine-pcb",
    # "psilocybin",
    # "psilocybin-pcb",
    # "perampanel",
    # "perampanel-pcb",
    # "tiagabine",
    # "tiagabine-pcb",
    "lsd-avg",
    "lsd-avg-pcb",
]
COLOR_MAP = {
    "lsd-Closed1": "#c6b4e3",
    "lsd-Closed2": "#b19bdc",
    "lsd-Music": "#9b83d5",
    "lsd-Open2": "#856bcc",
    "lsd-Open1": "#6f54c3",
    "lsd-Video": "#593ebb",
    "lsd-avg": "#4327b2",
    "ketamine": "#1dbc7c",
    "psilocybin": "#bf00ee",
    "perampanel": "#bfa900",
    "tiagabine": "#e61a1a",
}

In [None]:
def get_feature_names(df):
    return list(set([col.replace("feature-", "").split(".")[0] for col in df.columns if col.startswith("feature")]))


def get_channel_names(df):
    return list(set([col[-5:] for col in df.columns if ".spaces-" in col]))


ft_names, ch_names = None, None
data = {}
for path in sorted(glob(DATA_DIR + "/*.csv")):
    name = path.split("/")[-1].split(".")[0]
    if name.startswith("aggregate"):
        continue

    df = pd.read_csv(path, index_col=0)

    if ft_names is None:
        ft_names = get_feature_names(df)
        ch_names = get_channel_names(df)
    else:
        assert ft_names == get_feature_names(df), "Feature names do not match across datasets."
        assert ch_names == get_channel_names(df), "Channel names do not match across datasets."

    target = df["target"]
    df = df.drop(columns="target")

    if NORMALIZE:
        for col in df.columns:
            df[col] = (df[col] - df[col].mean()) / (df[col].std() + 1e-4)

    if ACT_MINUS_PCB:
        if name not in IGNORE_FEATURES:
            data[name] = (df[target == 1] - df[target == 0]) / (df[target == 0] + 1e-4)
    else:
        if name not in IGNORE_FEATURES:
            data[name] = df[target == 1]
        if name + "-pcb" not in IGNORE_FEATURES:
            data[name + "-pcb"] = df[target == 0]

print("Loaded datasets:", ", ".join(data.keys()))
print("Feature names:", ", ".join(ft_names))
print("Channel names:", ", ".join(ch_names))

In [None]:
def get_data(df, ft_name, avg_subjs=False, avg_chs=False):
    col_names = [f"feature-{ft_name}.spaces-{ch}" for ch in ch_names]
    data = df[col_names].values
    if avg_subjs:
        data = data.mean(axis=0, keepdims=True)
    if avg_chs:
        data = data.mean(axis=1, keepdims=True)
    return data


data_arrs = {}
for name, df in data.items():
    data_arrs[name] = np.concatenate([get_data(df, ft_name) for ft_name in ft_names], axis=-1)

In [None]:
# dr = UMAP(n_components=2)
# dr = PCA(n_components=2)
dr = TSNE(n_components=2)

comps = dr.fit_transform(np.concatenate(list(data_arrs.values()), axis=0))
comps = {name: comps[i * len(df) : (i + 1) * len(df), :] for i, (name, df) in enumerate(data.items())}

In [None]:
ax = plt.figure(figsize=(13, 7)).add_subplot(111, projection="3d" if dr.n_components == 3 else None)
for name, comp in comps.items():
    ax.scatter(*comp.T, label=name, color=COLOR_MAP[name.replace("-pcb", "")], marker="o" if "-pcb" not in name else "x")
plt.legend()
plt.show()