This notebook attempts speaker diarization.

I would have liked to have used [`speechbox`](https://github.com/huggingface/speechbox), but I run into some dependency conflicts involving `torch` and `pyannote.audio`.

## Set-up

In [None]:
!pip install -qq https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
!pip install huggingface_hub openai-whisper

In [None]:
# only if you're accessing this notebook in the cloud

# !git clone https://github.com/ralampay/sherpa-prototype.git
# %cd sherpa-prototype/notebooks

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
AUDIO_FILENAME = 'audio/1.mp3' #@param{'type': 'string'}

In [None]:
!mkdir -p diarization_outputs

## Transcribe

In [None]:
import whisper
import json

recognizer = whisper.load_model("medium")
speech = recognizer.transcribe(AUDIO_FILENAME)

with open('diarization_outputs/1.json', 'w') as f:
  json.dump(speech, f)

## Diarize

In [None]:
from pyannote.audio import Pipeline
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
diarizer = Pipeline.from_pretrained("pyannote/speaker-diarization").to(device)

In [None]:
diarization = diarizer(AUDIO_FILENAME)
with open("diarization_outputs/1.rttm", "w") as rttm:
  diarization.write_rttm(rttm)

## Combine diarization with transcript

In [None]:
from pyannote.database.util import load_rttm
import json

diarization = load_rttm('diarization_outputs/1.rttm')['1']

with open('diarization_outputs/1.json', 'r') as f:
  speech = json.load(f)

In [None]:
# view transcript

for speech_segment in speech['segments']:
  print(f"start={speech_segment['start']:.1f}s stop={speech_segment['end']:.1f}s {speech_segment['text']}")

In [None]:
# view diarization

for turn, _, speaker in diarization.itertracks(yield_label=True):
  print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}")

diarization

In [None]:
# aggregate consecutive speaker segments with the same speaker id
# code from https://github.com/huggingface/speechbox/blob/main/src/speechbox/diarize.py

old_speaker_segments = list(diarization.itertracks(yield_label=True))
speaker_segments = [{'start': old_speaker_segments[0][0].start, 'end': old_speaker_segments[0][0].end, 'speaker': old_speaker_segments[0][-1]}]
for turn, _, speaker in old_speaker_segments[1:]:
  if speaker == speaker_segments[-1]['speaker']:
    speaker_segments[-1]['end'] = turn.end
  else:
    speaker_segments.append({'start': turn.start, 'end': turn.end, 'speaker': speaker})

In [None]:
# create diarized transcript
# code from https://github.com/huggingface/speechbox/blob/main/src/speechbox/diarize.py

import numpy as np
asr_end_times = np.array([seg['end'] for seg in speech['segments']])
asr_end_times

diarized_transcript = []
speech_segments = speech['segments']
for speaker_seg in speaker_segments:
  stop_idx = np.argmin(np.abs(asr_end_times - speaker_seg['end']))
  diarized_transcript.append({'start': speaker_seg['start'], 'end': asr_end_times[stop_idx], 'speaker': speaker_seg['speaker'], 'text': ' '.join([speech_seg['text'] for speech_seg in speech_segments[:stop_idx+1]])})
  speech_segments = speech_segments[stop_idx+1:]
  asr_end_times = asr_end_times[stop_idx+1:]

In [None]:
# view diarized transcript

for seg in diarized_transcript:
  print(f'{seg["speaker"].replace("_", " ")}: {seg["text"]}')