In [8]:
import collections
import logging
import os
import shutil
import subprocess
import torch
import torchaudio
import wave
from distutils.dir_util import copy_tree
from glob import glob
from pyannote.audio import Pipeline
from pydub import AudioSegment
from speechbrain.pretrained import SpectralMaskEnhancement
from tqdm import tqdm

In [2]:
logging.basicConfig(filename='logs.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [3]:
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token="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")
pipeline.to(torch.device("cuda"))

<pyannote.audio.pipelines.speaker_diarization.SpeakerDiarization at 0x7fb11910c100>

In [4]:
audio_paths = glob('./audio_raw/*')
for audio_path in tqdm(audio_paths):
    audio_name = audio_path.split('/')[-1]
    audio_name_no_ext = audio_name.split('.')[0]
    output_path = f'./wav_raw/{audio_name_no_ext}.wav'
    ffmpeg_script = f'ffmpeg -i {audio_path} -vn -acodec pcm_s16le -ar 44100 -ac 2 {output_path} -y'
    ffmpeg_output = subprocess.run(
        ffmpeg_script, 
        shell=True, 
        check=True, 
        capture_output=True,
        text=True)
    logger.info(ffmpeg_output)

100%|█████████████████████████████████████████████████████████████████████████| 55/55 [03:48<00:00,  4.15s/it]


In [5]:
def get_wav_duration(file_path):
    with wave.open(file_path, 'rb') as wav_file:
        num_frames = wav_file.getnframes()
        frame_rate = wav_file.getframerate()
        duration = num_frames / float(frame_rate)
        return duration


wav_paths = glob('./wav_raw/*.wav')
for wav_path in tqdm(wav_paths):
    duration = get_wav_duration(wav_path)
    wav_name = wav_path.split('/')[-1]
    wav_name_no_ext = wav_name.split('.')[0]
    if duration > 600:
        segment_wav_path = f'./wav_10_minutes/{wav_name_no_ext}_%06d.wav'
        ffmpeg_script = f'ffmpeg -i {wav_path} -f segment -segment_time 600 -c copy {segment_wav_path}'
        ffmpeg_output = subprocess.run(
            ffmpeg_script, 
            shell=True, 
            check=True, 
            capture_output=True,
            text=True)
        logger.info(ffmpeg_output)
    else:
        shutil.copy(wav_path, f'./wav_10_minutes/{wav_name_no_ext}.wav')

100%|█████████████████████████████████████████████████████████████████████████| 55/55 [01:52<00:00,  2.05s/it]


In [6]:
def diarize_and_remove_overlap(segmented_wav_path):
    diarization = pipeline(segmented_wav_path)
    start_end_speakers_list = []
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        start_end = [turn.start, turn.end]
        start_end_speakers_list.append([start_end, speaker])
    start_end_speakers_list = sorted(start_end_speakers_list, key=lambda sublist: sublist[0][0])
    start_end_speakers_list_no_overlap = []
    compared_sublist = start_end_speakers_list[0]
    for index in range(1, len(start_end_speakers_list)):
        current_sublist = start_end_speakers_list[index]
        if compared_sublist[0][1] > current_sublist[0][0]:
            continue
        start_end_speakers_list_no_overlap.append(current_sublist)
        compared_sublist = current_sublist
    logger.info(
        f'Original diarized segments: {len(start_end_speakers_list)}'
        f'Overlap removed: {len(start_end_speakers_list_no_overlap)}'
    )
    logger.info(f'Lost {1 - len(start_end_speakers_list_no_overlap) / (len(start_end_speakers_list)):.2f}%')
    speaker_dict = {segmented_wav_path: {}}
    for start_end, speaker in start_end_speakers_list_no_overlap:
        if speaker not in speaker_dict[segmented_wav_path]:
            speaker_dict[segmented_wav_path][speaker] = []
        speaker_dict[segmented_wav_path][speaker].append(start_end)
    return speaker_dict

In [7]:
speaker_dict_list = []
segmented_wav_paths = glob('./wav_10_minutes/*.wav')
for segmented_wav_path in tqdm(segmented_wav_paths):
    segmented_wav_name = segmented_wav_path.split('/')[-1]
    segmented_wav_name_no_ext = segmented_wav_name.split('.')[0]
    temp_speaker_dict = diarize_and_remove_overlap(segmented_wav_path)
    speaker_dict_list.append(temp_speaker_dict)

100%|███████████████████████████████████████████████████████████████████| 442/442 [18:52:43<00:00, 153.76s/it]


In [23]:
# created diarized_results dir

In [24]:
for speaker_dict in tqdm(speaker_dict_list):
    segmented_wav_name = list(speaker_dict.keys())[0].split('/')[-1]
    segmented_wav_name_no_ext = segmented_wav_name.split('.')[0]
    diarized_sub_dir_path = f'./diarized_results/{segmented_wav_name_no_ext}'
    if not os.path.exists(diarized_sub_dir_path):
        os.mkdir(diarized_sub_dir_path)
    wav_source = AudioSegment.from_wav(list(speaker_dict.keys())[0])
    for speaker in speaker_dict[list(speaker_dict.keys())[0]]:
        speaker_path = f'{diarized_sub_dir_path}/{speaker}'
        if not os.path.exists(speaker_path):
            os.mkdir(speaker_path)
        speaker_snippet_index = 0
        for start_end in speaker_dict[list(speaker_dict.keys())[0]][speaker]:
            t1 = start_end[0] * 1000
            t2 = start_end[1] * 1000
            if (t2 - t1) < 3000:
                continue
            wav_snippet = wav_source[t1: t2]
            wav_snippet_path = f'{speaker_path}/{segmented_wav_name_no_ext}_{speaker_snippet_index}.wav'
            wav_snippet.export(wav_snippet_path, format="wav")
            speaker_snippet_index += 1

100%|███████████████████████████████████████████████████████████████████████| 442/442 [01:14<00:00,  5.96it/s]


In [133]:
# TODO: Why are enhanced results bad?
# copied_paths = copy_tree('./diarized_results', './diarized_results_enhanced')
# enhance_model = SpectralMaskEnhancement.from_hparams(
#     source="speechbrain/metricgan-plus-voicebank",
#     savedir="./pretrained_models/metricgan-plus-voicebank",
# )

# for copied_path in copied_paths:
#     noisy = enhance_model.load_audio(copied_path).unsqueeze(0)
#     enhanced = enhance_model.enhance_batch(noisy, lengths=torch.tensor([1.]))
#     torchaudio.save(copied_path, enhanced.cpu(), 16000)

In [13]:
def get_wav_sample_rate(file_path):
    with wave.open(file_path, 'rb') as wav_file:
        frame_rate = wav_file.getframerate()
        return frame_rate

In [19]:
diarized_results = glob('./diarized_results/*/*/*')
len(diarized_results)

17538

In [21]:
sample_rate_dict = {}
for diarized_wav_path in diarized_results:
    temp_sample_rate = get_wav_sample_rate(diarized_wav_path)
    if temp_sample_rate not in sample_rate_dict:
        sample_rate_dict[temp_sample_rate] = []
    sample_rate_dict[temp_sample_rate].append(diarized_wav_path)
sample_rate_count_dict = {}
for sample_key in list(sample_rate_dict):
    sample_rate_count_dict[sample_key] = len(sample_rate_dict[sample_key])
sample_rate_count_dict = collections.OrderedDict(sorted(sample_rate_count_dict.items()))
print(sample_rate_count_dict)

OrderedDict([(44100, 17538)])
