# Training (or fine-tuning) a model

In this tutorial, you will learn how to train a `pyannote.audio` model from scratch, or fine-tune a pretrained model.

**Warning:** this tutorial assumes that the [AMI corpus](https://groups.inf.ed.ac.uk/ami/corpus/) has already been [setup for use with `pyannote`](https://github.com/pyannote/AMI-diarization-setup/tree/main/pyannote) and the `PYANNOTE_DATABASE_CONFIG` environment variable is set accordingly. 

We start by defining which `task` the `model` will address.  
Here, we want the `model` to address voice activity detection (`vad`) using the `ami` dataset.

In [None]:
from pyannote.database import get_protocol
ami = get_protocol('AMI.SpeakerDiarization.only_words')

from pyannote.audio.tasks import VoiceActivityDetection
vad = VoiceActivityDetection(ami)

For the purpose of this tutorial, we define a `compute_model_fscore` function that runs a model on the AMI test set and returns the voice activity detection F-score.

In [None]:
from pyannote.audio.pipelines import VoiceActivityDetection as VoiceActivityDetectionPipeline
from pyannote.metrics.detection import DetectionPrecisionRecallFMeasure

def compute_model_fscore(model):

    # instantiate voice activity detection pipeline
    vad = VoiceActivityDetectionPipeline(segmentation=model)
    vad.instantiate({'onset': 0.5, 'offset': 0.5, 
                     'min_duration_on': 0.0, 'min_duration_off': 0.0})

    # instantiate precision/recall metrics
    metric = DetectionPrecisionRecallFMeasure()

    for file in ami.test():
        
        # apply the voice activity detection pipeline
        speech = vad(file)
        
        # evaluate its output
        _ = metric(
            file['annotation'],     # this is the reference annotation
            speech,                 # this is the hypothesized annotation
            uem=file['annotated'])  # this is the part of the file that should be evaluated

    # aggregate the performance over the whole test set
    fscore = abs(metric)
    print(f'F-score = {100 * fscore:.1f}%')

## Using a pretrained model

To serve as a baseline, we use the pretrained [`pyannote/segmentation`](https://hf.co/pyannote/segmentation) speaker segmentation model.

In [None]:
from pyannote.audio import Model
pretrained = Model.from_pretrained('pyannote/segmentation')

This `pretrained` model relies on the `PyanNet` architecture available in `pyannote.audio`, that combines (trainable) SincNet feature extraction, a few LSTM layers, a few linear layers and a final classification layer.

In [None]:
pretrained.summarize()

In [None]:
compute_model_fscore(pretrained)

## Training a model from scratch

We will now train a voice activity detection model from scratch, using the AMI training set.

To make sure we use the exact same architecture, we rely on `pretrained.hparams` that conveniently keeps track of the hyper-parameters used to instantiate the architecture of `pretrained` model.

In [None]:
pretrained.hparams

In [None]:
from pyannote.audio.models.segmentation import PyanNet
from_scratch = PyanNet(task=vad, **pretrained.hparams)

👀  Notice how we passed `vad` as the `task` argument of our `from_scratch` model.  
This allows `pyannote.audio` to automagically register the right `classifier` and `activation` layers into the `PyanNet` model.

> Look ma, no hands!

This magic trick is possible because every task in `pyannote.audio` exposes its specifications.

In [None]:
vad.specifications

Voice activity detection is a *binary classification* problem that is trained on *2s* audio chunks.

In [None]:
import pytorch_lightning as pl
trainer = pl.Trainer(gpus=1, max_epochs=2)
trainer.fit(from_scratch)

In [None]:
compute_model_fscore(from_scratch)

🤷‍♂️Training the model for just 2 epochs gives us decent results but it still performs worse than the pretrained model.

## Fine-tuning a pretrained model

🤔 Can we do better (and faster) by fine-tuning the pretrained model? 

In [None]:
fine_tuned = Model.from_pretrained('pyannote/segmentation')
fine_tuned.task = vad

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(fine_tuned)

In [None]:
compute_model_fscore(fine_tuned)

🎉 Fine-tuning the pretrained model for just one epoch already gives us an improvement (96.8%) over the pretrained model (96.6%)

## Going further

Data augmentation is supported via [`torch-audiomentations`](https://github.com/asteroid-team/torch-audiomentations).

```python
from torch_audiomentations import Compose, ApplyImpulseResponse, AddBackgroundNoise
augmentation = Compose(transforms=[ApplyImpulseResponse(...),
                                   AddBackgroundNoise(...)])
```



A growing collection of tasks can be addressed.
*Here, we address speaker segmentation.*

```python
from pyannote.audio.tasks import Segmentation
seg = Segmentation(ami, augmentation=augmentation)
```

A growing collection of model architecture can be used.
*Here, we use the PyanNet (sincnet + LSTM) architecture.*

```python
from pyannote.audio.models.segmentation import PyanNet
model = PyanNet(task=seg)
```

We benefit from all the nice things that [`pytorch-lightning`](https://www.pytorchlightning.ai/) has to offer:  distributed (GPU & TPU) training, model checkpointing, logging, etc.
*In this example, we don't really use any of this...*

```python
from pytorch_lightning import Trainer
trainer = Trainer()
trainer.fit(model)
```

Predictions are obtained by wrapping the model into the `Inference` engine.

```python
from pyannote.audio import Inference
inference = Inference(model)
predictions = inference('audio.wav')
```

Pretrained models can be shared on [Huggingface.co](https://huggingface.co/pyannote) model hub.
*Here, we download and use a [pretrained](https://huggingface.co/pyannote/segmentation) segmentation model.*

```python
inference = Inference('pyannote/segmentation')
predictions = inference('audio.wav')
```

Fine-tuning is as easy as setting the `task` attribute, freezing early layers and training.
*Here, we fine-tune on AMI dataset the pretrained segmentation model.*

```python
from pyannote.audio import Model
model = Model.from_pretrained('pyannote/segmentation')
model.task = Segmentation(ami)
model.freeze_up_to('sincnet')
trainer.fit(model)
```

Transfer learning is also supported out of the box.
*Here, we do transfer learning from segmentation to overlapped speech detection.*

```python
from pyannote.audio.tasks import OverlappedSpeechDetection
osd = OverlappedSpeechDetection(ami)
model.task = osd
trainer.fit(model)
```

Default optimizer (`Adam` with default parameters) is automatically set up for you.  Customizing optimizer (and scheduler) requires overriding [`model.configure_optimizers`](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers) method:

```python
from types import MethodType
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR
def configure_optimizers(self):
    return {"optimizer": SGD(self.parameters()),
            "lr_scheduler": ExponentialLR(optimizer, 0.9)}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer.fit(model)
```
