In [None]:
"""
You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.

Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GITHUB" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies.
"""
# If you're using Google Colab and not running locally, run this cell.

## Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install text-unidecode
!pip install ipython

# ## Install NeMo
BRANCH = 'main'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@{BRANCH}#egg=nemo_toolkit[asr]

## Install TorchAudio
!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html

# Streaming Multitalker ASR

## Streaming Multitalker ASR with Self-Speaker Adaptation

This tutorial shows you how to use NeMo's streaming multitalker ASR system based on the approach described in [(Wang et al., 2025)](https://arxiv.org/abs/2506.22646). This system transcribes each speaker separately in multispeaker audio using speaker activity information from a streaming diarization model.

### How This Approach Works

The streaming multitalker Parakeet model uses **self-speaker adaptation (SSA)**, which means:

1. **No Speaker Enrollment Required**: You only need speaker activity predictions from a diarization model (like Streaming Sortformer)
2. **Speaker Kernel Injection**: The model injects speaker-specific kernels into encoder layers to focus on each target speaker
3. **Multi-Instance Architecture**: You run one model instance per speaker, and each instance processes the same audio
4. **Handles Overlapping Speech**: Each instance focuses on one speaker, so it can transcribe overlapped speech segments

### Cache-Aware Streaming

The model uses stateful cache-based inference [(Noroozi et al., 2023)](https://arxiv.org/abs/2312.17279) for streaming:
- Left and right contexts in the encoder are constrained for low latency
- An activation caching mechanism enables the encoder to operate autoregressively during inference
- The model maintains consistent behavior between training and inference

## Background: Multi-Instance Architecture Overview

The streaming multitalker Parakeet model employs a **multi-instance approach** where one model instance is deployed per speaker:

<img src="images/multi_instance.png" alt="Multi-instance inference of Multitalker Parakeet model" style="width: 800px;"/>

Each model instance:
- Receives the same mixed audio input
- Injects **speaker-specific kernels** generated from diarization-based speaker activity
- Produces transcription output specific to its target speaker
- Operates independently and can run in parallel with other instances

### Speaker Kernel Injection Mechanism

Learnable speaker kernels are injected into selected layers of the Fast-Conformer encoder:

<img src="images/speaker_injection.png" alt="Speaker Kernel Injection Mechanism" style="width: 800px;"/>

The speaker kernels are generated through speaker supervision activations that detect speech activity for each target speaker from the streaming diarization output. This enables the encoder states to become more responsive to the targeted speaker's speech characteristics.

# Training Streaming Multitalker ASR Model

As we covered in the background section, SSA-based multitalker ASR uses external streaming speaker diarization logit values to inform the multitalker model to only concentrate on the targeted speaker. Thus, we only train the ASR model while the speaker diarization model is frozen.

### Data preparation

Let's download the mini Librispeech (English) dataset. It is OK for the purposes of this tutorial, but for anything real, you will need to get at least the entire Librispeech dataset (960 hrs).

In [None]:
# Downloading MiniLibrispeech
!mkdir -p datasets/mini


We will use the `get_librispeech_data.py` script located in the nemo/scripts/dataset_processing dir if you cloned NeMo repo

In [None]:
import os
if not os.path.exists("get_librispeech_data.py"):
    !wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/dataset_processing/get_librispeech_data.py

In [None]:
!python get_librispeech_data.py \
  --data_root "datasets/mini/" \
  --data_sets mini

Now let's prepare the LibriSpeechMix dataset using this MiniLibrispeech dataset. First, adding speaker ids to the manifest.

In [None]:
import json

def add_speaker_ids(manifest_path):
    new_data = []
    with open(manifest_path, "r") as f:
        for line in f:
            data = json.loads(line)
            data["speaker_id"] = os.path.basename(data["audio_filepath"]).split("-")[0]     
            new_data.append(data)

    with open(manifest_path.replace(".json", "_spk.json"), "w") as f:
        for data in new_data:
            f.write(json.dumps(data) + "\n")

add_speaker_ids("datasets/mini/train_clean_5.json")
add_speaker_ids("datasets/mini/dev_clean_2.json")

### Generate dataset for training

In [None]:
from lhotse import CutSet
from nemo.collections.asr.parts.utils.asr_multispeaker_utils import MultiSpeakerMixtureGenerator

train_manifest = "datasets/mini/train_clean_5_spk.json"
val_manifest = "datasets/mini/dev_clean_2_spk.json"

train_cuts = CutSet(
        MultiSpeakerMixtureGenerator(
            manifest_filepath=train_manifest,
            simulator_type="lsmix",
            sample_rate=16000,
            min_delay=0.5,
            min_duration=0.1,
            max_duration=60,
            num_speakers=2
        )
    )

train_cuts = CutSet([train_cuts[i] for i in range(10)])

print("Done")

Let's see some samples in the training cuts:

In [None]:
cut = train_cuts[0]

print("Speaker 0: ")
print('Start time: ', cut.supervisions[0].start)
print('End time: ', cut.supervisions[0].end)
print('Text: ', cut.supervisions[0].text)

print("Speaker 1: ")
print('Start time: ', cut.supervisions[1].start)
print('End time: ', cut.supervisions[1].end)
print('Text: ', cut.supervisions[1].text)

Let's play this audio:

In [None]:
from IPython.display import Audio

Audio(cut.load_audio()[0], rate=16000)

### Model loading

In [None]:
from nemo.collections.asr.models import ASRModel
import torch

config = ASRModel.from_pretrained("nvidia/multitalker-parakeet-streaming-0.6b-v1", return_config=True)

Let's see the config:

In [None]:
print(config.train_ds)

In [None]:
config.train_ds.input_cfg = "train_cfg.yaml"
config.validation_ds.input_cfg = "val_cfg.yaml"
print(config.train_ds)

In [None]:
import yaml

train_cfg = [{
    "input_cfg": [
        {
            "type": "multi_speaker_simulator",
            "manifest_filepath": train_manifest,
            "weight": 1,
            "simulator_type": "lsmix", 
            "num_speakers": 2,
            "min_delay": 0.5,
            "is_tarred": True
        }
    ],
    "type": "group"
}]
with open("train_cfg.yaml", 'w') as f:
    f.write(yaml.dump(train_cfg, sort_keys=False))

val_cfg = [{
    "input_cfg": [
        {
            "type": "multi_speaker_simulator",
            "manifest_filepath": val_manifest,
            "weight": 1,
            "simulator_type": "lsmix", 
            "num_speakers": 2,
            "min_delay": 0.5,
            "is_tarred": True
        }
    ],
    "type": "group"
}]
with open("val_cfg.yaml", 'w') as f:
    f.write(yaml.dump(val_cfg, sort_keys=False))

Reload this model with updated config:

In [None]:
!cat train_cfg.yaml

In [None]:
# Let's modify some trainer configs for this demo
# Checks if we have GPU available and uses it
import torch
import lightning.pytorch as pl

accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'

trainer = pl.Trainer(
    strategy="auto",
    devices=1,
    
    accelerator=accelerator,
    max_epochs=-1,
    max_steps=1000,
    limit_train_batches=100,
    limit_val_batches=5,
    enable_progress_bar=True,
    check_val_every_n_epoch=1
)

config.train_ds.batch_size=4
config.train_ds.max_duration=20
config.validation_ds.batch_size=4
config.validation_ds.max_duration=20



In [None]:
asr_model = ASRModel.from_pretrained("nvidia/multitalker-parakeet-streaming-0.6b-v1", override_config_path=config, trainer=trainer)

In [None]:
trainer.fit(asr_model)

## References

[1] [Speaker Targeting via Self-Speaker Adaptation for Multi-talker ASR](https://arxiv.org/abs/2506.22646)  


[2] [Stateful Conformer with Cache-based Inference for Streaming Automatic Speech Recognition](https://arxiv.org/abs/2312.17279)  

[3] [Streaming Sortformer: Speaker Cache-Based Online Speaker Diarization with Arrival-Time Ordering](https://arxiv.org/abs/2507.18446)

[4] [NEST: Self-supervised Fast Conformer as All-purpose Seasoning to Speech Processing Tasks](https://arxiv.org/abs/2408.13106)

[5] [Fast Conformer with Linearly Scalable Attention for Efficient Speech Recognition](https://arxiv.org/abs/2305.05084)