# Dataset Splits for Brain2Speech Models

Use the code in this notebook to create training and validation splits for the training of the brain2speech models. We first create splits for the ECoG data created using `src/preprocessing/create_ecog_data.py`, and then adjust the splits for the VariaNTS data created using `src/preprocessing/create_speech_data.py` to accomodate the skewed distribution of the ECoG words.

In [None]:
import os
import sys
sys.path.append('..')

import numpy as np

from utils.generic import get_word_from_filepath

### Brain Data

First get the names of the sound files from the paired ECoG+voice recording dataset

In [40]:
hp_data = '../../data/HP1_ECoG_conditional/sub-002'

files = [file for file in os.listdir(hp_data) if file.endswith('.wav')]

print(sorted(files))

['bed1.wav', 'bed2.wav', 'boel1.wav', 'brief1.wav', 'brief2.wav', 'brief3.wav', 'brief4.wav', 'brief5.wav', 'bril1.wav', 'bril2.wav', 'dag1.wav', 'dag10.wav', 'dag11.wav', 'dag12.wav', 'dag13.wav', 'dag2.wav', 'dag3.wav', 'dag4.wav', 'dag5.wav', 'dag6.wav', 'dag7.wav', 'dag8.wav', 'dag9.wav', 'dier1.wav', 'doel1.wav', 'dood1.wav', 'dood2.wav', 'dood3.wav', 'feest1.wav', 'goed1.wav', 'goed2.wav', 'goed3.wav', 'goed4.wav', 'goed5.wav', 'goed6.wav', 'goed7.wav', 'goed8.wav', 'goed9.wav', 'greep1.wav', 'greep2.wav', 'half1.wav', 'half2.wav', 'half3.wav', 'half4.wav', 'half5.wav', 'hand1.wav', 'hand2.wav', 'hand3.wav', 'heel1.wav', 'heel2.wav', 'heel3.wav', 'heel4.wav', 'heel5.wav', 'heel6.wav', 'heel7.wav', 'heer1.wav', 'hoofd1.wav', 'hoofd2.wav', 'hoofd3.wav', 'hoofd4.wav', 'hoofd5.wav', 'hoofd6.wav', 'hoofd7.wav', 'hoofd8.wav', 'hoop1.wav', 'hoop2.wav', 'hoop3.wav', 'hoop4.wav', 'kalm1.wav', 'kan1.wav', 'kan2.wav', 'kan3.wav', 'kan4.wav', 'kan5.wav', 'kan6.wav', 'kan7.wav', 'kant1.wav', 

Below are the word counts for this data (words that occur less than 4 times are not shown) for Subject 2:

WORD | wel | dag | goed | paar | hoofd | lang | heel | kan | meer | keer | vroeg | los | brief | half | land | hoop | weg | man | tijd
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---
COUNT | 18 | 13 | 9 | 8 | 8 | 8 | 7 | 7 | 7 | 7 | 6 | 6 | 5 | 5 | 5 | 4 | 4 | 4 | 4



Because the data is so unevenly distributed, we take different amounts for the validation set for each word. Below we define the amounts per word:

In [41]:
words_to_choose = {'wel': 5, 'dag': 3, 'goed': 2, 'paar': 1, 'hoofd': 1, 'lang': 1, 'heel': 1, 'kan': 1, 'meer': 1, 'keer': 1}

Note that the data for subject 1 are differently distributed, and that we may choose different amounts per word there.

Then we randomly choose the designated number for each word from the files:

In [42]:
rng = np.random.default_rng(1144)

val_files = []

for (word, count) in words_to_choose.items():
    word_files = [file for file in files if get_word_from_filepath(file) == word]
    val_files.extend(
        rng.choice(word_files, size=count, replace=False)
    )

print(val_files)

['wel12.wav', 'wel17.wav', 'wel5.wav', 'wel7.wav', 'wel1.wav', 'dag2.wav', 'dag4.wav', 'dag9.wav', 'goed8.wav', 'goed4.wav', 'paar1.wav', 'hoofd7.wav', 'lang8.wav', 'heel6.wav', 'kan4.wav', 'meer6.wav', 'keer1.wav']


Now we can write train and validation splits from these:

In [59]:
splits_path = '../../data/datasplits/HP1_ECoG_conditional/sub-002/'

os.makedirs(splits_path, exist_ok=False)

In [43]:
train_files = [file for file in files if not file in val_files]

In [None]:
with open(splits_path + 'train.csv', 'w') as f:
    f.write(','.join(train_files))
with open(splits_path + 'val.csv', 'w') as f:
    f.write(','.join(val_files))

### VariaNTS Data

For the brain conditional finetuning experiment that maps ECoG to VariaNTS data, we also need to change the train-val-split for the VariaNTS data, such that it fits to the just created split. Of course, we have to reuse the same splits as used in pretraining.

The VariaNTS datasplits were created such that the validation set contains 3 randomly chosen speakers for each word. Since the ECoG validation set does not contain every word because of the uneven distribution of words, we can only keep the words in the VariaNTS validation split that are also in the ECoG split.

First we load the original VariaNTS datasplits:

In [44]:
# The splits used in both unconditional and class-conditional pretraining runs
vnts_splits_path = '../../data/datasplits/VariaNTS/HP_90-10/'

with open(vnts_splits_path + 'train.csv', 'r') as f:
    vnts_train_files = f.read().split(',')
with open(vnts_splits_path + 'val.csv', 'r') as f:
    vnts_val_files = f.read().split(',')

len(vnts_train_files), len(vnts_val_files)

(27115, 165)

Then, for both train and validation files, we get the unique words used in the ECoG data. 

We have to do this not only for validation, but also for training, because the ECoG data may not contain all 55 words (indeed for subject 2, it contains 53 words, while for subject 1 only 43).

In [45]:
hp_train_words = np.unique(
    [get_word_from_filepath(file) for file in train_files])
hp_val_words = np.unique(
    [get_word_from_filepath(file) for file in val_files])

Then we remove all words from the VariaNTS training and validation splits that are not in the ECoG ones:

In [46]:
vnts_train_files = [
    file for file in vnts_train_files 
        if get_word_from_filepath(file) in hp_train_words
]
vnts_val_files = [
    file for file in vnts_val_files 
        if get_word_from_filepath(file) in hp_val_words
]

len(vnts_train_files), len(vnts_val_files)

(26129, 30)

Looking at the amount of remaining files, we can see that more files were removed from the validation split than the training split, relatively speaking. This makes sense, as only 2 words have to be removed from training for subject 2, while the ECoG validation split only contains a much smaller selection of words to begin with.

In [48]:
vnts_new_splits_path = '../../data/datasplits/VariaNTS/HP_b2s_90-10/'

os.makedirs(vnts_new_splits_path, exist_ok=False)

In [49]:
with open(vnts_new_splits_path + 'train.csv', 'w') as f:
    f.write(','.join(vnts_train_files))
with open(vnts_new_splits_path + 'val.csv', 'w') as f:
    f.write(','.join(vnts_val_files))

We can do the same for the unaugmented train files and augmented val files, for completeness:

In [50]:
with open(vnts_splits_path + 'train_noaug.csv', 'r') as f:
    vnts_train_files = f.read().split(',')
with open(vnts_splits_path + 'val_aug.csv', 'r') as f:
    vnts_val_files = f.read().split(',')

print(len(vnts_train_files), len(vnts_val_files))

vnts_train_files = [
    file for file in vnts_train_files 
        if get_word_from_filepath(file) in hp_train_words
]
vnts_val_files = [
    file for file in vnts_val_files 
        if get_word_from_filepath(file) in hp_val_words
]

print(len(vnts_train_files), len(vnts_val_files))

with open(vnts_new_splits_path + 'train_noaug.csv', 'w') as f:
    f.write(','.join(vnts_train_files))
with open(vnts_new_splits_path + 'val_aug.csv', 'w') as f:
    f.write(','.join(vnts_val_files))

1595 2805
1537 510
