In [1]:
import pickle
import glob, os
import time
import copy
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
#scikitlearn
from sklearn.manifold import TSNE
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram
from sklearn.semi_supervised import LabelSpreading
#Pyannote
from pyannote.core import Segment, notebook, Timeline
from pyannote.audio.utils.signal import Binarize, Peak
#Pydub
from pydub import AudioSegment
from pydub.playback import play
#Praat
import parselmouth

In [11]:
location="/location/to/extracted/pyannote/and/audio/files/"
patientIdList=["[ID of patient]"]

In [3]:
from pyannote.core import SlidingWindowFeature
plot_ready = lambda scores: SlidingWindowFeature(np.exp(scores.data[:, 1:]), scores.sliding_window)

In [4]:
# function to cluster the audio segments within a session audio file and output the result.
def ClusterAudio (patientIdLocation,wavfile):
    ## Load the scores
    print("Loading scores")
    sessionName=wavfile.replace('.wav', '')
    sad_scores = pickle.load( open( sessionName+"_sad.p", "rb" ) )
    scd_scores = pickle.load( open( sessionName+"_scd.p", "rb" ) )
    ovl_scores = pickle.load( open( sessionName+"_ovl.p", "rb" ) )
    embeddings = pickle.load( open( sessionName+"_emb.p", "rb" ) )
    
    ## speech regions
    binarize = Binarize(offset=0.52, onset=0.52, log_scale=True,min_duration_off=0.1, min_duration_on=0.1)
    speech = binarize.apply(sad_scores, dimension=1)
    ## speaker change point
    peak = Peak(alpha=0.10, min_duration=0.10, log_scale=True)
    partition = peak.apply(scd_scores, dimension=1)
    ## speech turns are simply the intersection of SAD and SCD
    speech_turns = partition.crop(speech)
    ## only keep long speech turns
    alpha=0.5 #0.5 DEFAULT
    long_turns = Timeline(segments=[s for s in speech_turns if s.duration > alpha])
    print(str(long_turns.__len__())+" long turns")
    
    ## Get sound with praat to calculate left right intensity
    snd = parselmouth.Sound(patientIdLocation+"/"+wavfile)
    if snd.get_number_of_channels==2:
        Int_left=snd.extract_left_channel().to_intensity()
        Int_left_ar=Int_left.as_array()[0]
        Int_right=snd.extract_right_channel().to_intensity()
        Int_right_ar=Int_right.as_array()[0]
        Int_LR=Int_left_ar-Int_right_ar
    else :
        Int_LR=snd.to_intensity()
    
    ## Get features for speech turns
    print("building features")
    Xemb,Xovl,Xlr,Xa,keep = [],[],[],[],[]
    for segment in long_turns:
        xemb = embeddings.crop(segment)#, mode='strict')
        xovl = plot_ready(ovl_scores).crop(segment)
        # average speech turn embedding if not empty
        if len(xemb) > 0:
            keep.append(1)
            ## embedding
            meanx=np.mean(xemb, axis=0)
            ## overlap
            meanxovl=np.mean(xovl,axis=0)
            ## Left Right Intensity
            fromframe_LR=np.rint(Int_left.get_frame_number_from_time(segment.start)).astype(int)
            toframe_LR=np.rint(Int_left.get_frame_number_from_time(segment.end)).astype(int)
            if fromframe_LR==toframe_LR:
                meanlr=Int_LR[fromframe_LR]
            else:
                meanlr=np.mean(Int_LR[fromframe_LR:toframe_LR],axis=0)
            if np.isnan(meanlr):
                meanlr=0
            Xa.append(np.concatenate([meanx,meanxovl,[meanlr]])) #DEFAULT
        else :
            keep.append(0)
    
    Xa = np.vstack(Xa)
    
    ## Cluster the result
    print("clustering result")
    clusterModel = AgglomerativeClustering(affinity='euclidean',n_clusters=4,linkage='ward')
    Y=clusterModel.fit_predict(Xa)
    
    ## plot the clusters
    print("export cluster plot")
    OutputClusterPlot(Xa,Y,sessionName)
    
    ## Export cluster sound
    print("export cluster audio files")
    OutputClusterAudio(patientIdLocation,wavfile,sessionName,long_turns,Y,keep)
    
    ## create result matrix
    print("create result matrix")
    clusterTimes = np.array([[0.00 for x in range(3)] for y in range(len(long_turns))])
    countY=0
    for segment in long_turns:
        clusterTimes[countY][0]=segment.start
        clusterTimes[countY][1]=segment.end
        clusterTimes[countY][2]=Y[countY]
        countY=countY+1
    
    ## Put clusters on pitch
    print("clusters to pitch")
    pitch = snd.to_pitch()
    pitchLength=len(pitch)
    t1=pitch.t1
    dt=pitch.dt
    clusterPitch = np.array([[-1.0 for x in range(3)] for y in range(pitchLength)])
    row=0
    t=t1
    for frame in pitch:
        clusterPitch[row][0]=t
        clusterPitch[row][1]=frame.as_array()[0][0]
        clust=clusterTimes[(clusterTimes[:,0] < t)&(clusterTimes[:,1] > t)]
        if len(clust)==1:
            clusterPitch[row][2]=clust[0,2]
        elif len(clust)>1:
            print("ERROR more than 2 segments in timeframe")
        else:
            clusterPitch[row][2]=-1
        row=row+1
        t=t+dt
    
    ## Export the result
    print("Exporting result")
    clusterPitch = pd.DataFrame(clusterPitch)
    clusterPitch.columns = ['t', 'p', 'cluster']
    clusterPitch.to_csv(sessionName+"_result.csv", index = False, header=True)


In [5]:
# Function for writing cluster output to 4 audio files
def OutputClusterAudio(patientIdLocation,wavfile,sessionName,long_turns,Y,keep):
    fullSesWav = AudioSegment.from_wav(patientIdLocation+"/"+wavfile)
    audioFill0=fullSesWav[0:100]
    audioFill1=fullSesWav[0:100]
    audioFill2=fullSesWav[0:100]
    audioFill3=fullSesWav[0:100]
    countY=0
    for segmentnr in range(len(long_turns)):
        #print(segmentnr)
        if keep[segmentnr]==1:
            begin=long_turns.segments_list_[segmentnr].start*1000
            end=long_turns.segments_list_[segmentnr].end*1000
            audioseg=fullSesWav[begin:end]
            if Y[segmentnr-countY]==0:
                audioFill0=audioFill0+audioseg
            elif Y[segmentnr-countY]==1:
                audioFill1=audioFill1+audioseg
            elif Y[segmentnr-countY]==2:
                audioFill2=audioFill2+audioseg
            elif Y[segmentnr-countY]==3:
                audioFill3=audioFill3+audioseg
        else: 
            countY=countY+1
    print("0:"+str(audioFill0.duration_seconds/60))
    print("1:"+str(audioFill1.duration_seconds/60))
    print("2:"+str(audioFill2.duration_seconds/60))
    print("3:"+str(audioFill3.duration_seconds/60))
    audioFill0.export(sessionName+"_0.mp3", format="mp3")
    audioFill1.export(sessionName+"_1.mp3", format="mp3")
    audioFill2.export(sessionName+"_2.mp3", format="mp3")
    audioFill3.export(sessionName+"_3.mp3", format="mp3")

In [6]:
# Function for creating a TSNE plot of the data that is used for clustering. Use plot to visualy inspect the diarization.
def OutputClusterPlot(Xa,Y,sessionName):
    tsne = TSNE(n_components=2, metric="euclidean", perplexity=30, random_state=42)
    X_2d = tsne.fit_transform(Xa)
    fig, ax = plt.subplots()
    fig.set_figheight(5)
    fig.set_figwidth(5)
    scatter=ax.scatter(*X_2d.T, c=Y)
    legend1 = ax.legend(*scatter.legend_elements(),
                        loc="lower left", title="Classes")
    ax.add_artist(legend1)
    plt.savefig(sessionName+"_plot.png",)

In [None]:
### Run for multiple patients
for patientId in patientIdList:
    patientIdLocation=location+patientId
    os.chdir(patientIdLocation)
    for wavfile in glob.glob("*.wav"):
        print(wavfile)
        ClusterAudio(patientIdLocation,wavfile)