In [18]:
import sys
import numpy as np
import sklearn.cross_decomposition as skcd
import pandas as pd

sys.path.append("bciflow/")

In [19]:
from bciflow.datasets.menggu import menggu

In [20]:
def channel_selection(dataset, channels_selected):
    new_dataset = {}
    ch_names = dataset["ch_names"]
    ch_idx = np.where(np.isin(ch_names, channels_selected))[0]
    new_dataset["X"] = dataset["X"][:, :, ch_idx, :]
    new_dataset["y"] = dataset["y"]
    dataset["ch_names"] = channels_selected
    return new_dataset

In [21]:
def filter_dataset(dataset, targets, time_window):
    new_dataset = {}
    targets = np.arange(targets[0], targets[1] + 1)
    idx = np.isin(dataset["y"], targets)
    new_dataset["X"] = dataset["X"][idx, :, :, time_window[0] : time_window[1]]
    new_dataset["y"] = dataset["y"][idx]
    return new_dataset

In [22]:
def build_target(target_freq, sfreq, total_time, num_harmonics=3):
    y = np.zeros((num_harmonics * 2, total_time))
    for i in range(1, num_harmonics + 1):
        y_sin = np.sin(
            2 * np.pi * target_freq * i * np.arange(total_time) / sfreq
        )
        y_cos = np.cos(
            2 * np.pi * target_freq * i * np.arange(total_time) / sfreq
        )
        y[(i - 1) * 2] = y_sin
        y[(i - 1) * 2 + 1] = y_cos
    return y


def cca(X, sfreq, total_time=5000, targets=(2, 5), num_harmonics=3):
    if type(targets) == tuple:
        targets = np.arange(targets[0], targets[1] + 1)
    # print(targets)

    y = np.array(
        [
            build_target(target, sfreq, total_time, num_harmonics)
            for target in targets
        ]
    )

    output = []

    for trial in range(X.shape[0]):
        for band in range(X.shape[1]):
            X_ = X[trial, band, :, :]
            corr_coefs = np.array([])
            for y_ in y:
                cca_ = skcd.CCA(n_components=1)
                cca_.fit(X_.T, y_.T)
                X_c, y_c = cca_.transform(X_.T, y_.T)
                corr = np.corrcoef(X_c.T, y_c.T)[0, 1]
                corr_coefs = np.append(corr_coefs, corr)
            output.append(targets[np.argmax(corr_coefs)])

    return output

In [23]:
# dataset = menggu(subject=1, path='data/', depth=["high"])

In [24]:
# print("Example before channel selection")
# dataset.keys()
# print("X shape: ", dataset['X'].shape)
# print("y shape: ", dataset['y'].shape)
# print("sfreq: ", dataset['sfreq'])
# print("y_dict: ", dataset['y_dict'])
# print("events: ", dataset['events'])
# print("ch_names: ", dataset['ch_names'])
# print("tmin: ", dataset['tmin'])

In [25]:
# channels_selected = ["PZ", "PO3", "PO4", "PO5", "PO6", "POZ", "O1", "O2", "OZ"]
# dataset_ = channel_selection(dataset, channels_selected)
# print("X shape: ", dataset['X'].shape)
# print("X shape: ", dataset_['X'].shape)

In [26]:
# targets = (14, 21)
# time_window = [0, 1000]
# print("X shape: ", dataset_['X'].shape)
# dataset_ = filter_dataset(dataset_, targets, time_window)
# print("X shape: ", dataset_['X'].shape)

In [27]:
# output = cca(dataset_['X'], 1000, time_window[1] - time_window[0], targets, 1)

# output = np.array(output)
# accuracy = (output == dataset_['y']).sum() / len(dataset_['y'])

# print("Accuracy: ", accuracy)

In [28]:
subjects = [i for i in range(1, 31)]
print("Subjects: ", subjects)
targets_window = 8
targets = [
    (i, i + targets_window - 1) for i in range(1, 64 + 1 - targets_window + 1)
]
print("Targets: ", targets)
depths = ["high", "low"]
print("Depths: ", depths)
channels_selected = ["PZ", "PO3", "PO4", "PO5", "PO6", "POZ", "O1", "O2", "OZ"]
print("Channels selected: ", channels_selected)
time_windows = [[0, i * 500] for i in range(1, 5)]
print("Time windows: ", time_windows)

Subjects:  [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
Targets:  [(1, 8), (2, 9), (3, 10), (4, 11), (5, 12), (6, 13), (7, 14), (8, 15), (9, 16), (10, 17), (11, 18), (12, 19), (13, 20), (14, 21), (15, 22), (16, 23), (17, 24), (18, 25), (19, 26), (20, 27), (21, 28), (22, 29), (23, 30), (24, 31), (25, 32), (26, 33), (27, 34), (28, 35), (29, 36), (30, 37), (31, 38), (32, 39), (33, 40), (34, 41), (35, 42), (36, 43), (37, 44), (38, 45), (39, 46), (40, 47), (41, 48), (42, 49), (43, 50), (44, 51), (45, 52), (46, 53), (47, 54), (48, 55), (49, 56), (50, 57), (51, 58), (52, 59), (53, 60), (54, 61), (55, 62), (56, 63), (57, 64)]
Depths:  ['high', 'low']
Channels selected:  ['PZ', 'PO3', 'PO4', 'PO5', 'PO6', 'POZ', 'O1', 'O2', 'OZ']
Time windows:  [[0, 500], [0, 1000], [0, 1500], [0, 2000]]


In [29]:
def run(subject, depth):
    table = []
    dataset = menggu(subject=subject, path="data/", depth=[depth])
    dataset_ = channel_selection(dataset, channels_selected)
    for target in targets:
        print("    Running for target: ", target)
        for time_window in time_windows:
            print("      Running for time window: ", time_window)
            dataset__ = filter_dataset(dataset_, target, time_window)
            print("        Dataset shape: ", dataset__["X"].shape)
            output = cca(
                dataset__["X"],
                1000,
                time_window[1] - time_window[0],
                target,
                1,
            )
            output = np.array(output)
            accuracy = (output == dataset__["y"]).sum() / len(dataset__["y"])
            table.append([subject, depth, target, time_window, accuracy])

    return table

In [None]:
# for each save the results in a table.csv
for subject in subjects:
    for depth in depths:
        print(f"Running subject {subject} with depth {depth}")
        table = run(int(subject), depth)
        df = pd.DataFrame(
            table,
            columns=["subject", "depth", "target", "time_window", "accuracy"],
        )
        df.to_csv(f"subject_{subject}_depth_{depth}.csv", index=False)

Running subject 1 with depth high
    Running for target:  (1, 8)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
        Dataset shape:  (96, 1, 9, 1500)
      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (2, 9)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
        Dataset shape:  (96, 1, 9, 1500)
      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (3, 10)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
   



      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (10, 17)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
        Dataset shape:  (96, 1, 9, 1500)




      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (11, 18)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
        Dataset shape:  (96, 1, 9, 1500)




      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (12, 19)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
        Dataset shape:  (96, 1, 9, 1500)




      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (13, 20)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
        Dataset shape:  (96, 1, 9, 1500)




      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (14, 21)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
        Dataset shape:  (96, 1, 9, 1500)




      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (15, 22)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
        Dataset shape:  (96, 1, 9, 1500)
      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (16, 23)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 9, 1000)
      Running for time window:  [0, 1500]
        Dataset shape:  (96, 1, 9, 1500)
      Running for time window:  [0, 2000]
        Dataset shape:  (96, 1, 9, 2000)
    Running for target:  (17, 24)
      Running for time window:  [0, 500]
        Dataset shape:  (96, 1, 9, 500)
      Running for time window:  [0, 1000]
        Dataset shape:  (96, 1, 