# Training (or fine-tuning) a model

In this tutorial, you will learn how to train a `pyannote.audio` model from scratch, or fine-tune a pretrained model.

**Warning:** this tutorial assumes that the [AMI corpus](https://groups.inf.ed.ac.uk/ami/corpus/) has already been [setup for use with `pyannote`](https://github.com/pyannote/AMI-diarization-setup/tree/main/pyannote) and the `PYANNOTE_DATABASE_CONFIG` environment variable is set accordingly. 

We start by defining which `task` the `model` will address.  
Here, we want the `model` to address voice activity detection (`vad`) using the `ami` dataset.

In [9]:
from pyannote.database import get_protocol
ami = get_protocol('AMI.SpeakerDiarization.only_words')

from pyannote.audio.tasks import VoiceActivityDetection
vad = VoiceActivityDetection(ami)

For the purpose of this tutorial, we define a `compute_model_fscore` function that runs a model on the AMI test set and returns the voice activity detection F-score.

In [20]:
from pyannote.audio.pipelines import VoiceActivityDetection as VoiceActivityDetectionPipeline
from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure

def compute_model_fscore(model):

    # instantiate voice activity detection pipeline
    vad = VoiceActivityDetectionPipeline(segmentation=model)
    vad.instantiate({'onset': 0.5, 'offset': 0.5, 
                     'min_duration_on': 0.0, 'min_duration_off': 0.0})

    # instantiate precision/recall metrics
    metric = DetectionPrecisionRecallFMeasure()

    for file in ami.test():
        
        # apply the voice activity detection pipeline
        speech = vad(file)
        
        # evaluate its output
        _ = metric(
            file['annotation'],     # this is the reference annotation
            speech,                 # this is the hypothesized annotation
            uem=file['annotated'])  # this is the part of the file that should be evaluated

    # aggregate the performance over the whole test set
    fscore = abs(metric)
    print(f'F-score = {100 * fscore:.1f}%')

## Using a pretrained model

To serve as a baseline, we use the pretrained [`pyannote/segmentation`](https://hf.co/pyannote/segmentation) speaker segmentation model.

In [11]:
from pyannote.audio import Model
pretrained = Model.from_pretrained('pyannote/segmentation')

This `pretrained` model relies on the `PyanNet` architecture available in `pyannote.audio`, that combines (trainable) SincNet feature extraction, a few LSTM layers, a few linear layers and a final classification layer.

In [12]:
pretrained.summarize()

  | Name       | Type       | Params | In sizes      | Out sizes                                  
--------------------------------------------------------------------------------------------------------
0 | sincnet    | SincNet    | 42.6 K | [3, 1, 32000] | [3, 60, 115]                               
1 | lstm       | LSTM       | 1.4 M  | [3, 115, 60]  | [[3, 115, 256], [[8, 3, 128], [8, 3, 128]]]
2 | linear     | ModuleList | 49.4 K | ?             | ?                                          
3 | classifier | Linear     | 516    | [3, 115, 128] | [3, 115, 4]                                
4 | activation | Sigmoid    | 0      | [3, 115, 4]   | [3, 115, 4]                                
--------------------------------------------------------------------------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
5.892     Total estimated model params size (MB)

In [18]:
compute_model_fscore(pretrained)

F-score = 96.6%


## Training a model from scratch

We will now train a voice activity detection model from scratch, using the AMI training set.

To make sure we use the exact same architecture, we rely on `pretrained.hparams` that conveniently keeps track of the hyper-parameters used to instantiate the architecture of `pretrained` model.

In [13]:
pretrained.hparams

"linear":       {'hidden_size': 128, 'num_layers': 2}
"lstm":         {'hidden_size': 128, 'num_layers': 4, 'bidirectional': True, 'monolithic': True, 'dropout': 0.5, 'batch_first': True}
"num_channels": 1
"sample_rate":  16000
"sincnet":      {'stride': 10, 'sample_rate': 16000}

In [14]:
from pyannote.audio.models.segmentation import PyanNet
from_scratch = PyanNet(task=vad, **pretrained.hparams)

👀  Notice how we passed `vad` as the `task` argument of our `from_scratch` model.  
This allows `pyannote.audio` to automagically register the right `classifier` and `activation` layers into the `PyanNet` model.

> Look ma, no hands!

This magic trick is possible because every task in `pyannote.audio` exposes its specifications.

In [15]:
vad.specifications

Specifications(problem=<Problem.BINARY_CLASSIFICATION: 0>, resolution=<Resolution.FRAME: 1>, duration=2.0, warm_up=(0.0, 0.0), classes=['speech'], permutation_invariant=False)

Voice activity detection is a *binary classification* problem that is trained on *2s* audio chunks.

In [22]:
import pytorch_lightning as pl
trainer = pl.Trainer(gpus=1, max_epochs=2)
trainer.fit(from_scratch)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params | In sizes       | Out sizes                                     
-------------------------------------------------------------------------------------------------------------------
0 | sincnet           | SincNet    | 42.6 K | [32, 1, 32000] | [32, 60, 115]                                 
1 | lstm              | LSTM       | 1.4 M  | [32, 115, 60]  | [[32, 115, 256], [[8, 32, 128], [8, 32, 128]]]
2 | linear            | ModuleList | 49.4 K | ?              | ?                                             
3 | classifier        | Linear     | 129    | [32, 115, 128] | [32, 115, 1]                                  
4 | activation        | Sigmoid    | 0      | [32, 115, 1]   | [32, 115, 1]                                  
5 | validation_metric | AUROC      | 0      | ?              | ?        

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [23]:
compute_model_fscore(from_scratch)

F-score = 96.4%


🤷‍♂️Training the model for just 2 epochs gives us decent results but it still performs worse than the pretrained model.

## Fine-tuning a pretrained model

🤔 Can we do better (and faster) by fine-tuning the pretrained model? 

In [17]:
fine_tuned = Model.from_pretrained('pyannote/segmentation')
fine_tuned.task = vad

In [25]:
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(fine_tuned)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params | In sizes       | Out sizes                                     
-------------------------------------------------------------------------------------------------------------------
0 | sincnet           | SincNet    | 42.6 K | [32, 1, 32000] | [32, 60, 115]                                 
1 | lstm              | LSTM       | 1.4 M  | [32, 115, 60]  | [[32, 115, 256], [[8, 32, 128], [8, 32, 128]]]
2 | linear            | ModuleList | 49.4 K | ?              | ?                                             
3 | classifier        | Linear     | 129    | [32, 115, 128] | [32, 115, 1]                                  
4 | activation        | Sigmoid    | 0      | [32, 115, 1]   | [32, 115, 1]                                  
5 | validation_metric | AUROC      | 0      | ?              | ?        

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [26]:
compute_model_fscore(fine_tuned)

F-score = 96.8%


🎉 Fine-tuning the pretrained model for just one epoch already gives us an improvement (96.8%) over the pretrained model (96.6%)