# Exercise 5

## Instructions

- Make sure you have uploaded the audio files to Google Drive.
- Please read the markdown sections, and code comments carefully before answering.
- You are required to treat ``...`` as incomplete code, which you are required to complete.
- Each incomplete region marked by ``...`` can be completed with a maximum of 2 statements (2 lines of code in Python).
- You may refer to the slides and reference material, but may not use AI code completion.
- Run all code cells in the notebook even if it does not require any answer from your part.
- The point for each section or sub-section is given in square brackets. E.g [15 pt] means 15 points.
- Pay attention to Q. & A. questions. The markdown-python cell separation is not always obvious.
- **ATTENTION**: There are many places where the path of the audio file needs to be fixed by you.

In [None]:
# mount google drive
#from google.colab import drive
#drive.mount('/content/drive', force_remount=True)

In [None]:
#!pip install git+https://github.com/openai/whisper.git

In [None]:
#!pip install jiwer

In [3]:
# jiwer example usage
import jiwer
reference = "hello world"
hypothesis = "hello duck"

error = jiwer.wer(reference, hypothesis)
error


0.5

In [4]:
character_error = jiwer.cer(reference, hypothesis)
character_error

0.45454545454545453

In [5]:
import warnings
warnings.filterwarnings("ignore")

## 1. Word Error Rate (20pt)

Questions:

1. Use jiwer to get WER and CER between the following hypothesis and reference (5)?

In [6]:
ref = 'HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE'
hyp = 'HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOWER FATTEN SAUCE'
wer = jiwer.wer(ref, hyp)
cer = jiwer.cer(ref, hyp)

2. Give example of an ASR output with WER 10% with only one of the following errors: insertions, deletions, or substitutions w.r.t the following reference? (5)

Reference is:

THE DARK SUIT WAS IN GREASY WASH WATER ALL YEAR


3.  Give example of an ASR output with WER 150% w.r.t the following reference? (5)

Reference is:

THE DARK SUIT WAS IN GREASY WASH WATER ALL YEAR

4. A dataset has 10 speech utterances. We are also given reference text for each of those 10 utterances. Given only WER of an ASR system for each of the 10 utterances (i.e. we do not have access to the hypothesis of the ASR system), how can we calculate the WER on the entire dataset ? (5)

## 2. Decoding with Whisper

In this section we will decode with the [AMI dataset](https://groups.inf.ed.ac.uk/ami/corpus/). We will use a subset of the test split of the dataset. We will use the term "AMI test" to refer to this set.

There are three tasks in this section:

1. Decode on AMI test with default parameters and evaluate the model's performance.
2. Decode on AMI test with beam search with beam size 4 and evaluate the model's performance.
3. Evaluate WER before and after text normalization

## 2.1 Example usage

In [8]:
from whisper import load_model
model = load_model('small', device='cpu', download_root='./')
# Whisper uses a greedy decoder when no option is specified.
#result_greedy = model.transcribe("/content/drive/MyDrive/work/uzh/teaching/2024-speech-technology/audio_files_ex5/ami-en2002b/EN2002b-3-379-ihm.wav", language="en", temperature=0)
result_greedy = model.transcribe("/ami-en2002b/EN2002b-3-379-ihm.wav", language="en", temperature=0)
result_greedy['text']  # way to access hypothesis after decoding


FileNotFoundError: [WinError 2] Das System kann die angegebene Datei nicht finden

## 2.2 Whisper decoding (5 pt)

In [None]:
utt2ref = {}
# TODO: fix path
with open('/content/drive/MyDrive/work/uzh/teaching/2024-speech-technology/audio_files_ex5/ami-en2002b/text_random150') as ipf:
  for ln in ipf:
    utt, *text = ln.strip().split()
    utt2ref[utt] = " ".join(text)

In [None]:
utt2hyp = {}
# for each utterance in utt2ref, decode and store the result in utt2hyp[utt]
...

In [None]:
results = []
for utt in utt2ref:
  if utt not in utt2hyp:
    print(f"ERROR: Missing hypothesis for utt {utt}")
    break
  results.append((utt, utt2ref[utt].split(), utt2hyp[utt].split()))
results[0] # example showing what results contains

In [None]:
!pip install kaldialign

In [None]:
# DO NOT MODIFY THIS CODE
# Code modified from Icefall: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
#                                                    Mingshuang Luo,
#                                                    Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
import kaldialign
from typing import Dict, Tuple, List

def get_error_stats(
    results: List[Tuple[str, str]],
    compute_CER: bool = False,
) -> float:
    """Write statistics based on predicted results and reference transcripts.

    It will write the following to the given file:

        - WER
        - number of insertions, deletions, substitutions, corrects and total
          reference words. For example::

              Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
              reference words (2337 correct)

        - The difference between the reference transcript and predicted result.
          An instance is given below::

            THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES

          The above example shows that the reference word is `EDISON`,
          but it is predicted to `ADDISON` (a substitution error).

          Another example is::

            FOR THE FIRST DAY (SIR->*) I THINK

          The reference word `SIR` is missing in the predicted
          results (a deletion error).
      results:
        An iterable of tuples. The first element is the cut_id, the second is
        the reference transcript and the third element is the predicted result.
      enable_log:
        If True, also print detailed WER to the console.
        Otherwise, it is written only to the given file.
    Returns:
      Return None.
    """
    subs: Dict[Tuple[str, str], int] = defaultdict(int)
    ins: Dict[str, int] = defaultdict(int)
    dels: Dict[str, int] = defaultdict(int)

    # `words` stores counts per word, as follows:
    #   corr, ref_sub, hyp_sub, ins, dels
    words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
    num_corr = 0
    ERR = "*"

    for cut_id, ref, hyp in results:
        ali = kaldialign.align(ref, hyp, ERR, sclite_mode=False)
        for ref_word, hyp_word in ali:
            if ref_word == ERR:
                ins[hyp_word] += 1
                words[hyp_word][3] += 1
            elif hyp_word == ERR:
                dels[ref_word] += 1
                words[ref_word][4] += 1
            elif hyp_word != ref_word:
                subs[(ref_word, hyp_word)] += 1
                words[ref_word][1] += 1
                words[hyp_word][2] += 1
            else:
                words[ref_word][0] += 1
                num_corr += 1
    ref_len = sum([len(r) for _, r, _ in results])
    sub_errs = sum(subs.values())
    ins_errs = sum(ins.values())
    del_errs = sum(dels.values())
    tot_errs = sub_errs + ins_errs + del_errs
    tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)


    print(
        f"%WER {tot_errs / ref_len:.2%} "
        f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
        f"{del_errs} del, {sub_errs} sub ]"
    )

    print(f"%WER = {tot_err_rate}")
    print(
        f"Errors: {ins_errs} insertions, {del_errs} deletions, "
        f"{sub_errs} substitutions, over {ref_len} reference "
        f"words ({num_corr} correct)",
    )
    print(
        "Search below for sections starting with PER-UTT DETAILS:, "
        "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
    )

    print("")
    print("PER-UTT DETAILS: corr or (ref->hyp)  ")
    for cut_id, ref, hyp in results:
        ali = kaldialign.align(ref, hyp, ERR)
        combine_successive_errors = True
        if combine_successive_errors:
            ali = [[[x], [y]] for x, y in ali]
            for i in range(len(ali) - 1):
                if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
                    ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
                    ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
                    ali[i] = [[], []]
            ali = [
                [
                    list(filter(lambda a: a != ERR, x)),
                    list(filter(lambda a: a != ERR, y)),
                ]
                for x, y in ali
            ]
            ali = list(filter(lambda x: x != [[], []], ali))
            ali = [
                [
                    ERR if x == [] else " ".join(x),
                    ERR if y == [] else " ".join(y),
                ]
                for x, y in ali
            ]

        print(
            f"{cut_id}:\t"
            + " ".join(
                (
                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                    for ref_word, hyp_word in ali
                )
            ),
        )

    print("")
    print("SUBSTITUTIONS: count ref -> hyp")

    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
        print(f"{count}   {ref} -> {hyp}")

    print("")
    print("DELETIONS: count ref")
    for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
        print(f"{count}   {ref}")

    print("")
    print("INSERTIONS: count hyp")
    for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
        print(f"{count}   {hyp}")

    print("")
    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp")
    for _, word, counts in sorted(
        [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
    ):
        (corr, ref_sub, hyp_sub, ins, dels) = counts
        tot_errs = ref_sub + hyp_sub + ins + dels
        ref_count = corr + ref_sub + dels
        hyp_count = corr + hyp_sub + ins

        print(f"{word}   {corr} {tot_errs} {ref_count} {hyp_count}")
    return float(tot_err_rate)

Run the above code. Check the most common errors (insertion, substituion, deletion) and propose some methods to fix the issue.

In [None]:
get_error_stats(results)

### Text normalization (20 pt)

Whisper outputs human readable text by default. Notice that the this is not the case for reference.

Normalize the reference and/or hypothesis text output in a way that both of them match. Then compute the WER again with ``get_error_stats()``.

You may Google for any Python-relevant help for string processing. If you use an external reference, please add the reference in the comment.

You may reuse code from previous sections to compute WER.


In [None]:
# Enter your code here. You can use multiple cells if necessary.
# Code "may" be longer than 3 statements in this case.

In [None]:
results = []
...

results[0] # example showing what results contains

In [None]:
get_error_stats(results)

## 2.5 Error analysis (5 pt)

Pick two specific types of errors and discuss how we can improve the ASR for such errors. Note that we are not referring to INSERTION, DELETION and SUBSTITUTION errors, but something more specific. For instance, you could mention a particular error like 'TWO -> TO' and propose a remedy. Choose the ASR output with the best WER so far for your analysis.

### 2.6 Plot WER vs utterance duration (10pt)

Create a plot to check if the duration of the utterance is related to the WER. Conclude upon visual observation if there exists a correlation between the two quantities. Use the system with the best WER so far. Justify your response.

## 3. Whisper tokenizer (15pt)

In the next code cell, the tokenizer used by multilingual models is initialized in the variable ``tokenizer``. Without changing the ``tokenizer`` variable, do the following:

1. Print the token index of beginning of sentence and end of sentence special tokens
2. Based on [this dictionary](https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/tokenizer.py#L10), print the the token index of two languages: French and German. An example is already given below for English.
3. Given that the supported languages are in ``whisper.tokenizer.LANGUAGES``, find out only using python code if (a) your native tongue is supported (b) the language Kurmanji is supported.

In [None]:
from whisper import load_model

model = load_model('base', device='cpu', download_root='./')

In [None]:
import whisper.tokenizer as whisper_tokenizer

tokenizer = whisper_tokenizer.get_tokenizer(
    True,  # assume a multilingual model
    num_languages=model.num_languages,
)

In [None]:
## 1. Print the token index of bos, eos (2.5 pt)

In [None]:
# This is an example to get the language token for English.
tokenizer.special_tokens.get('<|en|>')

In [None]:
## 2. print the the token index of French and German (2.5 pt)

In [None]:
## 3. Find out only using python code if (a) your native tongue is supported (b) the language Kurmanji is supported. (10 pt)

## 1D convolution vs TDNN (25pt)

As seen in class, Whisper uses two convolutional layers before the Transformer layers.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

n_mels = 80
n_state = 768
conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3)
conv2 = nn.Conv1d(n_state, n_state, kernel_size=3)

In [None]:
# 1. Create an appropriate input for the above convolutional layer, and apply it
# to the first convolutional layer. The input is a 3-dimensional tensor.
# The first dimenion is a number of sequences, we assume there is only one sequence.
# The second dimenion is the number of filterbank energies per frame.
# The third dimension is the maximum audio length in number of frames.

x = torch.randn(1, n_mels, 3000)

In [None]:
# 2. Get the output after passing it through the first 2 convolution layers.

# forward pass
x = F.gelu(conv1(x))
x = F.gelu(conv2(x))

# debug statement only, not part of forward pass
x.shape

In [None]:
# 3. Adjust the hyperparameters of the two Conv1d layers above to get an output
# of sequence length 1000. Please re-write all code in this cell here. Do not
# modify the code above.

# conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, bias=False, padding=1)
# conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, padding=1, stride=3)
...

x = torch.randn(1, n_mels, 3000)
x = F.gelu(conv1(x))
x = F.gelu(conv2(x))
x.shape
# assert x.shape == torch.Size([1, 768, 150])

4. Would your approach to getting the desired sequence length at the output of the two convolutional layers change if we did not apply the non-linearity in the forward pass?

5. Use the Conv1D class instead of the TDNN class in Exercise 4 to create one of the layers of the X-vector network. The details about which layer is given in the code below.

NOTE: Only re-write it. No need to test it (unless you prefer to demonstrate that they are equal).

In [None]:
# The class is commented. It is only to be used as a reference
# class TDNN(nn.Module):
#     def __init__(
#         self,
#         feat_dim,
#         output_dim,
#         context_len=1,
#     ):
#         super(TDNN, self).__init__()
#         self.linear = nn.Linear(feat_dim*context_len, output_dim)
#         self.context_len = torch.tensor(context_len, requires_grad=False)


#     def forward(self, input):
#         mb, T, D = input.shape
#         padded_input = input.reshape(mb, -1).unfold(1, D*self.context_len, D).contiguous()
#         x = self.linear(padded_input)
#         return x

# Rewrite the line below with Conv1D
# tdnn1 = TDNN(feat_dim=40, output_dim=128, context_len=3)

tdnn1 = nn.Conv1D(...)