In [1]:
from jadia import Jadia, Segments
from jadia.metrics import DER
import torch

### Diarize

In [2]:
# Load model
diarizer = Jadia(
    device=torch.device("cuda:0"),
    model="lite",
    batch_size=64,
)

In [3]:
# Params
NUM_VOICES = 5
AUDIO_FILENAME = 'S_R004S04C01.flac'
RTTM_FILENAME = 'output.rttm'
PREDICTIONS_IMAGE_FILENAME = 'preds.png'
SEGMENTS_IMAGE_FILENAME = 'segments.png'
GROUND_TRUTH_RTTM = 'S_R004S04C01.rttm'

In [4]:
# Run as a single command
# segments = diarizer.process(AUDIO_FILENAME, num_voices=NUM_VOICES)
# or step by step

diarizer.setup(num_voices=NUM_VOICES)
diarizer.load_audio(filename=AUDIO_FILENAME)
predictions = diarizer.predict()
segments = diarizer.predictions_to_segments(predictions)

# save to RTTM
segments.to_rttm_save(fpath=RTTM_FILENAME, overwrite=True)

### Eval

In [5]:
reference = Segments().from_rttm_load(GROUND_TRUTH_RTTM)
pass
der = DER(reference, segments)
print(f"       Confusion: {der['confusion']:.4f}")
print(f"     False alarm: {der['false alarm']:.4f}")
print(f"Missed detection: {der['missed detection']:.4f}")
print(f"           Total: {der['total']:.4f}")
print(f"             DER: {der['diarization error rate']:.4f}")




       Confusion: 1847.1136
     False alarm: 5808.6176
Missed detection: 229.2205
           Total: 2078.8480
             DER: 3.7929


### Plot (required jadia-plot to be installed)

In [6]:
import jadia_plot as plot

In [None]:
# plot segments
plot.plot_segments(
    pred=segments,
    ground_truth=reference,
    filename=SEGMENTS_IMAGE_FILENAME,
)
# or predictions (+segments)
plot.plot_predictions(
    predictions=predictions,
    segments=segments,
    filename=PREDICTIONS_IMAGE_FILENAME,
    ground_truth=reference,
)