In [1]:
import numpy as np
import pandas as pd

from pathlib import Path
from tqdm import tqdm

import torchaudio
from sklearn.model_selection import train_test_split

import os
import sys

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
save_path = save_path = Path('/content/drive/My Drive/type3/data')

In [4]:
import IPython.display
import json

def Audio(audio: np.ndarray, sr: int):
    """
    Use instead of IPython.display.Audio as a workaround for VS Code.
    `audio` is an array with shape (channels, samples) or just (samples,) for mono.
    """

    if np.ndim(audio) == 1:
        channels = [audio.tolist()]
    else:
        channels = audio.tolist()

    return IPython.display.HTML("""
        <script>
            if (!window.audioContext) {
                window.audioContext = new AudioContext();
                window.playAudio = function(audioChannels, sr) {
                    const buffer = audioContext.createBuffer(audioChannels.length, audioChannels[0].length, sr);
                    for (let [channel, data] of audioChannels.entries()) {
                        buffer.copyToChannel(Float32Array.from(data), channel);
                    }
            
                    const source = audioContext.createBufferSource();
                    source.buffer = buffer;
                    source.connect(audioContext.destination);
                    source.start();
                }
            }
        </script>
        <button onclick="playAudio(%s, %s)">Play</button>
    """ % (json.dumps(channels), sr))

In [20]:
dm_path = save_path / 'dementia'
nd_path = save_path / 'nodementia'

In [21]:
dm_df = pd.read_csv(save_path/'dementia.csv')
nd_df = pd.read_csv(save_path/'nodementia.csv')

In [22]:
dm_df.head()

Unnamed: 0,name,dementia type,birthdate,deathdate,diagnosis,URLs after symptoms,5 years,5 < 10 years,10 < 15 years,gender,ethnicity,datasplit,language,unknown 1,unkown 2,unknown 3
0,Abe Burrows,Alzheimer,1910,1985,1975.0,,https://www.youtube.com/watch?v=VezbsmCriw4,,,male,Caucasian/White,train,,,,
1,Aileen Hernandez,Dementia,1926,2017,2012.0,https://youtu.be/x7hujcEhQuY,https://youtu.be/CshhDl-YwkY \nhttps://youtu.b...,,,female,Black/African American,train,,,,
2,Alan Ramsey,Dementia,1938,2020,2015.0,,https://www.youtube.com/watch?v=CHeXE4c6EDI,,,male,Caucasian/White,train,,,,
3,Allan Burns,Lewy body,1935,2021,,,https://www.youtube.com/watch?v=aD3hL-kWoPc,,,male,Caucasian/White,train,,,,
4,Andrew Sachs,Dementia,1930,2016,2012.0,,,https://youtu.be/FSgKLooW1LM,https://youtu.be/3V1iFmavqG4,male,,train,,,,


In [23]:
def split_df(df, col, val):
    return df[df[col] == val], df[df[col] != val]

In [24]:
valid_dm, train_dm = split_df(dm_df, 'datasplit', 'valid')
test_dm, train_dm = split_df(train_dm, 'datasplit', 'test')

valid_nd, train_nd = split_df(nd_df, 'datasplit', 'valid')

In [25]:
train_dmlst = train_dm['name'].tolist() 
train_ndlst = train_nd['name'].tolist()
valid_dmlst = valid_dm['name'].tolist()
valid_ndlst = valid_nd['name'].tolist()

In [26]:
print(len(train_dmlst), len(train_ndlst), len(valid_dmlst), len(valid_ndlst))

68 50 14 11


In [27]:
data_train = []
data_valid = []
for path in tqdm(dm_path.glob('**/*.wav')):
    name = str(path).split('/')[-1].split('.')[0]
    person = str(path).split('/')[-2]
    if person in train_dmlst:
        try:
            s = torchaudio.load(path)
            data_train.append({ 'file': name, 'label': 'dementia', 'path': path })
        except Exception as e:
            print(f'{path} is not a valid wav file', e)
            pass
    elif person in valid_dmlst:
        try:
            s = torchaudio.load(path)
            data_valid.append({ 'file': name, 'label': 'dementia', 'path': path })
        except Exception as e:
            print(f'{path} is not a valid wav file', e)
            pass

131it [01:44,  1.26it/s]


In [28]:
for path in tqdm(nd_path.glob('**/*.wav')):
    name = str(path).split('/')[-1].split('.')[0]
    person = str(path).split('/')[-2]

    if person in train_ndlst:
        try:
            s = torchaudio.load(path)
            data_train.append({ 'file': name, 'label': 'nodementia', 'path': path })
        except Exception as e:
            print(f'{path} is not a valid wav file', e)
            pass
    elif person in valid_ndlst:
        try:
            s = torchaudio.load(path)
            data_valid.append({ 'file': name, 'label': 'nodementia', 'path': path })
        except Exception as e:
            print(f'{path} is not a valid wav file', e)
            pass

324it [00:16, 19.90it/s]


In [29]:
train_df = pd.DataFrame(data_train)
valid_df = pd.DataFrame(data_valid)

train_df.head()

Unnamed: 0,file,label,path
0,daningram_15,dementia,/content/drive/My Drive/type3/data/dementia/Da...
1,terryjones_5,dementia,/content/drive/My Drive/type3/data/dementia/Te...
2,maureenforrester_5,dementia,/content/drive/My Drive/type3/data/dementia/Ma...
3,aileenhernandez_0,dementia,/content/drive/My Drive/type3/data/dementia/Ai...
4,aileenhernandez_5_1,dementia,/content/drive/My Drive/type3/data/dementia/Ai...


In [30]:
valid_df.head()

Unnamed: 0,file,label,path
0,JimmyCalderwood_5,dementia,/content/drive/My Drive/type3/data/dementia/Ji...
1,vivnicholson_5,dementia,/content/drive/My Drive/type3/data/dementia/Vi...
2,IanHolm_2,dementia,/content/drive/My Drive/type3/data/dementia/Ia...
3,CharmianCarr_15,dementia,/content/drive/My Drive/type3/data/dementia/Ch...
4,CharmianCarr_5,dementia,/content/drive/My Drive/type3/data/dementia/Ch...


In [31]:
print("Labels: ", train_df.label.unique())
print(len(train_df.label.unique()))

Labels:  ['dementia' 'nodementia']
2


In [32]:
train_df.groupby('label').count()[['path']]

Unnamed: 0_level_0,path
label,Unnamed: 1_level_1
dementia,106
nodementia,121


In [33]:
print(f"train: {len(train_df)}")
print(f"valid: {len(valid_df)}")

train: 227
valid: 48


In [34]:
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)

train_df.to_csv(save_path / 'train_dm.csv', sep='\t', encoding='utf-8', index=False)
valid_df.to_csv(save_path / 'valid_dm.csv', sep='\t', encoding='utf-8', index=False)
