In [None]:
import os
os.environ["HF_TOKEN"] = "..."

In [None]:
from datasets import load_dataset, Audio
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.auto import tqdm

from src.speaker_diarization import SpeakerDiarizationWrapper, get_speech_mask
from src.audioset_utils import pad_or_trim_to_len

In [None]:
sova = (
    load_dataset('datasets/sova_128k', data_files='00000-of-00608.parquet', split='train')
    .cast_column('audio', Audio(decode=True, sampling_rate=16_000))
).shuffle(0)

In [None]:
yodas = (
    load_dataset('datasets/yodas_ru000_128k_filtered5', data_files='00000-of-00079.parquet', split='train')
    .cast_column('audio', Audio(decode=True, sampling_rate=16_000))
).shuffle(0)

In [None]:
sova_audios = [s['audio']['array'] for s in sova.take(500)]
yodas_audios = [s['audio']['array'] for s in yodas.take(500)]

In [None]:
plt.hist([len(w) / 16_000 for w in sova_audios], bins=np.linspace(0, 30, num=60))
plt.hist([len(w) / 16_000 for w in yodas_audios], bins=np.linspace(0, 30, num=60), weights=-np.ones(len(yodas_audios)))
plt.grid()
plt.show()

In [None]:
speaker_diarization = SpeakerDiarizationWrapper(
    segmentation_batch_size=32,
    embedding_batch_size=32,
    device='cpu',
)

In [None]:
waveform = sova_audios[0]

is_speech_sova = []
for waveform in tqdm(sova_audios):
    segments = speaker_diarization.predict_on_long_audio(
        waveform, sampling_rate=16_000
    ).segments
    speech_mask = get_speech_mask(segments, duration=len(waveform) / 16_000)
    speech_mask = pad_or_trim_to_len(speech_mask, 16_000 * 30)
    is_speech_sova.append(speech_mask)

In [None]:
speech_mask

In [None]:
plt.imshow(is_speech_mask[None], aspect='auto', cmap='tab10', interpolation='none', vmin=0, vmax=9)