In [None]:
%matplotlib inline

import os
from nemo.collections.asr.models import EncDecSpeakerLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
import librosa
import torch
import pickle
import IPython.display as ipd
import numpy as np
import matplotlib.patches as mpatches
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import pylab as plt
import matplotlib.patches as mpatches
import json

In [None]:
# Directory containing real and synthesized samples from various methods for different speakers.
samples_dir = "/home/pneekhara/synthesized_samples3/"

wav_paths = {}
for fname in os.listdir(samples_dir):
    if fname.endswith(".wav"):
        wav_type, method_info, speaker, val_no = fname.split("_")
        val_no = val_no.split(".")[0]
        key = "{}_{}_{}".format(wav_type, method_info, speaker)
        if "real_actual" in key:
            if key in wav_paths:
                wav_paths[key].append( ( int(val_no), os.path.join(samples_dir, fname) ) )
            else:
                wav_paths[key] = [ (int(val_no), os.path.join(samples_dir, fname)) ]

for key in wav_paths:
    wav_paths[key].sort()
    wav_paths[key] = [ t[1] for t in wav_paths[key] ]

In [None]:
for key in wav_paths:
    print (key, len(wav_paths[key]), wav_paths[key][:2])

In [None]:
wav_featurizer = WaveformFeaturizer(sample_rate=44100, int_values=False, augmentor=None)
mel_processor = AudioToMelSpectrogramPreprocessor(
        window_size = None,
        window_stride = None,
        sample_rate=44100,
        n_window_size=2048,
        n_window_stride=512,
        window="hann",
        normalize=None,
        n_fft=None,
        preemph=None,
        features=80,
        lowfreq=0,
        highfreq=None,
        log=True,
        log_zero_guard_type="add",
        log_zero_guard_value=1e-05,
        dither=0.0,
        pad_to=1,
        frame_splicing=1,
        exact_pad=False,
        stft_exact_pad=False,
        stft_conv=False,
        pad_value=0,
        mag_power=1.0
)

speaker_verification_model = EncDecSpeakerLabelModel.from_pretrained("speakerverification_speakernet")
speaker_verification_model.eval().cuda()

In [None]:
pickle_path = "/home/pneekhara/synthesized_audio_demo.pkl"
regenerate = False
if os.path.exists(pickle_path) and not regenerate:
    with open(pickle_path, "rb") as f:
        meta_data = pickle.load(f)
        embeddings = meta_data['embeddings']
else:
    embeddings = {}
    for key in wav_paths:
        print ("Getting embeddings for:", key, len(wav_paths[key]))
        embeddings[key] = []
        for path in wav_paths[key]:
            embedding = speaker_verification_model.get_embedding(path)
            embeddings[key].append(embedding.cpu().numpy().flatten())

    with open("synthesized_audio_meta_data.pkl", "wb") as f:
        meta_data = {
            'embeddings' : embeddings
        }
        pickle.dump(meta_data, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
def mscatter(x,y, ax=None, m=None, **kw):
    import matplotlib.markers as mmarkers
    #ax = ax or plt.gca()
    sc = plt.scatter(x,y,**kw)
    if (m is not None) and (len(m)==len(x)):
        paths = []
        for marker in m:
            if isinstance(marker, mmarkers.MarkerStyle):
                marker_obj = marker
            else:
                marker_obj = mmarkers.MarkerStyle(marker)
            path = marker_obj.get_path().transformed(
                        marker_obj.get_transform())
            paths.append(path)
        sc.set_paths(paths)
    return sc

def visualize_embeddings(embedding_dict_np, title = "TSNE"):
    """
    Arguments:
    embedding_dict_np : Dictionary with keys as speaker ids/labels and value as list of np arrays (embeddings).
    """
    color = []
    marker_shape = []
    universal_embed_list=[]
    handle_list=[]  
    _unique_speakers = {}
    
    marker_list = ['<', '*', 'h', 'X', 's', 'H', 'D', 'd', 'P', 'v', '^', '>', '8', 'p']
        
    for kidx, key in enumerate(embedding_dict_np.keys()):
        universal_embed_list += embedding_dict_np[key]
        _num_samples = len(embedding_dict_np[key])
        
        id_color = plt.cm.tab20(kidx)
        _color = [id_color] * _num_samples
        color += _color
        _marker_shape = [ marker_list[kidx % len(marker_list)] ] * _num_samples
        marker_shape += _marker_shape
        _label = key
        handle_list.append(mpatches.Patch(color = id_color, label=_label))
        
   
    
    speaker_embeddings = TSNE(n_components=2, random_state=0).fit_transform(universal_embed_list)        
    
    mscatter(speaker_embeddings[:, 0], speaker_embeddings[:, 1], m = marker_shape,  c=color, s=60)
   
    plt.legend(handles=handle_list,title="Speaker_ID")
    plt.title(title)
    plt.show()



In [None]:
# embedding_real is a dictionary with keys as speaker ids, and values as list of 256 dimensional embeddings (np arrays)
for key in embeddings:
    print (key, len(embeddings[key]), len(embeddings[key][0]), type(embeddings[key][0]) )

In [None]:
plt.rcParams["figure.figsize"] = (12, 12)
visualize_embeddings(embeddings)