<a href="https://colab.research.google.com/github/rfclara/fa_xhosa/blob/main/evaluation_asr_xhosa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Are evangelist speech models able to understand natural language ?

# Introduction

In this notebook, we explore the performance of an existing Automatic Speech Recognition (ASR) model for Xhosa, developed under the [Massively Multilingual Speech (MMS)](https://huggingface.co/docs/transformers/model_doc/mms) project by Facebook AI Research. The model's evaluation is conducted using a natural data corpus provided by the University of Gothenburg's Spraakbanken (Language Bank), accessible at [Spraakbanken's Xhosa Corpus](https://spraakbanken.gu.se/korp/?mode=xhosa#?lang=eng&cqp=%5B%5D&corpus=xhosa).

## Objectives

1. **Align and preprocess the data:**
   - The original Spraakbanken's Xhosa Corpus is fully manually transcribed by students in South Africa who have undergone specific training. Each sentence has been aligned to the corresponding audio using the MMS [forced alignment tool](https://github.com/facebookresearch/fairseq/tree/main/examples/mms/data_prep), as detailed in this [Notebook](https://github.com/rfclara/fa_xhosa/blob/main/xhosa_forced_alignement.ipynb). The preprocessing involved removing punctuation and comments such as \<laugh> or \<code-switching> to ensure the corpus is in a consistent format suitable for evaluation.

2. **Evaluate Existing MMS ASR Model:**
  - Assess the robustness and accuracy of the MMS ASR model for Xhosa using a natural language dataset from Spraakbanken. Thanks to this natural data, we will have the opportinity to verify Meta's claim: "while the content of the audio recordings is religious, our analysis shows that this does not overly bias the model to produce more religious language”.


3. **Fine-Tuning the ASR Model:**
   - Investigate whether fine-tuning the MMS ASR model on the natural data corpus improves its performance, given that the original model appears to be primarily trained on biblical texts.

4. **Compare the performances:**
Compare WER and CER metrix for the MMS ASR model before and after fine-tunning.

## Structure of the Notebook

1. **Data Preparation:**
   - Steps to download and preprocess the dataset.
   
2. **Model Evaluation:**
   - Application of the MMS ASR model on the prepared dataset.
   - Performance metrics and error analysis.
   
3. **Fine-Tuning:**
   - Process of fine-tuning the MMS ASR model using the natural data corpus.
   - Comparison of performance metrics before and after fine-tuning.
   
4. **Results and Discussion:**
   - Insights gained from the evaluation.
   - Discussion on the model's strengths and areas for improvement.
   
5. **Conclusion:**
   - Summary of findings.
   - Future work and potential improvements.

By following this structured approach, we aim to provide a comprehensive assessment of the MMS ASR model's capabilities in handling real-world Xhosa speech data, contributing valuable insights to the ongoing development and refinement of multilingual ASR technologies.


# Step 1: Set Up the Environment
Install required libraries and import the requiered packages

In [None]:
!pip install torchaudio transformers jiwer
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
import os
import json

In [None]:
!git clone https://github.com/pytorch/fairseq
!pwd
%cd "/content/fairseq"
!pip install --editable ./
!pip install tensorboardX

## Step 2: Load the Model
Load the processor and model for Xhosa from Facebook's MMS

In [None]:
from transformers import Wav2Vec2ForCTC, AutoProcessor

model_id = "facebook/mms-1b-all" # mms-1b-fl102, mms-1b-l1107, mms-1b-all
target_lang = "xho"  # Xhosa language code

processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang)
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True)

# Step 3: Download the clean and aligned audio files and gold transcriptions

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/aligned.zip /content
!unzip /content/aligned.zip -d /

# Step 4: Prediction


In [None]:
def load_and_process_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)
    waveform = waveform.squeeze()  # Remove channel dimension if it's mono
    if sample_rate != 16000:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    return waveform

# Step 4: Iterate through directories and process each manifest file
base_path = "/content/xhosa/aligned"
all_transcriptions = []
all_references = []

for root, dirs, files in os.walk(base_path):
    for file in files:
        if file == "manifest.json":
            manifest_path = os.path.join(root, file)
            with open(manifest_path, 'r') as f:
                for line in f:
                    entry = json.loads(line.strip())
                    audio_path = entry["audio_filepath"]
                    reference_text = entry["normalized_text"]
                    if not reference_text.strip():
                      continue
                    # Load and process the audio file
                    waveform = load_and_process_audio(audio_path)

                    # Process the audio input
                    inputs = processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True)

                    # Ensure input is 2D [batch_size, sequence_length]
                    if inputs.input_values.dim() == 3 and inputs.input_values.size(1) == 2:
                        inputs.input_values = inputs.input_values.mean(dim=1)  # Convert stereo to mono

                    # Perform inference
                    with torch.no_grad():
                        logits = model(inputs.input_values).logits

                    # Decode the predicted IDs to text
                    predicted_ids = torch.argmax(logits, dim=-1)
                    transcription = processor.batch_decode(predicted_ids)[0]

                    # Collect transcriptions and references
                    all_transcriptions.append(transcription)
                    all_references.append(reference_text)

                    # Print the result for each segment
                    print(f"Transcription: {transcription}")
                    print(f"Reference: {reference_text}")
                    print()

# Step 5: Evaluation

*   WER
*   CER



In [None]:
!pip install jiwer

In [None]:
import jiwer
from jiwer import wer, cer
wer = wer(all_references, all_transcriptions)
cer = cer(all_references, all_transcriptions)
print("Word Error Rate (WER):", wer)
print(f"Character Error Rate (CER): {cer}")
out = jiwer.process_words(
    all_references,
    all_transcriptions,
)

print(jiwer.visualize_alignment(out))
