<a href="https://colab.research.google.com/github/ymoslem/Speech/blob/main/Whisper-Fine-Tuning-Speech-Translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Speech Translation (Irish-Enlgish)
This notebook is modified for _Irish-to-English_ **Speech Translation** from the original HuggingFace tutorial, [Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers](https://huggingface.co/blog/fine-tune-whisper).


# Prepare Environment

In [None]:
!nvidia-smi

In [None]:
!pip3 install --upgrade transformers accelerate datasets -q
!pip3 install evaluate jiwer sacrebleu -q
!pip3 install librosa tensorboardX wandb -q

In [None]:
!wandb login $WB_TOKEN

In [None]:
import os
os.environ["WANDB_PROJECT"] = "Whisper-Irish"
os.environ["WANDB_LOG_MODEL"] = "end" # or "checkpoint"

# Load Dataset(s)

In [None]:
data_cache_dir = "/workspace/data/"
model_cache_dir = "/workspace/model/"

In [None]:
from datasets import load_dataset, DatasetDict, Audio

# Authentic dataset

iwslt2023_gaen_original = DatasetDict()

iwslt2023_gaen_original["train"] = load_dataset("ymoslem/IWSLT2023-GA-EN",
                                                split="train+dev",
                                                token=True,
                                                trust_remote_code=True,
                                                cache_dir=data_cache_dir,
                                                )
iwslt2023_gaen_original["test"] = load_dataset("ymoslem/IWSLT2023-GA-EN",
                                              split="test",
                                              token=True,
                                              trust_remote_code=True,
                                              cache_dir=data_cache_dir,
                                              )

iwslt2023_gaen_original = iwslt2023_gaen_original.cast_column("audio", Audio(sampling_rate=16000))

print(iwslt2023_gaen_original)
print(iwslt2023_gaen_original["train"][0])

In [None]:
# Dataset #2: Fleurs (authentic)

fleurs_dataset = load_dataset("ymoslem/FLEURS-GA-EN",
                              data_dir="data",
                              split="train",
                              token=True,
                              trust_remote_code=True,
                              cache_dir=data_cache_dir,
                              )

fleurs_dataset = fleurs_dataset.remove_columns(["id", "text_ga"])
fleurs_dataset = fleurs_dataset.rename_column("text_en", "translation")

fleurs_dataset = fleurs_dataset.cast_column("audio", Audio(sampling_rate=16000))

print(fleurs_dataset)
print(fleurs_dataset[0])

In [None]:
# Dataset #3: BitesizeIrish (authentic)

bitesize_dataset = load_dataset("ymoslem/BitesizeIrish-GA-EN",
                                data_dir="data",
                                split="train",
                                token=True,
                                trust_remote_code=True,
                                cache_dir=data_cache_dir,
                                #download_mode="force_redownload",
                                )

bitesize_dataset = bitesize_dataset.remove_columns(["text_ga", "text_en_raw", "pronunciation", "url"])
bitesize_dataset = bitesize_dataset.rename_column("text_en", "translation")

bitesize_dataset = bitesize_dataset.cast_column("audio", Audio(sampling_rate=16000))

print(bitesize_dataset)
bitesize_dataset[0]

In [None]:
# Dataset #4: SpokenWords (MTed)

spoken_words_dataset = load_dataset("ymoslem/SpokenWords-GA-EN-MTed",
                                    data_dir="data",
                                    split="train",
                                    token=True,
                                    trust_remote_code=True,
                                    cache_dir=data_cache_dir,
                                    )

spoken_words_dataset = spoken_words_dataset.remove_columns(["keyword"])

spoken_words_dataset = spoken_words_dataset.cast_column("audio", Audio(sampling_rate=16000))

print(spoken_words_dataset)
print(spoken_words_dataset[0])

In [None]:
# Dataset #5: Tatoeba-Speech (synthetic)

tatoeba_dataset = load_dataset("ymoslem/Tatoeba-Speech-Irish",
                              data_dir="data",
                              split="train",
                              token=True,
                              trust_remote_code=True,
                              cache_dir=data_cache_dir,
                              )

tatoeba_dataset = tatoeba_dataset.remove_columns(["text_ga"])
tatoeba_dataset = tatoeba_dataset.rename_column("text_en", "translation")

tatoeba_dataset = tatoeba_dataset.cast_column("audio", Audio(sampling_rate=16000))

print(tatoeba_dataset)
print(tatoeba_dataset[0])

In [None]:
# Dataset #6: Wikimedia-Speech (synthetic)

wikimedia_dataset = load_dataset("ymoslem/Wikimedia-Speech-Irish",
                                 data_dir="data",
                                 split="train",
                                 token=True,
                                 trust_remote_code=True,
                                 cache_dir=data_cache_dir,
                                )

wikimedia_dataset = wikimedia_dataset.remove_columns(["text_ga"])
wikimedia_dataset = wikimedia_dataset.rename_column("text_en", "translation")

wikimedia_dataset = wikimedia_dataset.cast_column("audio", Audio(sampling_rate=16000))

print(wikimedia_dataset)
print(wikimedia_dataset[0])

In [None]:
# Dataset #7: EUbookshop-Speech (synthetic)

ebookshop_dataset = load_dataset("ymoslem/EUbookshop-Speech-Irish",
                                 data_dir="data",
                                 split="train",
                                 token=True,
                                 trust_remote_code=True,
                                 cache_dir=data_cache_dir,
                                )

ebookshop_dataset = ebookshop_dataset.remove_columns(["text_ga"])
ebookshop_dataset = ebookshop_dataset.rename_column("text_en", "translation")

ebookshop_dataset = ebookshop_dataset.cast_column("audio", Audio(sampling_rate=16000))

print(ebookshop_dataset)
print(ebookshop_dataset[0])

In [None]:
from datasets import concatenate_datasets

iwslt2023_gaen = DatasetDict()

iwslt2023_gaen["train"] = concatenate_datasets([iwslt2023_gaen_original["train"],
                                                fleurs_dataset,
                                                bitesize_dataset,
                                                spoken_words_dataset,
                                                tatoeba_dataset,
                                                wikimedia_dataset,
                                                ebookshop_dataset
                                                ]
                                               )

iwslt2023_gaen["test"] = iwslt2023_gaen_original["test"]

iwslt2023_gaen = iwslt2023_gaen.shuffle(seed=42)

print(iwslt2023_gaen)

In [None]:
iwslt2023_gaen["train"][0]

In [None]:
iwslt2023_gaen["test"][0]

# Prepare Feature Extractor, Tokenizer and Data

In [2]:
# Define the Whisper model name
# Options: "openai/whisper-tiny", "openai/whisper-base",
# "openai/whisper-small", "openai/whisper-medium", "openai/whisper-large-v3"

model_name = "openai/whisper-medium"

### Load WhisperFeatureExtractor

In [None]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name,
                                                            cache_dir=model_cache_dir,
                                                           )

### Load WhisperTokenizer

In [None]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained(model_name,
                                             cache_dir=model_cache_dir,
                                             language="English",  # target language
                                             task="translate")  # important

### Combine To Create A WhisperProcessor

In [None]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained(model_name,
                                             cache_dir=model_cache_dir,
                                             language="English",  # target language
                                             task="translate")  # important

### Prepare Data

In [None]:
def prepare_dataset(batch):
  # load and resample audio data from 48 to 16kHz
  audio = batch["audio"]

  # compute log-Mel input features from input audio array
  batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

  # encode target text to label ids
  batch["labels"] = tokenizer(batch["translation"]).input_ids
  return batch

In [None]:
iwslt2023_gaen = iwslt2023_gaen.map(prepare_dataset,
                                    remove_columns=iwslt2023_gaen.column_names["train"],
                                    num_proc=None)

## Training and Evaluation

### Define a Data Collator

In [None]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

### Evaluation Metrics

In [None]:
import evaluate

metric_bleu = evaluate.load("sacrebleu")
metric_chrf = evaluate.load("chrf")
metric_wer = evaluate.load("wer")
# metric_comet = evaluate.load("comet")

Downloading builder script:   0%|          | 0.00/8.15k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/9.01k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)

    bleu = metric_bleu.compute(predictions=pred_str, references=label_str)
    bleu = round(bleu["score"], 2)

    chrf = metric_chrf.compute(predictions=pred_str, references=label_str)
    chrf = round(chrf["score"], 2)

    return {"bleu": bleu, "chrf": chrf, "wer": wer}

### Load a Pre-Trained Checkpoint

In [None]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name,
                                                        cache_dir=model_cache_dir,
                                                       )

In [None]:
# print(model.config)

In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [None]:
# Define the target language, here it is "en" as we want to translate into English
# If you want to translate into another language, change "en" to its language code
model.generation_config.language = "en"

### Define the Training Configuration

In [None]:
# Traing arguments

run_name = "whisper-medium-ga2en"

output_dir = "ymoslem/" + run_name

batch_size = 16  # tested on A100-SXM4-80GB GPU; change if the GPU memory is less

learning_rate = 1e-4
warmup_ratio = 0.03

max_steps = 8000  # equivalent to 1.1 epoch for these datasets; change as needed

gradient_accumulation_steps=1

max_length = 225

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,  # change to a repo name of your choice
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,  # increase by 2x for every 2x decrease in batch size - default 1
    learning_rate=learning_rate,
    # warmup_steps=warmup_steps,
    warmup_ratio=warmup_ratio,
    max_steps=max_steps,
    # gradient_checkpointing=True,  # less memory, but slower
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    generation_max_length=max_length,
    save_steps=100,
    eval_steps=100,
    logging_steps=25,
    report_to=["tensorboard", "wandb"],
    run_name=run_name,
    load_best_model_at_end=True,
    metric_for_best_model="chrf",
    greater_is_better=True,
    push_to_hub=True,  # set to True to push the model to the Hugging Face model hub
    private=True,  # set to True to make the repository private
)

In [None]:
training_args.learning_rate

In [None]:
print(training_args)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=iwslt2023_gaen["train"],
    eval_dataset=iwslt2023_gaen["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
processor.save_pretrained(training_args.output_dir)

[]

### Training

In [None]:
# Start training
trainer.train()

In [None]:
# To rather resume training from a checkpoint

# trainer.train(resume_from_checkpoint=True)
# trainer.train(resume_from_checkpoint="ymoslem/whisper-medium-ga2en/checkpoint-8000/")

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir /content/drive/MyDrive/models/whisper-ga2en/ymoslem/whisper-medium-ga2en/runs

In [3]:
kwargs = {
    "dataset_tags": ["ymoslem/IWSLT2023-GA-EN", "ymoslem/FLEURS-GA-EN",
                     "ymoslem/BitesizeIrish-GA-EN", "ymoslem/SpokenWords-GA-EN-MTed",
                     "ymoslem/Tatoeba-Speech-Irish", "ymoslem/Wikimedia-Speech-Irish",
                     "ymoslem/EUbookshop-Speech-Irish"],
    "dataset": "IWSLT-2023, FLEURS, BiteSize, SpokenWords, Tatoeba, Wikimedia, and EUbookshop",  # a 'pretty' name for the training dataset
    # "dataset_args": "config: en, split: test",
    "language": ["ga", "en"],
    "model_name": "Whisper Medium GA-EN Speech Translation",  # a 'pretty' name for our model
    "finetuned_from": model_name,
    "tasks": "automatic-speech-recognition",
}

In [None]:
trainer.push_to_hub(**kwargs)

In [None]:
import wandb

datset_names = "IWSLT-2023 GA-EN, FLEURS, BiteSize, SpokenWords, Tatoeba, Wikimedia, EUbookshop"

wandb.run.notes = f"""Whisper Medium, fine-tuned on {datset_names} datasets.
Learning rate {learning_rate}, warmup ratio {warmup_ratio},
batch size {batch_size}, max steps {max_steps},
gradient accumulation steps {training_args.gradient_accumulation_steps}."""

wandb.finish()

In [None]:
print(trainer.state.log_history)