In [None]:
import torch
from pyannote.database import get_protocol, FileFinder

emb = torch.hub.load('pyannote/pyannote-audio', 'emb')
print(f'Embedding has dimension {emb.dimension:d}.')

preprocessors = {'audio': FileFinder()}
protocol = get_protocol('VOXCON.SpeakerDiarization.Sample', preprocessors=preprocessors)

In [None]:
test_file = next(protocol.test())

embeddings = emb(test_file)

chunks = embeddings.sliding_window
print(f'Embeddings were extracted every {1000 * chunks.step:g}ms on {1000 * chunks.duration:g}ms-long windows.')

In [None]:
test_file = (next(protocol.test()))
protocol.test()

In [None]:
import numpy as np
from IPython.display import display, clear_output

X, Y = [], []
length = len(embeddings)

for id, (window, embedding) in enumerate(embeddings):
    # average speech turn embedding
    X.append(np.nanmean(embedding, axis=0))

    # keep track of speaker label (for later scatter plot)
    y = test_file['annotation'].argmax(window)
    Y.append(y)
    clear_output(wait=True)
    display(f'{id+1} {100*(id+1)/length:g}%')

X = np.vstack(X)
_, y_true = np.unique(Y, return_inverse=True)

In [None]:
from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, metric="cosine")
X_2d = tsne.fit_transform(X)

# plot 
fig, ax = plt.subplots()
fig.set_figheight(5)
fig.set_figwidth(5)
plt.clf()
plt.scatter(*X_2d.T, c=y_true)



In [None]:
# for resource in protocol.test():
#     print(resource["audio"])
#     print(resource["uri"])

# test_file = next(protocol.test())
# test_file["audio"]

###########################################

# sw = SlidingWindow(duration=4, step=1, start=0.0, end=len(embeddings))

# for segment in sw:
#     # "strict" only keeps embedding strictly included in segment
#     x = embeddings.crop(segment, mode='strict')

############################################

# from pyannote.core import Segment
# import numpy as np

# for id, (window, emb) in enumerate(embeddings):
#     print(window, emb)
#     assert isinstance(window, Segment)
#     assert isinstance(emb, np.ndarray)