In [72]:
import itertools
import json
import os
import warnings

import mne
import numpy as np
from lempel_ziv_complexity import lempel_ziv_complexity
from scipy.sparse.linalg import eigs, ArpackError
from tqdm.notebook import tqdm

%matplotlib notebook

In [25]:
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.str_):
            return str(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super().default(obj)

In [26]:
BANDS = ["delta", "theta", "alpha", "beta", "gamma"]
RATIOS_TWO_BANDS = [
    ("theta", "delta"),
    ("alpha", "delta"),
    ("alpha", "theta"),
    ("beta", "delta"),
    ("beta", "theta"),
    ("beta", "alpha"),
    ("gamma", "delta"),
    ("gamma", "theta"),
    ("gamma", "alpha"),
    ("gamma", "beta"),
]
RATIOS_THREE_BANDS = [
    ("alpha", "delta", "theta"),
    ("beta", "delta", "theta"),
    ("beta", "theta", "alpha"),
    ("gamma", "delta", "theta"),
    ("gamma", "theta", "alpha"),
    ("gamma", "alpha", "beta"),
]
POSSIBLE_N_BINS_LZC = ["median", "3", "4", "5", "6"]
POSSIBLE_PARAMS_S_LAMBDA = {
    "spatial": ["4", "6", "10", "14", "20", "24"],
    "temporal": ["3", "4"] # 6, 24
}

In [27]:
def spatial_encoding(signal, n_bins):
    norm_signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal))

    if n_bins == "median":
        bins = [0, np.median(norm_signal)]
    else:
        bins = np.linspace(0, 1, int(n_bins) + 1)[:-1]

    symbols = np.digitize(norm_signal, bins) - 1
    symbols = list(map(
        lambda x: x if x < 10 else chr(ord("a") + (x - 10)),
        symbols
    ))

    return ''.join(map(str, symbols))

In [28]:
def temporal_encoding(signal, window_size):
    window_size = int(window_size)

    dictionary = {
        permutation: chr(ord("a") + i)
        for i, permutation in enumerate(list(itertools.permutations(range(window_size))))
    }

    symbols = []
    for i in range(len(signal) - window_size + 1):
        window = signal[i:i + window_size]
        order = tuple(np.argsort(window))
        symbols.append(dictionary[order])

    return "".join(symbols)

In [29]:
def signal_encoding(signal, encoding_type, encoding_param):
    if encoding_type == "spatial":
        return spatial_encoding(signal, encoding_param)
    return temporal_encoding(signal, encoding_param)

In [30]:
def lzc(epochs, n_bins="median"):
    result = []
    for epoch in epochs:
        symbols = spatial_encoding(epoch, n_bins)
        result.append(lempel_ziv_complexity(symbols))
    return np.array(result)

In [51]:
def create_transition_matrix(symbols):
    state_index = {state: idx for idx, state in enumerate(sorted(set(symbols)))}
    N = len(state_index)
    matrix = np.zeros((N, N))

    for i in range(len(symbols) - 1):
        index_from = state_index[symbols[i]]
        index_to = state_index[symbols[i + 1]]
        matrix[index_from, index_to] += 1

    matrix = matrix / np.maximum(matrix.sum(axis=1, keepdims=True), 1e-12)
    matrix = np.maximum(matrix, 1e-12)

    return matrix

In [74]:
def compute_S_and_Lambda(transition_matrix):
    try:
        vals, vecs = eigs(transition_matrix.T, k=1, which='LM')
        rho = np.real(vecs[:, 0])
    except ArpackError:
        try:
            print("Error while computing eigs, trying again")
            vals, vecs = np.linalg.eig(transition_matrix.T)
            idx = np.argmax(np.real(vals))
            rho = np.real(vecs[:, idx])
        except Exception:
            print("Second error, skipping calculations")
            return np.nan, np.nan

    rho = rho / rho.sum()

    L1 = -transition_matrix * np.log(transition_matrix)
    L2 = transition_matrix * np.log(transition_matrix) ** 2

    S = np.sum(rho @ L1)

    L_square = np.sum(rho @ L2)
    Lambda = L_square - S ** 2

    return S, Lambda


In [75]:
def get_features(folder, channels_list, size=30):
    for filename in os.listdir(f"preprocessed/{folder}"):
        results_filename = f"{filename[:-8]}_part2.json"
        if results_filename in os.listdir(f"features/{folder}"):
            print(filename, "skip")
            continue

        pbar = tqdm(total=8, desc=f"{filename}", unit="step")

        resulting_features = {
            "lzc": {},
            "psd_ratio": {},
            "S": {},
            "Lambda": {},
            "annotations": {"main": []},
            "coh": {},
            "plv": {},
        }
        for encoding_type in POSSIBLE_PARAMS_S_LAMBDA:
            for params in POSSIBLE_PARAMS_S_LAMBDA[encoding_type]:
                resulting_features["S"][f"{encoding_type}{params}"] = {}
                resulting_features["Lambda"][f"{encoding_type}{params}"] = {}

        raw = mne.io.read_raw_fif(f"preprocessed/{folder}/{filename}", verbose=False)
        epochs = mne.make_fixed_length_epochs(raw, duration=size, preload=True, verbose=False)

        pbar.set_description(f"{filename} -> psd")
        pbar.update(1)
        psd_table = epochs.compute_psd(n_jobs=-1, verbose=False).get_data().mean(axis=2).T
        resulting_features["psd"] = {
            "_".join(ch_name.rsplit("-", 1)): psd_array
            for ch_name, psd_array in zip(raw.ch_names, psd_table)
        }

        pbar.set_description(f"{filename} -> psd ratios")
        pbar.update(1)
        for channel in channels_list:
            for dividend, divisor in RATIOS_TWO_BANDS:
                resulting_features["psd_ratio"][f"{channel}_{dividend}/{divisor}"] = (
                    resulting_features["psd"][f"{channel}_{dividend}"] / resulting_features["psd"][f"{channel}_{divisor}"]
                )
            for dividend, divisor_1, divisor_2 in RATIOS_THREE_BANDS:
                resulting_features["psd_ratio"][f"{channel}_{dividend}/({divisor_1}+{divisor_2})"] = (
                    resulting_features["psd"][f"{channel}_{dividend}"] / (
                        resulting_features["psd"][f"{channel}_{divisor_1}"] + resulting_features["psd"][f"{channel}_{divisor_2}"]
                    )
                )

        pbar.set_description(f"{filename} -> lzc")
        pbar.update(1)
        for n_bins in POSSIBLE_N_BINS_LZC:
            resulting_features["lzc"][n_bins] = {
                channel_name: lzc(epochs.get_data(copy=False)[:, channel_n], n_bins=n_bins)
                for channel_n, channel_name in enumerate(raw.ch_names)
            }

        pbar.set_description(f"{filename} -> s, lambda")
        pbar.update(1)
        for epoch in epochs:
            for channel_epoch_data, channel_name in zip(epoch, raw.ch_names):
                for encoding_type in POSSIBLE_PARAMS_S_LAMBDA:
                    for params in POSSIBLE_PARAMS_S_LAMBDA[encoding_type]:
                        symbols = signal_encoding(channel_epoch_data, encoding_type, params)

                        transition_matrix = create_transition_matrix(symbols)
                        S, Lambda = compute_S_and_Lambda(transition_matrix)

                        key = f"{encoding_type}{params}"
                        if channel_name not in resulting_features["S"][key]:
                            resulting_features["S"][key][channel_name] = []
                            resulting_features["Lambda"][key][channel_name] = []
                        resulting_features["S"][key][channel_name].append(S)
                        resulting_features["Lambda"][key][channel_name].append(Lambda)

        pbar.set_description(f"{filename} -> annotations")
        pbar.update(1)
        annotators_number = max(map(len, epochs.get_annotations_per_epoch()))
        if annotators_number != 1:
            for annotator in range(annotators_number):
                resulting_features["annotations"][str(annotator + 1)] = []
        for annotations in epochs.get_annotations_per_epoch():
            if len(annotations) == 1:
                for annotator in resulting_features["annotations"]:
                    resulting_features["annotations"][annotator].append(annotations[0][2])
            else:
                for annotation in annotations:
                    resulting_features["annotations"][annotation[2][-1:]].append(annotation[2][:1])
                resulting_features["annotations"]["main"].append(None)

        pbar.set_description(f"{filename} -> coherence, plv")
        pbar.update(1)
        with open(f"features/{folder}/{filename[:-8]}_part1.json", mode="r", encoding="utf-8") as file:
            coh_plv_data = json.loads(file.read())
            for key_upper in ["coh", "plv"]:
                for key_lower in coh_plv_data[key_upper]:
                    resulting_features[key_upper][key_lower] = np.array(
                        coh_plv_data[key_upper][key_lower]
                    ).reshape(len(epochs), 6).mean(axis=1)

        pbar.set_description(f"{filename} -> saving")
        pbar.update(1)
        with open(f"features/{folder}/{results_filename}", mode="w", encoding="utf-8") as file:
            file.write(json.dumps(resulting_features, cls=NumpyEncoder, indent=None))

        pbar.set_description(f"{filename} -> DONE")
        pbar.update(1)
        pbar.close()

In [76]:
get_features(
    "isruc-sleep",
    ["F3", "C3", "O1", "F4", "C4", "O2"]
)

s1_10_eeg.fif skip
s1_1_eeg.fif skip
s1_2_eeg.fif skip
s1_3_eeg.fif skip
s1_4_eeg.fif skip
s1_5_eeg.fif skip
s1_6_eeg.fif skip
s1_7_eeg.fif skip
s1_8_eeg.fif skip
s1_9_eeg.fif skip
s3_10_eeg.fif skip
s3_1_eeg.fif skip
s3_2_eeg.fif skip
s3_3_eeg.fif skip
s3_4_eeg.fif skip
s3_5_eeg.fif skip
s3_6_eeg.fif skip
s3_7_eeg.fif skip
s3_8_eeg.fif skip
s3_9_eeg.fif skip


In [78]:
get_features(
    "sleep_edf_database_expanded",
    ["Fpz-Cz", "Pz-Oz"]
)

0_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

10_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

11_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

12_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

13_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

14_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

15_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

1_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

2_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

3_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

4_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

5_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

6_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

7_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

8_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

9_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

In [79]:
get_features(
    "eegmat",
    ["F3", "C3", "O1", "F4", "C4", "O2"]
)

0_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

10_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

11_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

12_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

13_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

14_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

15_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

16_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

17_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

18_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

19_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

1_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

20_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

21_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

22_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

23_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

24_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

25_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

26_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

27_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

28_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

29_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

2_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

30_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

31_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

32_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

33_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

34_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

35_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

3_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

4_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

5_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

6_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

7_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

8_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

9_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

In [80]:
get_features(
    "spis",
    ["F3", "C3", "O1", "F4", "C4", "O2"]
)

0_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

1_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

2_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

3_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

4_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

5_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

6_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

7_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

8_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

9_eeg.fif:   0%|          | 0/8 [00:00<?, ?step/s]

In [77]:
# MNIST

CHANNELS_MNIST = ["F3", "F4", "O1", "O2", "P7", "P8"]

resulting_features = None
results_filename = f"mnist_part2.json"
for iter_n, filename in enumerate(tqdm(os.listdir(f"preprocessed/mnist"))):
    if resulting_features is None:
        resulting_features = {
            "psd": {},
            "lzc": {n_bins: {} for n_bins in POSSIBLE_N_BINS_LZC},
            "psd_ratio": {},
            "S": {},
            "Lambda": {},
            "annotations": {"main": []}
        }
        for encoding_type in POSSIBLE_PARAMS_S_LAMBDA:
            for params in POSSIBLE_PARAMS_S_LAMBDA[encoding_type]:
                resulting_features["S"][f"{encoding_type}{params}"] = {}
                resulting_features["Lambda"][f"{encoding_type}{params}"] = {}

    raw = mne.io.read_raw_fif(f"preprocessed/mnist/{filename}", verbose=False)
    epochs = mne.make_fixed_length_epochs(raw, duration=2, preload=True, verbose=False)

    if len(resulting_features["psd"]) == 0:
        for ch_name in raw.ch_names:
            resulting_features["psd"][
                "_".join(ch_name.rsplit("-", 1))
            ] = []
            for n_bins in POSSIBLE_N_BINS_LZC:
                resulting_features["lzc"][n_bins][ch_name] = []
            for encoding_type in POSSIBLE_PARAMS_S_LAMBDA:
                for params in POSSIBLE_PARAMS_S_LAMBDA[encoding_type]:
                    key = f"{encoding_type}{params}"
                    resulting_features["S"][f"{encoding_type}{params}"][ch_name] = []
                    resulting_features["Lambda"][f"{encoding_type}{params}"][ch_name] = []

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always", category=UserWarning)
        psd_table = epochs.compute_psd(n_jobs=-1, verbose=False).get_data().mean(axis=2).T

        skip_file = False
        for warning in w:
            if "Zero value in spectrum" in str(warning.message):
                print(f"Warning caught at file {filename}, skipping")
                skip_file = True

        if skip_file:
            continue

    for ch_name, psd_array in zip(raw.ch_names, psd_table):
        resulting_features["psd"][
            "_".join(ch_name.rsplit("-", 1))
        ].extend(psd_array)

    for n_bins in POSSIBLE_N_BINS_LZC:
        for channel_n, channel_name in enumerate(raw.ch_names):
            resulting_features["lzc"][n_bins][channel_name].extend(
                lzc(epochs.get_data(copy=False)[:, channel_n], n_bins=n_bins)
            )

    for epoch in epochs:
        for channel_epoch_data, channel_name in zip(epoch, raw.ch_names):
            for encoding_type in POSSIBLE_PARAMS_S_LAMBDA:
                for params in POSSIBLE_PARAMS_S_LAMBDA[encoding_type]:
                    symbols = signal_encoding(channel_epoch_data, encoding_type, params)

                    transition_matrix = create_transition_matrix(symbols)
                    S, Lambda = compute_S_and_Lambda(transition_matrix)

                    resulting_features["S"][f"{encoding_type}{params}"][channel_name].append(S)
                    resulting_features["Lambda"][f"{encoding_type}{params}"][channel_name].append(Lambda)

    resulting_features["annotations"]["main"].append(
        epochs.get_annotations_per_epoch()[0][0][2]
    )

    if len(resulting_features["annotations"]["main"]) == 1000:
        for channel in CHANNELS_MNIST:
            for dividend, divisor in RATIOS_TWO_BANDS:
                resulting_features["psd_ratio"][f"{channel}_{dividend}/{divisor}"] = (
                    np.array(
                        resulting_features["psd"][f"{channel}_{dividend}"]
                    ) / np.array(
                    resulting_features["psd"][f"{channel}_{divisor}"]
                )
                )
            for dividend, divisor_1, divisor_2 in RATIOS_THREE_BANDS:
                resulting_features["psd_ratio"][f"{channel}_{dividend}/({divisor_1}+{divisor_2})"] = (
                    np.array(resulting_features["psd"][f"{channel}_{dividend}"]) / (
                        np.array(
                            resulting_features["psd"][f"{channel}_{divisor_1}"]
                        ) + np.array(
                            resulting_features["psd"][f"{channel}_{divisor_2}"]
                        )
                    )
                )

        with open(f"features/mnist/{iter_n // 1000}_part2.json", mode="w", encoding="utf-8") as file:
            file.write(json.dumps(resulting_features, cls=NumpyEncoder, indent=None))
        resulting_features = None

  0%|          | 0/65033 [00:00<?, ?it/s]

Error while computing eigs, trying again
Error while computing eigs, trying again
