In [1]:
import json
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from omegaconf import DictConfig,OmegaConf
from pyannote.core import Annotation
from pyannote.metrics.diarization import DiarizationErrorRate
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only
from tqdm.notebook import tqdm

from nemo.collections.asr.metrics.der import score_labels
from nemo.collections.asr.models.clustering_diarizer import (
    _MODEL_CONFIG_YAML,
    _SPEAKER_MODEL,
    _VAD_MODEL,
    get_available_model_names,
)
from nemo.collections.asr.models.msdd_models import ClusterEmbedding,EncDecDiarLabelModel
from nemo.collections.asr.models.configs.diarizer_config import NeuralDiarizerInferenceConfig
from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel
from nemo.collections.asr.parts.utils.speaker_utils import (
    audio_rttm_map,
    get_id_tup_dict,
    get_uniq_id_list_from_manifest,
    labels_to_pyannote_object,
    make_rttm_with_overlap,
    rttm_to_labels,
)
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging

try:
    from torch.cuda.amp import autocast
except ImportError:
    from contextlib import contextmanager

    @contextmanager
    def autocast(enabled=None):
        yield

In [2]:
class NeuralDiarizer(LightningModule):
    def __init__(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]):
        super().__init__()
        self._cfg = cfg

        # Parameter settings for MSDD model
        self.use_speaker_model_from_ckpt = cfg.diarizer.msdd_model.parameters.get('use_speaker_model_from_ckpt', True)
        self.use_clus_as_main = cfg.diarizer.msdd_model.parameters.get('use_clus_as_main', False)
        self.max_overlap_spks = cfg.diarizer.msdd_model.parameters.get('max_overlap_spks', 2)
        self.num_spks_per_model = cfg.diarizer.msdd_model.parameters.get('num_spks_per_model', 2)
        self.use_adaptive_thres = cfg.diarizer.msdd_model.parameters.get('use_adaptive_thres', True)
        self.max_pred_length = cfg.diarizer.msdd_model.parameters.get('max_pred_length', 0)
        self.diar_eval_settings = cfg.diarizer.msdd_model.parameters.get(
            'diar_eval_settings', [(0.25, True), (0.25, False), (0.0, False)]
        )

        self._init_msdd_model(cfg)
        self.diar_window_length = cfg.diarizer.msdd_model.parameters.diar_window_length
        self.msdd_model.cfg = self.transfer_diar_params_to_model_params(self.msdd_model, cfg)

        # Initialize clustering and embedding preparation instance (as a diarization encoder).
        self.clustering_embedding = ClusterEmbedding(
            cfg_diar_infer=cfg, cfg_msdd_model=self.msdd_model.cfg, speaker_model=self._speaker_model
        )

        # Parameters for creating diarization results from MSDD outputs.
        self.clustering_max_spks = self.msdd_model._cfg.max_num_of_spks
        self.overlap_infer_spk_limit = cfg.diarizer.msdd_model.parameters.get(
            'overlap_infer_spk_limit', self.clustering_max_spks
        )

    def transfer_diar_params_to_model_params(self, msdd_model, cfg):
        msdd_model.cfg.diarizer.out_dir = cfg.diarizer.out_dir
        msdd_model.cfg.test_ds.manifest_filepath = cfg.diarizer.manifest_filepath
        msdd_model.cfg.test_ds.emb_dir = cfg.diarizer.out_dir
        msdd_model.cfg.test_ds.batch_size = cfg.diarizer.msdd_model.parameters.infer_batch_size
        msdd_model.cfg.test_ds.seq_eval_mode = cfg.diarizer.msdd_model.parameters.seq_eval_mode
        msdd_model._cfg.max_num_of_spks = cfg.diarizer.clustering.parameters.max_num_speakers
        return msdd_model.cfg

    @rank_zero_only
    def save_to(self, save_path: str):
        self.clus_diar = self.clustering_embedding.clus_diar_model
        _NEURAL_DIAR_MODEL = "msdd_model.nemo"

        with tempfile.TemporaryDirectory() as tmpdir:
            config_yaml = os.path.join(tmpdir, _MODEL_CONFIG_YAML)
            spkr_model = os.path.join(tmpdir, _SPEAKER_MODEL)
            neural_diar_model = os.path.join(tmpdir, _NEURAL_DIAR_MODEL)

            self.clus_diar.to_config_file(path2yaml_file=config_yaml)
            if self.clus_diar.has_vad_model:
                vad_model = os.path.join(tmpdir, _VAD_MODEL)
                self.clus_diar._vad_model.save_to(vad_model)
            self.clus_diar._speaker_model.save_to(spkr_model)
            self.msdd_model.save_to(neural_diar_model)
            self.clus_diar.__make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)

    def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.') -> EncDecSpeakerLabelModel:
        model_state_dict = self.msdd_model.state_dict()
        spk_emb_module_names = []
        for name in model_state_dict.keys():
            if prefix in name:
                spk_emb_module_names.append(name)

        spk_emb_state_dict = {}
        for name in spk_emb_module_names:
            org_name = name.replace(prefix, '')
            spk_emb_state_dict[org_name] = model_state_dict[name]

        _speaker_model = EncDecSpeakerLabelModel.from_config_dict(self.msdd_model.cfg.speaker_model_cfg)
        _speaker_model.load_state_dict(spk_emb_state_dict)
        return _speaker_model

    def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]):
        model_path = cfg.diarizer.msdd_model.model_path
        if model_path.endswith('.nemo'):
            logging.info(f"Using local nemo file from {model_path}")
            self.msdd_model = EncDecDiarLabelModel.restore_from(restore_path=model_path, map_location=cfg.device)
        elif model_path.endswith('.ckpt'):
            logging.info(f"Using local checkpoint from {model_path}")
            self.msdd_model = EncDecDiarLabelModel.load_from_checkpoint(
                checkpoint_path=model_path, map_location=cfg.device
            )
        else:
            if model_path not in get_available_model_names(EncDecDiarLabelModel):
                logging.warning(f"requested {model_path} model name not available in pretrained models, instead")
            logging.info("Loading pretrained {} model from NGC".format(model_path))
            self.msdd_model = EncDecDiarLabelModel.from_pretrained(model_name=model_path, map_location=cfg.device)
        # Load speaker embedding model state_dict which is loaded from the MSDD checkpoint.
        if self.use_speaker_model_from_ckpt:
            self._speaker_model = self.extract_standalone_speaker_model()
        else:
            self._speaker_model = None

    def get_pred_mat(self, data_list: List[Union[Tuple[int], List[torch.Tensor]]]) -> torch.Tensor:
        all_tups = tuple()
        for data in data_list:
            all_tups += data[0]
        n_est_spks = len(set(all_tups))
        digit_map = dict(zip(sorted(set(all_tups)), range(n_est_spks)))
        total_len = max([sess[1].shape[1] for sess in data_list])
        sum_pred = torch.zeros(total_len, n_est_spks)
        for _dim_tup, pred_mat in data_list:
            dim_tup = [digit_map[x] for x in _dim_tup]
            if len(pred_mat.shape) == 3:
                pred_mat = pred_mat.squeeze(0)
            if n_est_spks <= self.num_spks_per_model:
                sum_pred = pred_mat
            else:
                _end = pred_mat.shape[0]
                sum_pred[:_end, dim_tup] += pred_mat.cpu().float()
        sum_pred = sum_pred / (n_est_spks - 1)
        return sum_pred

    def get_integrated_preds_list(
        self, uniq_id_list: List[str], test_data_collection: List[Any], preds_list: List[torch.Tensor]
    ) -> List[torch.Tensor]:
        session_dict = get_id_tup_dict(uniq_id_list, test_data_collection, preds_list)
        output_dict = {uniq_id: [] for uniq_id in uniq_id_list}
        for uniq_id, data_list in session_dict.items():
            sum_pred = self.get_pred_mat(data_list)
            output_dict[uniq_id] = sum_pred.unsqueeze(0)
        output_list = [output_dict[uniq_id] for uniq_id in uniq_id_list]
        return output_list

    def get_emb_clus_infer(self, cluster_embeddings):
        """Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`."""
        self.msdd_model.emb_sess_test_dict = cluster_embeddings.emb_sess_test_dict
        self.msdd_model.clus_test_label_dict = cluster_embeddings.clus_test_label_dict
        self.msdd_model.emb_seq_test = cluster_embeddings.emb_seq_test

    @torch.no_grad()
    def diarize(self) -> Optional[List[Optional[List[Tuple[DiarizationErrorRate, Dict]]]]]:
        self.clustering_embedding.prepare_cluster_embs_infer()
        self.msdd_model.pairwise_infer = True
        self.get_emb_clus_infer(self.clustering_embedding)
        preds_list, targets_list, signal_lengths_list = self.run_pairwise_diarization()
        thresholds = list(self._cfg.diarizer.msdd_model.parameters.sigmoid_threshold)
        return [self.run_overlap_aware_eval(preds_list, threshold) for threshold in thresholds]

    def get_range_average(
        self, signals: torch.Tensor, emb_vectors: torch.Tensor, diar_window_index: int, test_data_collection: List[Any]
    ) -> Tuple[torch.Tensor, torch.Tensor, int]:
        emb_vectors_split = torch.zeros_like(emb_vectors)
        uniq_id = os.path.splitext(os.path.basename(test_data_collection.audio_file))[0]
        clus_label_tensor = torch.tensor([x[-1] for x in self.msdd_model.clus_test_label_dict[uniq_id]])
        for spk_idx in range(len(test_data_collection.target_spks)):
            stt, end = (
                diar_window_index * self.diar_window_length,
                min((diar_window_index + 1) * self.diar_window_length, clus_label_tensor.shape[0]),
            )
            seq_len = end - stt
            if stt < clus_label_tensor.shape[0]:
                target_clus_label_tensor = clus_label_tensor[stt:end]
                emb_seq, seg_length = (
                    signals[stt:end, :, :],
                    min(
                        self.diar_window_length,
                        clus_label_tensor.shape[0] - diar_window_index * self.diar_window_length,
                    ),
                )
                target_clus_label_bool = target_clus_label_tensor == test_data_collection.target_spks[spk_idx]

                # There are cases where there is no corresponding speaker in split range, so any(target_clus_label_bool) could be False.
                if any(target_clus_label_bool):
                    emb_vectors_split[:, :, spk_idx] = torch.mean(emb_seq[target_clus_label_bool], dim=0)

                # In case when the loop reaches the end of the sequence
                if seq_len < self.diar_window_length:
                    emb_seq = torch.cat(
                        [
                            emb_seq,
                            torch.zeros(self.diar_window_length - seq_len, emb_seq.shape[1], emb_seq.shape[2]).to(
                                signals.device
                            ),
                        ],
                        dim=0,
                    )
            else:
                emb_seq = torch.zeros(self.diar_window_length, emb_vectors.shape[0], emb_vectors.shape[1]).to(
                    signals.device
                )
                seq_len = 0
        return emb_vectors_split, emb_seq, seq_len

    def get_range_clus_avg_emb(
        self, test_batch: List[torch.Tensor], _test_data_collection: List[Any], device: torch.device('cpu')   # type:ignore
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        _signals, signal_lengths, _targets, _emb_vectors = test_batch
        sess_emb_vectors, sess_emb_seq, sess_sig_lengths = [], [], []
        split_count = torch.ceil(torch.tensor(_signals.shape[1] / self.diar_window_length)).int()
        self.max_pred_length = max(self.max_pred_length, self.diar_window_length * split_count)
        for k in range(_signals.shape[0]):
            signals, emb_vectors, test_data_collection = _signals[k], _emb_vectors[k], _test_data_collection[k]
            for diar_window_index in range(split_count):
                emb_vectors_split, emb_seq, seq_len = self.get_range_average(
                    signals, emb_vectors, diar_window_index, test_data_collection
                )
                sess_emb_vectors.append(emb_vectors_split)
                sess_emb_seq.append(emb_seq)
                sess_sig_lengths.append(seq_len)
        sess_emb_vectors = torch.stack(sess_emb_vectors).to(device)
        sess_emb_seq = torch.stack(sess_emb_seq).to(device)
        sess_sig_lengths = torch.tensor(sess_sig_lengths).to(device)
        return sess_emb_vectors, sess_emb_seq, sess_sig_lengths

    def diar_infer(
        self, test_batch: List[torch.Tensor], test_data_collection: List[Any]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
       
        signals, signal_lengths, _targets, emb_vectors = test_batch
        if self._cfg.diarizer.msdd_model.parameters.split_infer:
            split_count = torch.ceil(torch.tensor(signals.shape[1] / self.diar_window_length)).int()
            sess_emb_vectors, sess_emb_seq, sess_sig_lengths = self.get_range_clus_avg_emb(
                test_batch, test_data_collection, device=self.msdd_model.device
            )
            with autocast():
                _preds, scale_weights = self.msdd_model.forward_infer(
                    input_signal=sess_emb_seq,
                    input_signal_length=sess_sig_lengths,
                    emb_vectors=sess_emb_vectors,
                    targets=None,
                )
            _preds = _preds.reshape(len(signal_lengths), split_count * self.diar_window_length, -1)
            _preds = _preds[:, : signals.shape[1], :]
        else:
            with autocast():
                _preds, scale_weights = self.msdd_model.forward_infer(
                    input_signal=signals, input_signal_length=signal_lengths, emb_vectors=emb_vectors, targets=None
                )
        self.max_pred_length = max(_preds.shape[1], self.max_pred_length)
        preds = torch.zeros(_preds.shape[0], self.max_pred_length, _preds.shape[2])
        targets = torch.zeros(_preds.shape[0], self.max_pred_length, _preds.shape[2])
        preds[:, : _preds.shape[1], :] = _preds
        return preds, targets, signal_lengths

    @torch.no_grad()
    def run_pairwise_diarization(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
        
        self.out_rttm_dir = self.clustering_embedding.out_rttm_dir
        self.msdd_model.setup_test_data(self.msdd_model.cfg.test_ds)
        self.msdd_model.eval()
        cumul_sample_count = [0]
        preds_list, targets_list, signal_lengths_list = [], [], []
        uniq_id_list = get_uniq_id_list_from_manifest(self.msdd_model.cfg.test_ds.manifest_filepath)
        test_data_collection = [d for d in self.msdd_model.data_collection]
        for sidx, test_batch in enumerate(tqdm(self.msdd_model.test_dataloader())):
            signals, signal_lengths, _targets, emb_vectors = test_batch
            cumul_sample_count.append(cumul_sample_count[-1] + signal_lengths.shape[0])
            preds, targets, signal_lengths = self.diar_infer(
                test_batch, test_data_collection[cumul_sample_count[-2] : cumul_sample_count[-1]]
            )
            if self._cfg.diarizer.msdd_model.parameters.seq_eval_mode:
                self.msdd_model._accuracy_test(preds, targets, signal_lengths)

            preds_list.extend(list(torch.split(preds, 1)))
            targets_list.extend(list(torch.split(targets, 1)))
            signal_lengths_list.extend(list(torch.split(signal_lengths, 1)))

        if self._cfg.diarizer.msdd_model.parameters.seq_eval_mode:
            f1_score, simple_acc = self.msdd_model.compute_accuracies()
            logging.info(f"Test Inference F1 score. {f1_score:.4f}, simple Acc. {simple_acc:.4f}")
        integrated_preds_list = self.get_integrated_preds_list(uniq_id_list, test_data_collection, preds_list)
        return integrated_preds_list, targets_list, signal_lengths_list

    def run_overlap_aware_eval(
        self, preds_list: List[torch.Tensor], threshold: float
    ) -> List[Optional[Tuple[DiarizationErrorRate, Dict]]]:
        
        logging.info(
            f"     [Threshold: {threshold:.4f}] [use_clus_as_main={self.use_clus_as_main}] [diar_window={self.diar_window_length}]"
        )
        outputs = []
        manifest_filepath = self.msdd_model.cfg.test_ds.manifest_filepath
        rttm_map = audio_rttm_map(manifest_filepath)
        for k, (collar, ignore_overlap) in enumerate(self.diar_eval_settings):
            all_reference, all_hypothesis = make_rttm_with_overlap(
                manifest_filepath,
                self.msdd_model.clus_test_label_dict,
                preds_list,
                threshold=threshold,
                infer_overlap=True,
                use_clus_as_main=self.use_clus_as_main,
                overlap_infer_spk_limit=self.overlap_infer_spk_limit,
                use_adaptive_thres=self.use_adaptive_thres,
                max_overlap_spks=self.max_overlap_spks,
                out_rttm_dir=self.out_rttm_dir,
            )
            output = score_labels(
                rttm_map,
                all_reference,
                all_hypothesis,
                collar=collar,
                ignore_overlap=ignore_overlap,
                verbose=self._cfg.verbose,
            )
            outputs.append(output)
        logging.info(f"  \n")
        return outputs

    @classmethod
    def from_pretrained(
        cls,
        model_name: str,
        vad_model_name: str = 'vad_multilingual_marblenet',
        map_location: Optional[str] = None,
        verbose: bool = False,
    ):
        
        logging.setLevel(logging.INFO if verbose else logging.WARNING)
        cfg = NeuralDiarizerInferenceConfig.init_config(
            diar_model_path=model_name,
            vad_model_path=vad_model_name,
            map_location=map_location,
            verbose=verbose,
        )
        return cls(cfg)

    def __call__(
        self,
        audio_filepath: str,
        batch_size: int = 64,
        num_workers: int = 1,
        max_speakers: Optional[int] = None,
        num_speakers: Optional[int] = None,
        out_dir: Optional[str] = None,
        verbose: bool = False,
    ) -> Union[Annotation, List[Annotation]]:
        
        if out_dir:
            os.makedirs(out_dir, exist_ok=True)
        with tempfile.TemporaryDirectory(dir=out_dir) as tmpdir:
            manifest_path = os.path.join(tmpdir, 'manifest.json')
            meta = [
                {
                    'audio_filepath': audio_filepath,
                    'offset': 0,
                    'duration': None,
                    'label': 'infer',
                    'text': '-',
                    'num_speakers': num_speakers,
                    'rttm_filepath': None,
                    'uem_filepath': None,
                }
            ]

            with open(manifest_path, 'w') as f:
                f.write('\n'.join(json.dumps(x) for x in meta))

            self._initialize_configs(
                manifest_path=manifest_path,
                max_speakers=max_speakers,
                num_speakers=num_speakers,
                tmpdir=tmpdir,
                batch_size=batch_size,
                num_workers=num_workers,
                verbose=verbose,
            )

            self.msdd_model.cfg.test_ds.manifest_filepath = manifest_path
            self.diarize()

            pred_labels_clus = rttm_to_labels(f'{tmpdir}/pred_rttms/{Path(audio_filepath).stem}.rttm')
        return labels_to_pyannote_object(pred_labels_clus)

    def _initialize_configs(
        self,
        manifest_path: str,
        max_speakers: Optional[int],
        num_speakers: Optional[int],
        tmpdir: tempfile.TemporaryDirectory,
        batch_size: int,
        num_workers: int,
        verbose: bool,
    ) -> None:
        self._cfg.batch_size = batch_size
        self._cfg.num_workers = num_workers
        self._cfg.diarizer.manifest_filepath = manifest_path
        self._cfg.diarizer.out_dir = tmpdir
        self._cfg.verbose = verbose
        self._cfg.diarizer.clustering.parameters.oracle_num_speakers = num_speakers is not None
        if max_speakers:
            self._cfg.diarizer.clustering.parameters.max_num_speakers = max_speakers
        self.transfer_diar_params_to_model_params(self.msdd_model, self._cfg)

    @classmethod
    def list_available_models(cls) -> List[PretrainedModelInfo]:
        return EncDecDiarLabelModel.list_available_models()


In [3]:
config=OmegaConf.load('/mnt/d/Programs/Python/PW/projects/speech/diarization/nemo_diarization/config.yaml')

In [4]:
config

{'name': 'ClusterDiarizer', 'num_workers': 1, 'sample_rate': 16000, 'batch_size': 64, 'device': None, 'verbose': True, 'diarizer': {'manifest_filepath': '/mnt/d/Programs/Python/PW/projects/speech/diarization/nemo_diarization/manifest.txt', 'out_dir': '/mnt/d/Programs/Python/PW/projects/speech/diarization/nemo_diarization/out_dir', 'oracle_vad': False, 'collar': 0.25, 'ignore_overlap': True, 'vad': {'model_path': 'vad_multilingual_marblenet', 'external_vad_manifest': None, 'parameters': {'window_length_in_sec': 0.15, 'shift_length_in_sec': 0.01, 'smoothing': 'median', 'overlap': 0.5, 'onset': 0.1, 'offset': 0.1, 'pad_onset': 0.1, 'pad_offset': 0, 'min_duration_on': 0, 'min_duration_off': 0.2, 'filter_speech_first': True}}, 'speaker_embeddings': {'model_path': 'titanet_large', 'parameters': {'window_length_in_sec': [1.5, 1.25, 1.0, 0.75, 0.5], 'shift_length_in_sec': [0.75, 0.625, 0.5, 0.375, 0.25], 'multiscale_weights': [1, 1, 1, 1, 1], 'save_embeddings': True}}, 'clustering': {'paramete

In [5]:
class NemoNeuralDiarizer(LightningModule):
    def __init__(self,cfg:OmegaConf):
        super(NemoNeuralDiarizer,self).__init__()
        self._cfg = cfg

        # Parameter settings for MSDD model
        self.use_speaker_model_from_ckpt = cfg.diarizer.msdd_model.parameters.get('use_speaker_model_from_ckpt', True)
        self.use_clus_as_main = cfg.diarizer.msdd_model.parameters.get('use_clus_as_main', False)
        self.max_overlap_spks = cfg.diarizer.msdd_model.parameters.get('max_overlap_spks', 2)
        self.num_spks_per_model = cfg.diarizer.msdd_model.parameters.get('num_spks_per_model', 2)
        self.use_adaptive_thres = cfg.diarizer.msdd_model.parameters.get('use_adaptive_thres', True)
        self.max_pred_length = cfg.diarizer.msdd_model.parameters.get('max_pred_length', 0)
        self.diar_eval_settings = cfg.diarizer.msdd_model.parameters.get(
            'diar_eval_settings', [(0.25, True), (0.25, False), (0.0, False)]
        )

        self._init_msdd_model(cfg)
        self.diar_window_length = cfg.diarizer.msdd_model.parameters.diar_window_length
        self.msdd_model.cfg = self.transfer_diar_params_to_model_params(self.msdd_model, cfg)

        # Initialize clustering and embedding preparation instance (as a diarization encoder).
        self.clustering_embedding = ClusterEmbedding(
            cfg_diar_infer=cfg, cfg_msdd_model=self.msdd_model.cfg, speaker_model=self._speaker_model
        )

    def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.') -> EncDecSpeakerLabelModel:
        model_state_dict = self.msdd_model.state_dict()
        spk_emb_module_names = []
        for name in model_state_dict.keys():
            if prefix in name:
                spk_emb_module_names.append(name)

        spk_emb_state_dict = {}
        for name in spk_emb_module_names:
            org_name = name.replace(prefix, '')
            spk_emb_state_dict[org_name] = model_state_dict[name]

        _speaker_model = EncDecSpeakerLabelModel.from_config_dict(self.msdd_model.cfg.speaker_model_cfg)
        _speaker_model.load_state_dict(spk_emb_state_dict)
        return _speaker_model
    def transfer_diar_params_to_model_params(self, msdd_model, cfg):
        msdd_model.cfg.diarizer.out_dir = cfg.diarizer.out_dir
        msdd_model.cfg.test_ds.manifest_filepath = cfg.diarizer.manifest_filepath
        msdd_model.cfg.test_ds.emb_dir = cfg.diarizer.out_dir
        msdd_model.cfg.test_ds.batch_size = cfg.diarizer.msdd_model.parameters.infer_batch_size
        msdd_model.cfg.test_ds.seq_eval_mode = cfg.diarizer.msdd_model.parameters.seq_eval_mode
        msdd_model._cfg.max_num_of_spks = cfg.diarizer.clustering.parameters.max_num_speakers
        return msdd_model.cfg
    
    def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]):
        model_path = cfg.diarizer.msdd_model.model_path
        if model_path.endswith('.nemo'):
            logging.info(f"Using local nemo file from {model_path}")
            self.msdd_model = EncDecDiarLabelModel.restore_from(restore_path=model_path, map_location=cfg.device)
        elif model_path.endswith('.ckpt'):
            logging.info(f"Using local checkpoint from {model_path}")
            self.msdd_model = EncDecDiarLabelModel.load_from_checkpoint(
                checkpoint_path=model_path, map_location=cfg.device
            )
        else:
            if model_path not in get_available_model_names(EncDecDiarLabelModel):
                logging.warning(f"requested {model_path} model name not available in pretrained models, instead")
            logging.info("Loading pretrained {} model from NGC".format(model_path))
            self.msdd_model = EncDecDiarLabelModel.from_pretrained(model_name=model_path, map_location=cfg.device)
        # Load speaker embedding model state_dict which is loaded from the MSDD checkpoint.
        if self.use_speaker_model_from_ckpt:
            self._speaker_model = self.extract_standalone_speaker_model()
        else:
            self._speaker_model = None

In [6]:
neural_diarizer=NemoNeuralDiarizer(config)

[NeMo I 2024-09-28 00:15:57 nemo_logging:381] Loading pretrained diar_msdd_telephonic model from NGC
[NeMo I 2024-09-28 00:15:57 nemo_logging:381] Found existing object /home/rahim/.cache/torch/NeMo/NeMo_2.1.0rc0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.
[NeMo I 2024-09-28 00:15:57 nemo_logging:381] Re-using file from: /home/rahim/.cache/torch/NeMo/NeMo_2.1.0rc0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo
[NeMo I 2024-09-28 00:15:57 nemo_logging:381] Instantiating model from pre-trained checkpoint


[NeMo W 2024-09-28 00:16:01 nemo_logging:393] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: true
    
[NeMo W 2024-09-28 00:16:01 nemo_logging:393] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: false
    
[NeMo W 2024-09-28 00:16:01 nemo_logging:393] Please call the ModelPT.setup_test_data() or ModelPT

[NeMo I 2024-09-28 00:16:01 nemo_logging:381] PADDING: 16
[NeMo I 2024-09-28 00:16:02 nemo_logging:381] PADDING: 16


      return torch.load(model_weights, map_location='cpu')
    


[NeMo I 2024-09-28 00:16:04 nemo_logging:381] Model EncDecDiarLabelModel was successfully restored from /home/rahim/.cache/torch/NeMo/NeMo_2.1.0rc0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.
[NeMo I 2024-09-28 00:16:04 nemo_logging:381] PADDING: 16
[NeMo I 2024-09-28 00:16:04 nemo_logging:381] Loading pretrained vad_multilingual_marblenet model from NGC
[NeMo I 2024-09-28 00:16:04 nemo_logging:381] Found existing object /home/rahim/.cache/torch/NeMo/NeMo_2.1.0rc0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.
[NeMo I 2024-09-28 00:16:04 nemo_logging:381] Re-using file from: /home/rahim/.cache/torch/NeMo/NeMo_2.1.0rc0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo
[NeMo I 2024-09-28 00:16:04 nemo_logging:381] Instantiating model from pre-trained checkpoint


[NeMo W 2024-09-28 00:16:04 nemo_logging:393] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /manifests/ami_train_0.63.json,/manifests/freesound_background_train.json,/manifests/freesound_laughter_train.json,/manifests/fisher_2004_background.json,/manifests/fisher_2004_speech_sampled.json,/manifests/google_train_manifest.json,/manifests/icsi_all_0.63.json,/manifests/musan_freesound_train.json,/manifests/musan_music_train.json,/manifests/musan_soundbible_train.json,/manifests/mandarin_train_sample.json,/manifests/german_train_sample.json,/manifests/spanish_train_sample.json,/manifests/french_train_sample.json,/manifests/russian_train_sample.json
    sample_rate: 16000
    labels:
    - background
    - speech
    batch_size: 256
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    tarred_shard_strategy

[NeMo I 2024-09-28 00:16:04 nemo_logging:381] PADDING: 16
[NeMo I 2024-09-28 00:16:04 nemo_logging:381] Model EncDecClassificationModel was successfully restored from /home/rahim/.cache/torch/NeMo/NeMo_2.1.0rc0/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.
