# Exercise 6

Same rules as previous exercises apply.

**NOTE**: Submit the ipynb, not PDF for this exercise.

Chose the GPU runtime with T4 GPU.

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
!pip install git+https://github.com/openai/whisper.git

In [None]:
from whisper import load_model

In [None]:
# load the 'small' model
model_baseline = ...

In [None]:
# print the number of filterbanks used
model_baseline.dims.n_mels

In [None]:
model_baseline = None  # free memory

## Complete the dataloader (5 pt)

In [None]:
# Write a code to fine-tune with the ami train set and evaluate with the ami test set

# define data loader given that the audio files are in ami_train_audio_segmented
# and the transcriptions are in ami_train_audio_segmented/text. The text file is in the format
# <audio_file_name> <transcription>

import whisper.audio as whisper_audio
import torch

audio_files = []
transcriptions = []
import os
# TODO: change paths in two places
with open('/content/drive/MyDrive/work/uzh/teaching/2024-speech-technology/ex6_files/text') as f:
    for line in f:
        audio_file, *transcription = line.strip().split()
        transcription = ' '.join(transcription)
        audio_files.append(os.path.join('/content/drive/MyDrive/work/uzh/teaching/2024-speech-technology/ex6_files/', f'{audio_file}.wav'))
        transcriptions.append(transcription)

def extract_audio_features(audio_file):
    # find the number of samples for the audio file
    n_samples = ...
    padding_length = whisper_audio.N_SAMPLES - n_samples
    return whisper_audio.log_mel_spectrogram(audio_file, n_mels=..., padding=padding_length)

# define data loader given that the audio files are in ami_test_audio_segmented
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, audio_files, transcriptions):
        self.audio_files = audio_files
        self.transcriptions = transcriptions

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        mel = extract_audio_features(self.audio_files[idx])
        return mel, self.transcriptions[idx]

def collate_fn(batch):
    mels, transcriptions = zip(*batch)
    return torch.stack(mels), transcriptions

train_dataset = AudioDataset(audio_files, transcriptions)

## Boilerplate

Just the next few code cells in this section. This is library code, nothing to be changed here.

In [None]:
# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


class LabelSmoothingLoss(torch.nn.Module):
    """
    Implement the LabelSmoothingLoss proposed in the following paper
    https://arxiv.org/pdf/1512.00567.pdf
    (Rethinking the Inception Architecture for Computer Vision)

    """

    def __init__(
        self,
        ignore_index: int = -1,
        label_smoothing: float = 0.1,
        reduction: str = "sum",
    ) -> None:
        """
        Args:
          ignore_index:
            ignored class id
          label_smoothing:
            smoothing rate (0.0 means the conventional cross entropy loss)
          reduction:
            It has the same meaning as the reduction in
            `torch.nn.CrossEntropyLoss`. It can be one of the following three
            values: (1) "none": No reduction will be applied. (2) "mean": the
            mean of the output is taken. (3) "sum": the output will be summed.
        """
        super().__init__()
        assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
        assert reduction in ("none", "sum", "mean"), reduction
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing
        self.reduction = reduction

    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute loss between x and target.

        Args:
          x:
            prediction of dimension
            (batch_size, input_length, number_of_classes).
          target:
            target masked with self.ignore_index of
            dimension (batch_size, input_length).

        Returns:
          A scalar tensor containing the loss without normalization.
        """
        assert x.ndim == 3
        assert target.ndim == 2
        assert x.shape[:2] == target.shape
        num_classes = x.size(-1)
        x = x.reshape(-1, num_classes)
        # Now x is of shape (N*T, C)

        # We don't want to change target in-place below,
        # so we make a copy of it here
        target = target.clone().reshape(-1)

        ignored = target == self.ignore_index

        # See https://github.com/k2-fsa/icefall/issues/240
        # and https://github.com/k2-fsa/icefall/issues/297
        # for why we don't use target[ignored] = 0 here
        target = torch.where(ignored, torch.zeros_like(target), target)

        true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x)

        true_dist = (
            true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes
        )

        # Set the value of ignored indexes to 0
        #
        # See https://github.com/k2-fsa/icefall/issues/240
        # and https://github.com/k2-fsa/icefall/issues/297
        # for why we don't use true_dist[ignored] = 0 here
        true_dist = torch.where(
            ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
            torch.zeros_like(true_dist),
            true_dist,
        )

        loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
        if self.reduction == "sum":
            return loss.sum()
        elif self.reduction == "mean":
            return loss.sum() / (~ignored).sum()
        else:
            return loss.sum(dim=-1)


In [None]:
# DO NOT MODIFY
# Copyright    2023  Xiaomi Corp.        (authors: Xiaoyu Yang)
#              2024  Yuekai Zhang
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Any
from torch import Tensor
from torch.nn.functional import pad as pad_tensor


def batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
    padding_size = max(tensor.shape[0] for tensor in tensors)
    dims = len(tensors[0].shape)
    padded_tensors = []
    for tensor in tensors:
        padding = [0] * 2 * dims
        padding[-1] = padding_size - tensor.shape[0]
        padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
    return torch.stack([tensor for tensor in padded_tensors], dim=0)

In [None]:
!pip install kaldialign

In [None]:
# DO NOT MODIFY THIS CODE
# Code modified from Icefall: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
#                                                    Mingshuang Luo,
#                                                    Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
import kaldialign
from typing import Dict, Tuple, List

def get_error_stats(
    results: List[Tuple[str, str]],
    compute_CER: bool = False,
) -> float:
    """Write statistics based on predicted results and reference transcripts.

    It will write the following to the given file:

        - WER
        - number of insertions, deletions, substitutions, corrects and total
          reference words. For example::

              Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
              reference words (2337 correct)

        - The difference between the reference transcript and predicted result.
          An instance is given below::

            THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES

          The above example shows that the reference word is `EDISON`,
          but it is predicted to `ADDISON` (a substitution error).

          Another example is::

            FOR THE FIRST DAY (SIR->*) I THINK

          The reference word `SIR` is missing in the predicted
          results (a deletion error).
      results:
        An iterable of tuples. The first element is the cut_id, the second is
        the reference transcript and the third element is the predicted result.
      enable_log:
        If True, also print detailed WER to the console.
        Otherwise, it is written only to the given file.
    Returns:
      Return None.
    """
    subs: Dict[Tuple[str, str], int] = defaultdict(int)
    ins: Dict[str, int] = defaultdict(int)
    dels: Dict[str, int] = defaultdict(int)

    # `words` stores counts per word, as follows:
    #   corr, ref_sub, hyp_sub, ins, dels
    words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
    num_corr = 0
    ERR = "*"

    for cut_id, ref, hyp in results:
        ali = kaldialign.align(ref, hyp, ERR, sclite_mode=False)
        for ref_word, hyp_word in ali:
            if ref_word == ERR:
                ins[hyp_word] += 1
                words[hyp_word][3] += 1
            elif hyp_word == ERR:
                dels[ref_word] += 1
                words[ref_word][4] += 1
            elif hyp_word != ref_word:
                subs[(ref_word, hyp_word)] += 1
                words[ref_word][1] += 1
                words[hyp_word][2] += 1
            else:
                words[ref_word][0] += 1
                num_corr += 1
    ref_len = sum([len(r) for _, r, _ in results])
    sub_errs = sum(subs.values())
    ins_errs = sum(ins.values())
    del_errs = sum(dels.values())
    tot_errs = sub_errs + ins_errs + del_errs
    tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)


    print(
        f"%WER {tot_errs / ref_len:.2%} "
        f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
        f"{del_errs} del, {sub_errs} sub ]"
    )

    print(f"%WER = {tot_err_rate}")
    print(
        f"Errors: {ins_errs} insertions, {del_errs} deletions, "
        f"{sub_errs} substitutions, over {ref_len} reference "
        f"words ({num_corr} correct)",
    )
    print(
        "Search below for sections starting with PER-UTT DETAILS:, "
        "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
    )

    print("")
    print("PER-UTT DETAILS: corr or (ref->hyp)  ")
    for cut_id, ref, hyp in results:
        ali = kaldialign.align(ref, hyp, ERR)
        combine_successive_errors = True
        if combine_successive_errors:
            ali = [[[x], [y]] for x, y in ali]
            for i in range(len(ali) - 1):
                if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
                    ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
                    ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
                    ali[i] = [[], []]
            ali = [
                [
                    list(filter(lambda a: a != ERR, x)),
                    list(filter(lambda a: a != ERR, y)),
                ]
                for x, y in ali
            ]
            ali = list(filter(lambda x: x != [[], []], ali))
            ali = [
                [
                    ERR if x == [] else " ".join(x),
                    ERR if y == [] else " ".join(y),
                ]
                for x, y in ali
            ]

        print(
            f"{cut_id}:\t"
            + " ".join(
                (
                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                    for ref_word, hyp_word in ali
                )
            ),
        )

    print("")
    print("SUBSTITUTIONS: count ref -> hyp")

    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
        print(f"{count}   {ref} -> {hyp}")

    print("")
    print("DELETIONS: count ref")
    for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
        print(f"{count}   {ref}")

    print("")
    print("INSERTIONS: count hyp")
    for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
        print(f"{count}   {hyp}")

    print("")
    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp")
    for _, word, counts in sorted(
        [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
    ):
        (corr, ref_sub, hyp_sub, ins, dels) = counts
        tot_errs = ref_sub + hyp_sub + ins + dels
        ref_count = corr + ref_sub + dels
        hyp_count = corr + hyp_sub + ins

        print(f"{word}   {corr} {tot_errs} {ref_count} {hyp_count}")
    return float(tot_err_rate)

## Whisper baseline (0 pt)

Write down the WER obtained with the Whisper model in the previous exercise:

### Mel spectrogram for Whisper (10 pt)

The feature extraction for Whisper is done in the code above (see Dataloader section) as follows:

```python
whisper_audio.log_mel_spectrogram(audio_file, n_mels=..., padding=padding_samples)
```

where ``audio_file`` is path to an audio file.

1. What is the value N_SAMPLES? You may find out by printing it, or referring to the Whisper source code.
2. What does it refer to?
3. Why is the padding even required?
4. What are the advantages and disadvantages of padding especially during model training?
5. What do we need to modify in order to avoid padding?

## Finetune for a subset of the AMI train set

In [None]:
# load the small model
model = ...

### Prepare the start of the transcript for training (5 pt)

In [None]:
# Find the start of sentence token in the vocabulary
import whisper
tokenizer = whisper.tokenizer.get_tokenizer(
        True,
        num_languages=model.num_languages,
        language=...,
        task=...,
    )

# Find the token ids of the the text "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
# How?
# Method 1: Use dir(tokenizer) to find the member that represents the token ids
# Method 2:  use the function mentioned in the previous exercise to find the token ids of each
# of the special tokens and put them together in a list to represent them as a sequence
# of special tokens.
tokens_to_prepend = ...
tokens_to_prepend

In [None]:
# justify that tokens_to_prepend is indeed "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"

## Fine-tune Whisper to AMI dataset (15 pt)

In this section, we will fine-tune by only using CPUs. We will load run the training with batch size of 1, but will only update the model parameters every 8 batches, making the batch size effectively 8.

This is done to avoid running out of memory while training.

There are about 1400+ audio files in training. You may want to run the training at least run it on 25% of the files. So, modify the break statement below to stop at an appropriate point such that it doesn't take too long for you to run the training, but run it on at least 25% of the training data.

In [None]:
# set seed
torch.manual_seed(42)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
model.train()
# model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
for batch_idx, (feats, texts) in enumerate(train_loader):
    # after few batches, stop
    ...
    # normalize and then convert text to tensor of indices using the model's tokenizer
    text_tokens_list = [
        tokens_to_prepend
        + tokenizer.encode(text.lower())
        + [tokenizer.eot]
        for text in texts
    ]
    # convert it to torch tensor
    text_tokens_list = [
        torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
    ]


    # 50256 is the index of <pad> for all whisper models
    prev_outputs_tokens = batch_tensors(
        [tokens[:-1] for tokens in text_tokens_list], pad_value=50256
    )
    target_tokens = batch_tensors(
        [tokens[1:] for tokens in text_tokens_list], pad_value=50256
    )
    target_lengths = torch.LongTensor(
        [tokens.shape[0] - 1 for tokens in text_tokens_list]
    )


    model.zero_grad()
    # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestamps|>
    ignore_prefix_size = 3
    decoder_criterion = LabelSmoothingLoss(
        ignore_index=50256, label_smoothing=0.1, reduction="sum"
    )

    #feats = feats.cuda()
    encoder_out = model.encoder(feats)
    #prev_outputs_tokens = prev_outputs_tokens.cuda()
    text_logits = model.decoder(prev_outputs_tokens, encoder_out)
    text_logits = text_logits[:, ignore_prefix_size:, :]
    target_tokens = target_tokens[:, ignore_prefix_size:]
    #target_tokens = target_tokens.cuda()
    loss = decoder_criterion(text_logits, target_tokens)
    loss_value = loss.item()
    # normalize loss value by number of sequences in the batch and number of
    # frames used as input per batch
    loss_value_normalized = ...
    print(f"batch_idx: {batch_idx}, loss: {loss_value_normalized}")
    loss.backward()
    if batch_idx % 8 == 0:
      optimizer.step()

# you may want to add code here if there are some examples
# on which loss.backward() was applied, but optimizer.step() was not
# run because you exited the loop earlier than expected.
...

### Estimate WER with model ( 5pt)

Use get_error_stats as done in the previous exercise to display the WER and detailed errors.

In [None]:
...

In [None]:
torch.save('./model_fullfinetuning.pt', model.state_dict())

In order to load the model again if you need it for later experiments, simply do

```python
model = load_model(...)  #. we already did this a couple of times earlier
model.load_state_dict(torch.load('./model_fullfinetuning.pt'))
```

## Fine-tune Whisper on GPUs (10 pt)

Run the fine-tuning code above with the statements containing ``.cuda()`` uncommented. Explain the result of the run.

In [None]:
%% time
# add training code here

In [None]:
# clear memory
model = None

## Fine-tune with LoRA (20 pt)

In [None]:
!pip install peft

In [None]:
from peft import LoraConfig
import peft
import whisper
model_lora = load_model('small', device='cpu', download_root='./')
target_modules = [n for n, m in model_lora.named_modules() if type(m) == whisper.model.Linear]
lora_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias=False,
    target_modules=target_modules,
)
model_lora = peft.add_lora(model_lora, lora_config)

Run the training on model_lora for the same amount of data as done for the full-finetuning. Run the training on GPU, if it doesn't work, you may run it on CPU. Then, evaluate the model on the same test set as before. Also, save the model if it performed better than full-finetuning.

## Number of parameters (10 pt)

You may refer to the Whisper paper to answer the following questions. You may also use the ``model`` variable to answer the questions.

1. How many parameters does Whisper small have?
2. How many layers of self-attention based transformer module does the Whisper encoder have for the 'small' model?
3. How many layers does the Whisper decoder have for the 'small' model?
4. How many linear layers (i.e. nn.Linear) does the Whisper small model have in total?
5. Compute the total number of new parameters that we added with LoRA by peft.add_lora() code earlier given the answer for for 4 and given that LoRA rank is 64.

## Transducer model with Whisper encoder (5 pt)

In [None]:
import torch.nn as nn
class Joiner(nn.Module):
    def __init__(self):
      ...
    def forward(self):
      ...

class Predictor(nn.Module):
    def __init__(self):
      ...
    def forward(self):
      ...

class WhisperTransducer(nn.Module):
    def __init__(self, encoder, tokenizer):
        super().__init__()
        self.encoder = encoder
        self.joiner = Joiner()  # assume Joiner has been already defined as in https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py
        self.predictor = Predictor()  # assume Predictor has been already defined as in https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py

    def forward(self, audio_file):
        x = self.encoder(x)
        # rest of transducer logic (not required to fill)
        ...
        return x

In [None]:
# TODO: extract the encoder only from whisper model and create
# an instance of WhisperTransducer by passing the encoder

## Best performance on the test data (5 pt)

How can one achieve a lower bound of the WER on a given data set? This is one way to know how far we can keep improving a given model's performance on that dataset. The 'model' here is not restricted to Whisper. It could by any type of ASR system.

## Word boosting for Whisper (15 pt)

Suppose we want to avoid some common mistakes, e.g. we know that the speech has 'wanna' but the ASR outputs 'want to'. Identify one such simple mistake in the result of your best model. Identify all the files (you may do it manually since there are only 150 test files), and add the correct word/phrase to be used to the prompt when decoding the audio where you want to avoid such confusion.

When decoding the prompt can be passed with the ``initial_prompt`` argument.

Show the output before and after the word boosting. Explain clearly (in technical terms) what makes the boosting work/not work in your case.