# Supplementary Brain Classification Experiments

In [2]:
from collections import defaultdict
import math
from pathlib import Path
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
from scipy.io.wavfile import read as wavread
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.model_selection import LeaveOneOut, KFold
from sklearn.svm import SVC
from textgrids import TextGrid

import IPython.display as ipd

## Specify Target Words

In [7]:
# We can either use all 55 words that are in both the Harry Potter chapter as well as the VariaNTS data as target words...
with open('data/HP_VariaNTS_intersection.txt', 'r') as f:
    target_words = f.read().split(',')

# ...Or we can specify specific target words by hand:
# target_words = ['dag', 'wel']
# target_words = ['perkamentus', 'anderling']

print(target_words)

['bed', 'boel', 'brief', 'bril', 'dag', 'dier', 'doel', 'dood', 'feest', 'goed', 'greep', 'half', 'hand', 'heel', 'heer', 'hoofd', 'hoop', 'kalm', 'kan', 'kant', 'keel', 'keer', 'kind', 'klein', 'kneep', 'kwaad', 'laat', 'land', 'lang', 'licht', 'los', 'man', 'meer', 'mond', 'mot', 'neus', 'paar', 'raam', 'rest', 'snel', 'stad', 'steen', 'stijf', 'stuur', 'tijd', 'vol', 'vorm', 'vroeg', 'weg', 'wel', 'werk', 'wind', 'zet', 'zin', 'zoon']


## Create Data

In [8]:
# If only some electrodes should be selected, their indices can be specified here.
# NOTE: These current ones are for subject 2.
selected_electrodes = np.concatenate([np.arange(8)+16*i for i in range(0,6)])

In [14]:
# Sampling rate of the ECoG recordings
ECOG_SAMPLERATE = 100

# Whether to undersample classes with more occurences to the number of occurrences in the least occurring class
UNDERSAMPLE = False

# The subject to process
SUB = 2

# Replace this with wherever the data is stored
DATA_DIR = Path('/home/passch/brain2speech-diffusion/data/hp_reading')

data = []
labels = []
lengths = []

word_counter = defaultdict(int)

# Define subject specific path
subj_dir = DATA_DIR / f'sub-00{SUB}'

lfb_filename = 'lfb_hp_reading_ecog_car_1-30_avgfirst_100Hz_log_norm.npy'
hfb_filename = 'hfb_hp_reading_ecog_car_70-170_avgfirst_100Hz_log_norm.npy'

ecogs = {
    1: { # run 1
        'lfb' : np.load(subj_dir / f'sub-00{SUB}_ses-iemu_acq-ECOG_run-01_ieeg/{hfb_filename}'),
        'hfb' : np.load(subj_dir / f'sub-00{SUB}_ses-iemu_acq-ECOG_run-01_ieeg/{hfb_filename}'),
    },
    2: { # run 2
        'lfb' : np.load(subj_dir / f'sub-00{SUB}_ses-iemu_acq-ECOG_run-02_ieeg/{lfb_filename}'),
        'hfb' : np.load(subj_dir / f'sub-00{SUB}_ses-iemu_acq-ECOG_run-02_ieeg/{hfb_filename}'),
    }
} 


intervals = []

for run in [1,2]:
    # The TextGrid which contains all detected word intervals
    textgrid = TextGrid(subj_dir / f'sub-00{SUB}_ses-iemu_acq-ECOG_run-0{run}_audio.TextGrid')

    # Pick the relevant intervals from the text grid, i.e. only those of words in the intersection
    for interval in textgrid['words']:
        if interval.text in target_words:
            interval.run = run
            intervals.append(interval)

max_length = round(np.max([interval.xmax - interval.xmin for interval in intervals]) * ECOG_SAMPLERATE)
print('Longest segment:', max_length)

max_n = np.min(np.unique([e.text for e in intervals], return_counts=True)[1])

np.random.shuffle(intervals)

for interval in tqdm(intervals):
    word_counter[interval.text] += 1
    if UNDERSAMPLE and word_counter[interval.text] > max_n:
        continue

    # ecog_lfb = ecogs[interval.run]['lfb'][            
    #     math.floor(interval.xmin * ECOG_SAMPLERATE) : 
    #     math.floor(interval.xmin * ECOG_SAMPLERATE) + max_length
    # ] # -> Shape (TIMESTEPS, ELECTRODES)

    ecog_hfb = ecogs[interval.run]['hfb'][            
        math.floor(interval.xmin * ECOG_SAMPLERATE) : 
        math.floor(interval.xmin * ECOG_SAMPLERATE) + max_length
    ] # -> Shape (TIMESTEPS, ELECTRODES)

    # ecog = np.concatenate([ecog_lfb, ecog_hfb], axis=1)
    ecog = ecog_hfb

    ecog = ecog[:, selected_electrodes]
    ecog = ecog.T.flatten()
    
    data.append(ecog)
    labels.append(interval.text)

n_classes = len(np.unique(labels))
print('Classes:', n_classes)
data = np.array(data)
labels = np.array(labels)

counts = dict(zip(*np.unique([l for l in labels], return_counts=True)))
print(counts)

Longest segment: 55


100%|██████████| 188/188 [00:00<00:00, 43430.78it/s]

Classes: 53
{'bed': 2, 'boel': 1, 'brief': 5, 'bril': 2, 'dag': 13, 'dier': 1, 'doel': 1, 'dood': 3, 'feest': 1, 'goed': 9, 'greep': 2, 'half': 5, 'hand': 3, 'heel': 7, 'heer': 1, 'hoofd': 8, 'hoop': 4, 'kalm': 1, 'kan': 7, 'kant': 2, 'keel': 1, 'keer': 7, 'kind': 2, 'klein': 3, 'kneep': 1, 'kwaad': 1, 'laat': 2, 'land': 5, 'lang': 8, 'licht': 1, 'los': 6, 'man': 4, 'meer': 7, 'mond': 2, 'neus': 2, 'paar': 8, 'raam': 3, 'rest': 1, 'snel': 1, 'stad': 2, 'stijf': 1, 'stuur': 1, 'tijd': 4, 'vol': 1, 'vorm': 2, 'vroeg': 6, 'weg': 4, 'wel': 18, 'werk': 2, 'wind': 1, 'zet': 1, 'zin': 1, 'zoon': 1}





## Train SVMs

Fit SVMs with cross validation:

In [93]:
# Using `n_splits` equal to the total number of samples is equivalent to doing LOO
crossval = KFold(n_splits=len(data))
print(f'Cross-validating with {crossval.get_n_splits()} splits:')

train_accs, test_accs = [], []

for i, (train_index, test_index) in enumerate(crossval.split(data)):
    X_train, X_test = data[train_index], data[test_index]
    y_train, y_test = labels[train_index], labels[test_index]

    clf = SVC(C=2.5)
    clf.fit(X_train, y_train)

    train_accs.append(clf.score(X_train, y_train))
    test_accs.append(clf.score(X_test, y_test))

train_acc = round(np.mean(train_accs), 4)
test_acc = round(np.mean(test_accs), 4)
print('  Train Acc. :', train_acc)
print('  Test Acc.  :', test_acc)

Cross-validating with 40 splits:
  Train Acc. : 1.0
  Test Acc.  : 0.625


## Results Collection

In [None]:
# Short words that occur more often
# maar vs. zijn (47c): 1.0 0.883
# maar vs. niet (47c): 1.0 0.8404
# zijn vs. niet (65c): 1.0 0.6308 
# zijn vs. niet vs. maar (47c): 1.0 0.6738

# Long words
# perkamentus vs. anderling (26c): 1.0 0.9615 
# perkamentus vs. professor (33c): 1.0 0.8939
# professor vs. anderling (26c): 1.0 1.0
# professor vs. anderling vs. perkamentus (26c): 1.0 0.9231
# professor vs. anderling vs. perkamentus vs. meneer vs. mevrouw (20c): 1.0 0.74
# meneer vs. mevrouw (20c): 1.0 0.625

In [16]:
# Results of using different thresholds on the minimum no. of occurrences 
# required for a class to be incorporated into the fit 
# (using the 55 intersection words):

# WITHOUT Undersampling:
# Threshold, Classes, Train, Test
# 1, 53, 1.0, 0.1117
# 2, 34, 1.0, 0.1243
# 3, 23, 1.0, 0.1429
# 4, 19, 1.0, 0.1556
# 5, 15, 1.0, 0.1765
# 6, 12, 1.0, 0.2019
# 7, 10, 1.0, 0.2283
# 8,  6, 1.0, 0.3281
# 9,  3, 1.0, 0.575
# 10, 2, 1.0, 0.7419

# WITH Undersampling:
# Threshold, Classes, Train, Test
# 1, 53, 1.0, 0.0
# 2, 34, 1.0, 0.0
# 3, 23, 1.0, 0.0
# 4, 19, 1.0, 0.0395
# 5, 15, 1.0, 0.08
# 6, 12, 1.0, 0.0417
# 7, 10, 1.0, 0.0571
# 8,  6, 1.0, 0.2917
# 9,  3, 1.0, 0.4444
# 13, 2, 1.0, 0.6538

## Misc.

Here we can randomly sample an interval and load the appropriate audio to check if the intervals are correctly aligned

In [188]:
# The audio file of the recording
audio_samplerate, audio_signal = wavread(
    subj_dir / f'sub-00{SUB}_ses-iemu_acq-ECOG_run-0{run}_audio_pitch_shifted.wav')

interval =  intervals[np.random.choice(len(intervals))]
interval_sig = audio_signal[
    int(interval.xmin * audio_samplerate) : 
    int(interval.xmax * audio_samplerate) 
]
print(interval.text) 
ipd.Audio(interval_sig, rate=audio_samplerate)

goed
