In [None]:
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.append('../')
import torch
import pickle5
from aux import load_audio
from config import DATA_PATH, GMM_FILE, OUT_TRAIN, CALLHOME_ENG_10_SEC
from aux_evaluate import draw_speech_and_channel

import numpy as np
from kaldi_feats import logmel_feats
import matplotlib.pyplot as plt
import os

# Load audio

In [None]:
# audio_file = CALLHOME_ENG_10_SEC + os.listdir(CALLHOME_ENG_10_SEC)[19]
audio_file = CALLHOME_ENG_10_SEC + '4289.wav'

sample_rate, audio = load_audio(audio_file, 8000, mono=False)

plt.figure(figsize=(14, 5))

vad=None                                                                                                                                                    
with open(GMM_FILE, 'rb') as fid:                                                                                                             
    vad = pickle5.load(fid)                

channel_0, channel_1 = audio[0], audio[1]
speech_0 = vad.detect_speech(signal=channel_0, sampling_rate=8000, fit_to_audio=True)
speech_1 = vad.detect_speech(signal=channel_1, sampling_rate=8000, fit_to_audio=True)

overlapping_speech = np.logical_and(speech_0, speech_1)
channel_0, channel_1 = channel_0/32768, channel_1/32768

_, mono = load_audio(audio_file, 8000, mono=True)
speech_mono = vad.detect_speech(signal=mono,sampling_rate=8000, fit_to_audio=True)

draw_speech_and_channel(speech_mono, mono, 'mono', 'purple')

# LSTM model init

In [None]:
sys.path.append('../train')
from aux_evaluate import prepare_mono_for_forward
from model import LSTM_Diarization
from config import MODEL_FILE


model = torch.load(MODEL_FILE, map_location='cpu')


model.window_size=20
model.shift=10


model.train=False



# auxiliary function

In [None]:
from aux import get_embeddings
# from numpy import linalg as LA
# from sklearn.cluster import spectral_clustering
# from spectral_cluster import blur, row_wise_threshold, row_wise_normalize, symmetrize, diffuse
from spectral_cluster import adjust_labels_to_signal



def labels_to_audio_size(labels, mono, speech_mono):
    """labels expect to be boolean np array"""
    """True, False stand for first or 2nd speaker"""
    

#     print('speech mono sum', speech_mono.sum())
#     print('mono shape', mono.shape)
    
    
    adjusted_size_labels = adjust_labels_to_signal(labels, speech_mono.sum())
#     print('adjusted', adjusted_size_labels.shape)
    labels_audio_size = np.full((mono.shape[0],), 3)
    print('labels', labels_audio_size.shape)
    labels_audio_size[np.where(speech_mono)] = adjusted_size_labels
    
    spk_0 = np.zeros_like(mono, dtype=np.bool)
    spk_1 = np.zeros_like(mono, dtype=np.bool)
    
    spk_0[np.where(labels_audio_size==0)] = True
    spk_1[np.where(labels_audio_size==1)] = True
    
    return spk_0, spk_1

# Model forward on single recording

In [None]:
# char_annots = get_char_labels(pred_labels, audio_split, speech_indexes)
from spectral_cluster import adjust_labels_to_signal

speech_indexes = vad.detect_speech(signal=mono, sampling_rate=8000, fit_to_audio=True)
speech_indexes = np.logical_and(speech_indexes, ~overlapping_speech)
filtered_channel = mono[speech_indexes]




feats = prepare_mono_for_forward(filtered_mono_channel=filtered_channel, sampling_rate=8000).to('cpu')

prepared_feats = prepare_mono_for_forward(filtered_mono_channel=filtered_channel, sampling_rate=8000)
d_vectors = model.forward(prepared_feats)

from spectral_cluster import get_affinity_matrix, cluster_affinity
import sys
import numpy
numpy.set_printoptions(threshold=sys.maxsize)


s = get_affinity_matrix(d_vectors.squeeze(0))
pred_labels = cluster_affinity(s.detach().cpu().numpy(),dtype=np.int)

print(pred_labels)




spk_0, spk_1 = labels_to_audio_size(pred_labels, mono, speech_indexes)



# Draw result

In [None]:
plt.figure(figsize=(14, 7))

plt.subplot(6, 1, 1)
# plt.tight_layout()


plt.ylim(-1., 1.)
plt.plot(channel_0)
plt.title("channel_0", x=0.5, y=0.6)
plt.fill_between(range(speech_0.shape[0]), speech_0 * .8,color='red', alpha=0.7)


plt.subplot(6, 1, 2)
# plt.tight_layout()
plt.ylim(-1., 1.)
plt.plot(channel_1)
plt.title('channel 1', x=0.5, y=0.6)
plt.fill_between(range(speech_1.shape[0]), speech_1 * .8,color='green', alpha=0.7)

plt.subplot(6, 1, 3)
plt.title('spk_A_prediction', x=0.5, y=0.6)
plt.fill_between(range(spk_0.shape[0]), spk_0 * .8,color='purple', alpha=0.7)

plt.subplot(6, 1, 4)
plt.title('spk_B_prediction', x=0.5, y=0.6)
plt.fill_between(range(spk_1.shape[0]), spk_1 * .8,color='purple', alpha=0.7)

plt.subplot(6, 1, 5)
plt.title('speech mono', x=0.5, y=0.6)
plt.fill_between(range(speech_mono.shape[0]), speech_mono * .8,color='gray', alpha=0.7)

plt.subplot(6, 1, 6)
plt.title('mono', x=0.5, y=0.6)
plt.fill_between(range(mono.shape[0]), mono * .8,color='gray', alpha=0.7)

plt.show()


# Perform on all talkbank callhome 8 sec recordings

In [None]:

audio_files = os.listdir(CALLHOME_ENG_10_SEC)
# print(audio_files)

for audio_file in audio_files:
    try:
        fp = CALLHOME_ENG_10_SEC + audio_file
        print(audio_file)
        sample_rate, audio = load_audio(str(fp), 8000, mono=False)

        plt.figure(figsize=(14, 5))

        vad=None                                                                                                                                                    
        with open(GMM_FILE, 'rb') as fid:                                                                                                             
            vad = pickle5.load(fid)                

        channel_0, channel_1 = audio[0], audio[1]
        speech_0 = vad.detect_speech(signal=channel_0, sampling_rate=8000, fit_to_audio=True)
        speech_1 = vad.detect_speech(signal=channel_1, sampling_rate=8000, fit_to_audio=True)

        overlapping_speech = np.logical_and(speech_0, speech_1)
        channel_0, channel_1 = channel_0/32768, channel_1/32768

        _, mono = load_audio(fp, 8000, mono=True)
        speech_mono = vad.detect_speech(signal=mono,sampling_rate=8000, fit_to_audio=True)

        draw_speech_and_channel(speech_mono, mono, 'mono', 'purple')

        speech_indexes = vad.detect_speech(signal=mono, sampling_rate=8000, fit_to_audio=True)
        speech_indexes = np.logical_and(speech_indexes, ~overlapping_speech)
        filtered_channel = mono[speech_indexes]
        



        feats = prepare_mono_for_forward(filtered_mono_channel=filtered_channel, sampling_rate=8000).to('cpu')

        # prepared_feats = prepare_mono_for_forward(vad=vad, mono_channel = mono, sampling_rate=8000)
        prepared_feats = prepare_mono_for_forward(filtered_mono_channel=filtered_channel, sampling_rate=8000)
        d_vectors = model.forward(prepared_feats)

        from spectral_cluster import get_affinity_matrix, cluster_affinity
        import sys
        import numpy
        numpy.set_printoptions(threshold=sys.maxsize)


        s = get_affinity_matrix(d_vectors.squeeze(0))
        pred_labels = cluster_affinity(s.detach().cpu().numpy(),dtype=np.int)
        
        print(pred_labels)
        



        spk_0, spk_1 = labels_to_audio_size(pred_labels, mono, speech_indexes)

        plt.figure(figsize=(14, 7))

        plt.subplot(6, 1, 1)
        # plt.tight_layout()


        plt.ylim(-1., 1.)
        plt.plot(channel_0)
        plt.title("channel_0", x=0.5, y=0.6)
        plt.fill_between(range(speech_0.shape[0]), speech_0 * .8,color='red', alpha=0.7)


        plt.subplot(6, 1, 2)
        # plt.tight_layout()
        plt.ylim(-1., 1.)
        plt.plot(channel_1)
        plt.title('channel 1', x=0.5, y=0.6)
        plt.fill_between(range(speech_1.shape[0]), speech_1 * .8,color='green', alpha=0.7)

        plt.subplot(6, 1, 3)
        plt.title('spk_A_prediction', x=0.5, y=0.6)
        plt.fill_between(range(spk_0.shape[0]), spk_0 * .8,color='purple', alpha=0.7)

        plt.subplot(6, 1, 4)
        plt.title('spk_B_prediction', x=0.5, y=0.6)
        plt.fill_between(range(spk_1.shape[0]), spk_1 * .8,color='purple', alpha=0.7)

        plt.subplot(6, 1, 5)
        plt.title('speech mono', x=0.5, y=0.6)
        plt.fill_between(range(speech_mono.shape[0]), speech_mono * .8,color='gray', alpha=0.7)

        plt.subplot(6, 1, 6)
        plt.title('mono', x=0.5, y=0.6)
        plt.fill_between(range(mono.shape[0]), mono * .8,color='gray', alpha=0.7)

        plt.show()
        print()
        print()
        
    except:
        pass

