In [1]:
import torch
import os

from torch import nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.nn import Module
from torch.utils.data import Subset
import torchaudio
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm

import random

import ast

from typing import List, Dict

import pandas as pd

import speechbrain
from speechbrain.lobes.models.Xvector import Xvector, Classifier
from speechbrain.lobes.models import ECAPA_TDNN
from speechbrain.inference.speaker import EncoderClassifier
from speechbrain.inference.classifiers import EncoderClassifier
from speechbrain.nnet.pooling import AttentionPooling, StatisticsPooling
from speechbrain.utils.seed import seed_everything
from speechbrain.dataio.dataio import length_to_mask
from speechbrain.nnet.linear import Linear


from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler, normalize
from sklearn.model_selection import train_test_split

from transformers import get_linear_schedule_with_warmup

import librosa

from scipy.io import wavfile

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm
INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [allow_tf32, disable_jit_profiling]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []
INFO:matplotlib.font_manager:generated new fontManager


In [2]:
def sample_to_df(data_tv, data_wav, org_data_df, sample_file, segment_length_frames):
    num_frames_tv = data_tv.shape[0]
    origin_id = os.path.splitext(os.path.basename(sample_file))[0]
    if origin_id == 'AVPEPUDEAC0041_Monologo-NR_':
        origin_id = 'AVPEPUDEAC0041_Monologo-NR'
    
    min_valid_frames = int(70) #int((16000 / 160))
    
    
    segments_wav = []
    segments_tv = []
    for i in range(0, num_frames_tv, segment_length_frames):
        segment_wav = data_wav[i:i+segment_length_frames]
        segment_tv = data_tv[i:i+segment_length_frames]
        
        if len(segment_tv) >= min_valid_frames:
            segments_tv.append(segment_tv)
            segments_wav.append(segment_wav)
        #segments_wav.append(segment_wav)
        #segments_tv.append(segment_tv)
    
    assert len(segments_tv) == len(segments_wav)
    
    origin_id_list = []
    for i in range(len(segments_wav)):
        origin_id_list.append(f"{origin_id}_{i}")
    
    copied_data = pd.DataFrame([org_data_df]*len(segments_wav))
    copied_data['sample_wav'] = segments_wav
    copied_data['sample_tv'] = segments_tv
    copied_data['origin_id'] = origin_id_list
    copied_data['filename'] = sample_file
    
    return copied_data

In [3]:
def new_testset_perfold(fold):
    org_test_set = pd.read_csv(f"/home/stinasb/SSL4PR/TV_train_test/Fold{str(fold)}/test/test.csv")

    target_dir = f"/home/stinasb/SSL4PR/TV_train_test/Fold{str(fold)}/test/test_segment.pkl"
    
    sr = 16000
    hop_length = int(0.01*sr)
    segment_length = 10
    segment_length_frames = int(segment_length*sr/hop_length)

    columns_to_copy = ['speaker_id', 'status', 'UPDRS', 'UPDRS-speech','H/Y', 'SEX', 'AGE', 'time after diagnosis']
    
    df_list = []
    
    for index, series in org_test_set.iterrows():
        
        sample_tv = org_test_set['sample_tv'][index]
        sample_wav = org_test_set['sample_wav'][index]
        #labels = org_test_set['labels'][]

        tv_data = pd.read_pickle(sample_tv).to_numpy() 
        samplerate, file = wavfile.read(sample_wav)
                    
        feat = librosa.feature.melspectrogram(
            y=file.astype(float), 
            sr=samplerate, 
            n_fft=2 * int(0.01*samplerate),
            hop_length= int(0.01*samplerate),
            #win_length= int(0.01*samplerate),
            window="hamming",
            center=True, 
            n_mels=80,
            )
        #feat = librosa.feature.mfcc(y=file.astype(float), sr=samplerate)
        log_mel = librosa.power_to_db(feat)
        wav_data = log_mel.transpose(1,0)
        
        org_data_copy = org_test_set.iloc[index][columns_to_copy]
        
        status = org_data_copy['status']
        #org_data_copy['labels'] = int(0) if status in {'hc', 'HC'} else int(1)

        if status in {'hc', 'HC'}:
            org_data_copy['labels'] = int(0)
        elif status in {'pd', 'PD'}:
            org_data_copy['labels'] = int(1)
        
        
        df = sample_to_df(tv_data, wav_data, org_data_copy, sample_wav, segment_length_frames)
        #df['labels'] = df['labels'].astype(int)

        #print(df)
        
        df_list.append(df)
    #for elem in df_list:
     #   print(elem)
        #df['labels'] = df['labels'].astype(int)  
    #print(df_list)
    combined_df = pd.concat(df_list, ignore_index=True)
    #print(combined_df['labels'])
    combined_df['labels'] = combined_df['labels'].astype(int)


    #org_test_set
    #print(f"Combined segmented set size (by unique speaker/file): {combined_df['sample_wav'].nunique()}")

    #combined_df.to_csv(target_dir, index=False)
    combined_df.to_pickle(target_dir)
    
    return combined_df

In [4]:
for i in range(1, 11):    
    ddf = new_testset_perfold(i)
#df = new_testset_perfold(3)
#df

In [43]:
df[df['speaker_id']=='AVPEPUDEA0055']['origin_id']

88     AVPEPUDEA0055_Monologo-NR_0
89     AVPEPUDEA0055_Monologo-NR_1
90     AVPEPUDEA0055_Monologo-NR_2
91            AVPEPUDEA0055_juan_0
92              AVPEPUDEA0055_ka_0
93           AVPEPUDEA0055_laura_0
94       AVPEPUDEA0055_loslibros_0
95           AVPEPUDEA0055_luisa_0
96          AVPEPUDEA0055_micasa_0
97            AVPEPUDEA0055_omar_0
98              AVPEPUDEA0055_pa_0
99          AVPEPUDEA0055_pakata_0
100         AVPEPUDEA0055_pataka_0
101         AVPEPUDEA0055_petaka_0
102     AVPEPUDEA0055_preocupado_0
103       AVPEPUDEA0055_readtext_0
104       AVPEPUDEA0055_readtext_1
105         AVPEPUDEA0055_rosita_0
106             AVPEPUDEA0055_ta_0
107         AVPEPUDEA0055_triste_0
108          AVPEPUDEA0055_viste_0
Name: origin_id, dtype: object

In [None]:
org = pd.read_pickle("/home/stinasb/SSL4PR/TV_train_test/Fold3/test/test_segment.pkl")

org[org['speaker_id']=='AVPEPUDEA0055']

Unnamed: 0,speaker_id,status,UPDRS,UPDRS-speech,H/Y,SEX,AGE,time after diagnosis,labels,sample_wav,sample_tv,origin_id,filename
88,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[29.916489999272926, 28.46756730547154, 26.19...","[[0.017235443, -0.08916493, 0.032738045, -0.11...",AVPEPUDEA0055_Monologo-NR_0,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...
89,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[35.13232024512537, 32.922482016593285, 27.53...","[[0.03801344, 0.00540906, 0.011352032, 0.01563...",AVPEPUDEA0055_Monologo-NR_1,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...
90,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[32.155171093117694, 29.591306860790134, 20.6...","[[0.011900089, 0.005456824, -0.0032276711, 0.0...",AVPEPUDEA0055_Monologo-NR_2,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...
91,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[26.590774059793823, 26.261463689399566, 26.0...","[[-0.00854967, -0.033905514, -0.038085274, -0....",AVPEPUDEA0055_juan_0,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...
92,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[23.979576836663142, 23.53035926161073, 23.16...","[[0.10660799, -0.061311573, 0.07857328, -0.070...",AVPEPUDEA0055_ka_0,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...
93,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[32.38540850662159, 30.76785685047087, 28.012...","[[0.032478053, 0.10337065, 0.004802118, 0.0835...",AVPEPUDEA0055_laura_0,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...
94,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[28.0176730768982, 28.421172319793556, 29.002...","[[0.059710644, -0.16134395, -0.015438917, -0.1...",AVPEPUDEA0055_loslibros_0,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...
95,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[37.301973430651174, 36.287557248333115, 35.0...","[[0.083751604, -0.20897914, -0.023192098, -0.1...",AVPEPUDEA0055_luisa_0,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...
96,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[30.61957787739776, 43.54018339040852, 46.803...","[[0.08754513, -0.16801138, -0.04283937, -0.148...",AVPEPUDEA0055_micasa_0,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...
97,AVPEPUDEA0055,pd,53.0,2.0,2.0,M,65.0,19.0,1,"[[36.04947354195551, 36.66216946670973, 37.424...","[[-0.09281662, 0.1683, 0.064931124, 0.0888924,...",AVPEPUDEA0055_omar_0,/talebase/data/speech_raw/PC-GITA-v2/PC-GITA_p...


In [33]:
df = pd.read_csv("/home/stinasb/SSL4PR/TV_train_test/Fold3/test/test.csv")
df[df['speaker_id']=='AVPEPUDEA0055']['origin_id']

72    AVPEPUDEA0055_Monologo-NR
73           AVPEPUDEA0055_juan
74             AVPEPUDEA0055_ka
75          AVPEPUDEA0055_laura
76      AVPEPUDEA0055_loslibros
77          AVPEPUDEA0055_luisa
78         AVPEPUDEA0055_micasa
79           AVPEPUDEA0055_omar
80             AVPEPUDEA0055_pa
81         AVPEPUDEA0055_pakata
82         AVPEPUDEA0055_pataka
83         AVPEPUDEA0055_petaka
84     AVPEPUDEA0055_preocupado
85       AVPEPUDEA0055_readtext
86         AVPEPUDEA0055_rosita
87             AVPEPUDEA0055_ta
88         AVPEPUDEA0055_triste
89          AVPEPUDEA0055_viste
Name: origin_id, dtype: object