In [1]:
pwd = '.'
# !pip install --upgrade pip
# !pip install --upgrade datasets transformers accelerate evaluate jiwer
# from google.colab import drive
# drive.mount('/content/drive')
# pwd = './drive/MyDrive/Colab Notebooks/CS4347'

In [2]:
import torch
import torchaudio
import tensorboard
from dataclasses import dataclass
from datasets import load_dataset, DatasetDict, concatenate_datasets
from transformers import WhisperFeatureExtractor, WhisperProcessor, WhisperTokenizer, DataCollatorWithPadding, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer, pipeline
import os
import evaluate
from typing import Any, Dict, List, Union
# target = 'hanlo'
target = 'tailo'
# target_column = 'hok_text_hanlo_tai'
target_column = 'hok_text_tailo_number_tone'
size = 'small' # model size
n_epoch = 5

In [3]:
from datasets import load_dataset, DatasetDict

# Specify datasets and their split status
data_sources = [
    {"name": "tat_open_source", "pre_split": True},
    # {"name": "hok_song", "pre_split": False, "test_split_percentage": 1},
    # {"name": "suisiann", "pre_split": False, "data_percentage": 0.1, "test_split_percentage": 0.25},
]

# Initialize an empty DatasetDict
combined_dataset = DatasetDict()

# Loop through each dataset
for data_source in data_sources:
    dataset_name = data_source["name"]
    is_pre_split = data_source["pre_split"]
    data_percentage = data_source.get("data_percentage", 1.0)  # Default to 100% if not specified
    test_split_percentage = data_source.get("test_split_percentage", 0.2) # Default to 20% if not specified
    
    if is_pre_split:
        # For pre-split datasets, load train and test directly
        dataset = load_dataset(
            'csv',
            data_files={
                'train': pwd + f'/data/{dataset_name}/dev/dev.tsv',
                'test': pwd + f'/data/{dataset_name}/test/test.tsv'
            },
            delimiter='\t',
            usecols=['hok_audio', target_column]
        )
    else:
        # Load the non-pre-split dataset
        dataset = load_dataset(
            'csv',
            data_files={'full': pwd + f'/data/{dataset_name}/all.csv'}
        )
    
        # Filter columns using map
        dataset = dataset['full'].map(lambda example: {key: example[key] for key in ['hok_audio', target_column]})
    
        # Dynamically split into train and test
        dataset = dataset.train_test_split(test_size=test_split_percentage)

    # Apply data percentage (limit the rows based on the percentage)
    if data_percentage < 1.0:
        dataset['train'] = dataset['train'].select(range(int(len(dataset['train']) * data_percentage)))
        dataset['test'] = dataset['test'].select(range(int(len(dataset['test']) * data_percentage)))

    def update_audio_path(example, dataset_type):
        if is_pre_split:
            if dataset_type == 'train':
                example['hok_audio'] = pwd + f'/data/{dataset_name}/dev/' + example['hok_audio']
            elif dataset_type == 'test':
                example['hok_audio'] = pwd + f'/data/{dataset_name}/test/' + example['hok_audio']
        else:
            example['hok_audio'] = pwd + f'/data/{dataset_name}/' + example['hok_audio']
        return example

    dataset['train'] = dataset['train'].map(lambda x: update_audio_path(x, 'train'))
    dataset['test'] = dataset['test'].map(lambda x: update_audio_path(x, 'test'))

    # Add a `source` column to indicate the dataset name
    dataset['train'] = dataset['train'].map(lambda x: {**x, 'source': dataset_name})
    dataset['test'] = dataset['test'].map(lambda x: {**x, 'source': dataset_name})

    # Add the current dataset's splits to the combined dataset
    if 'train' not in combined_dataset:
        combined_dataset['train'] = dataset['train']
    else:
        combined_dataset['train'] = concatenate_datasets([combined_dataset['train'], dataset['train']])
    
    if 'test' not in combined_dataset:
        combined_dataset['test'] = dataset['test']
    else:
        combined_dataset['test'] = concatenate_datasets([combined_dataset['test'], dataset['test']])

# Truncate labels for the combined dataset
max_label_length = 448

def truncate_labels(example):
    """Truncates the 'labels' field to the maximum allowed length."""
    example[target_column] = example[target_column][:max_label_length]
    return example

combined_dataset['train'] = combined_dataset['train'].map(truncate_labels)
combined_dataset['test'] = combined_dataset['test'].map(truncate_labels)

Map:   0%|          | 0/722 [00:00<?, ? examples/s]

Map:   0%|          | 0/686 [00:00<?, ? examples/s]

In [4]:
# test dataset loading
print(combined_dataset['train'].num_rows)
print(combined_dataset['train'][710])
# print(combined_dataset['train'][730])

722
{'hok_audio': './data/tat_open_source/dev/hok/TAT-Vol1-eval_0034_5.64_TSM013_concat.wav', 'hok_text_tailo_number_tone': 'hian7-tai7 e5 tai5-uan5 siau3-lian5-lang5 lian5“bong1 la5-a2”to1 m7 tsai1 siann2 i3-su3, beh4 an2-tsuann2 ka7 kai2-sueh4“kiam1 se2 khoo3”?', 'source': 'tat_open_source'}


In [5]:
feature_extractor = WhisperFeatureExtractor.from_pretrained('openai/whisper-' + size)
tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-' + size, language='Mandarin', task='transcribe')

In [6]:
input_str = combined_dataset['train'][0][target_column]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

In [7]:
# test
print(input_str)
print(labels)
print(decoded_with_special)
print(decoded_str)
input_str == decoded_str

the5-si7 kha2 pian1-ho7:TA_0009
[50258, 50260, 50359, 50363, 3322, 20, 12, 7691, 22, 350, 1641, 17, 32198, 16, 12, 1289, 22, 25, 8241, 62, 1360, 24, 50257]
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>the5-si7 kha2 pian1-ho7:TA_0009<|endoftext|>
the5-si7 kha2 pian1-ho7:TA_0009


True

In [8]:
processor = WhisperProcessor.from_pretrained('openai/whisper-' + size, language='Mandarin', task='transcribe')

In [9]:
def preprocess_function(examples):
    audio_path = examples['hok_audio']
    # Load audio
    speech_array, sampling_rate = torchaudio.load(audio_path)
    # Resample if necessary
    speech_array = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(speech_array)
    # Convert audio to log-mel spectrogram
    input_features = processor(speech_array.squeeze().numpy(), sampling_rate=16000).input_features
    return {'input_features': input_features, 'transcription': examples[target_column]}

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio_path = batch['hok_audio']
    # Load audio
    speech_array, sampling_rate = torchaudio.load(audio_path)

    speech_array = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(speech_array)
    # compute log-Mel input features from input audio array
    batch["input_features"] =  feature_extractor(speech_array.squeeze().numpy(), sampling_rate=16000).input_features[0]
    # batch["input_features"] = feature_extractor(speech_array, sampling_rate=16000).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch[target_column]).input_ids
    return batch

combined_dataset = combined_dataset.map(prepare_dataset, remove_columns=['hok_audio'])

Map:   0%|          | 0/722 [00:00<?, ? examples/s]

Map:   0%|          | 0/686 [00:00<?, ? examples/s]

In [10]:
# Load the pre-trained Whisper model
model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-' + size)

In [11]:
model.generation_config.language = 'Mandarin'
model.generation_config.task = 'transcribe'

model.generation_config.forced_decoder_ids = None

In [12]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [13]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [14]:
# Tailo Tokenizer
#   code snippet from https://github.com/wchang88/Tai-Lo-Tokenizer/blob/main/TailoTokenizer.py
import re
from string import punctuation

class TailoTokenizer():
   def __init__(self):
      self.consonants = ['ph', 'p',
                      'm', 'b',
                      'tshi', 'tsh', 'tsi', 'ts', 'th','t',
                      'n', 'l',
                      'kh', 'k',
                      'ng', 'g',
                      'si', 's',
                      'ji','j',
                      'h']

   def tokenize_helper(self, word):
      for onset in self.consonants:
         if word.lower().find(onset) == 0:
            if onset[-1] == 'i':
               return [word[:len(onset)], word[len(onset) - 1:]]
            else:
               return [word[:len(onset)], word[len(onset):]]
      return [word]

   def tokenize(self, sent):
      tokens = []
      for word in re.split(r' |([%s]+)' % re.escape(punctuation), sent):
         if word is not None:
            if re.search(r'[%s]+' % re.escape(punctuation), word):
               # if any combination of punctuation
               tokens.append(word)
            else:
               # if a tai-lo romanization
               tokens.extend(self.tokenize_helper(word))
      return tokens

   def tokenize_join(self, text):
      # Tokenize into initials and finals
      tokens = self.tokenize(text)
      # Join tokens with spaces for consistency
      return " ".join(tokens)

   def tokenize_join_no_dashes(self, text): # remove "--"" and "-"" in Tailo (not used)
      # Remove dashes between words
      text = text.replace("--", " ").replace("-", " ")
      # Tokenize into initials and finals
      tokens = self.tokenize(text)
      # Join tokens with spaces for consistency
      return " ".join(tokens)

   def remove_tone_numbers(self, token):
      """Removes trailing tone numbers from a token."""
      return re.sub(r'\d+$', '', token)

   def tokenize_join_remove_tones(self, text):
      tokens = self.tokenize(text)
      tokens = [self.remove_tone_numbers(token) for token in tokens]
      return " ".join(tokens)

   def tokenize_join_no_dashes_remove_tones(self, text):
      text = text.replace("--", " ").replace("-", " ")
      tokens = self.tokenize(text)
      tokens = [self.remove_tone_numbers(token) for token in tokens]
      return " ".join(tokens)

   def detokenize(self, tokens):
      i = 0
      sentence = []
      dash_found = False
      while i < len(tokens):
         if re.search(r'[%s]+' % re.escape(punctuation), tokens[i]):
            # if the current token is punctuation
            if '-' in tokens[i]:
               dash_found = True
            sentence.append(tokens[i])
            i += 1
         else:
            if tokens[i] in self.consonants:
               # if the current token is a consonant, combine it with the next
               if tokens[i][-1] == 'i' and tokens[i+1][0] == 'i':
                  # reduce double i into single i
                  sentence.append("".join([tokens[i], tokens[i+1][1:]]))
               else:
                  sentence.append("".join(tokens[i:i+2]))
               i += 2
            else:
               sentence.append(tokens[i])
               i += 1

            if dash_found:
               compound = [sentence.pop() for i in range(3)]
               sentence.append("".join(compound[::-1]))
               dash_found = False

      return " ".join(sentence)

In [15]:
# test Tailo Tokenizer
text = combined_dataset['train'][2][target_column]
tailo_tokenizer = TailoTokenizer()
tailo_tokens_split = tailo_tokenizer.tokenize(text)
tailo_tokens_string = tailo_tokenizer.tokenize_join(text)
tailo_tokens_string_no_dashes = tailo_tokenizer.tokenize_join_no_dashes(text)

tailo_tokens_string_no_tones = tailo_tokenizer.tokenize_join_remove_tones(text)
tailo_tokens_string_no_dashes_no_tones = tailo_tokenizer.tokenize_join_no_dashes_remove_tones(text)
print(text)
print(tailo_tokens_split)
print(tailo_tokens_string)
print(tailo_tokens_string_no_dashes)
print(tailo_tokens_string_no_tones)
print(tailo_tokens_string_no_dashes_no_tones)

sua3-loh8-lai5 khuann3 lam5-tau5-kuan7 bin5-a2-tsai3 sann1 ho7 e5 thinn1-khi3
['s', 'ua3', '-', 'l', 'oh8', '-', 'l', 'ai5', 'kh', 'uann3', 'l', 'am5', '-', 't', 'au5', '-', 'k', 'uan7', 'b', 'in5', '-', 'a2', '-', 'ts', 'ai3', 's', 'ann1', 'h', 'o7', 'e5', 'th', 'inn1', '-', 'kh', 'i3']
s ua3 - l oh8 - l ai5 kh uann3 l am5 - t au5 - k uan7 b in5 - a2 - ts ai3 s ann1 h o7 e5 th inn1 - kh i3
s ua3 l oh8 l ai5 kh uann3 l am5 t au5 k uan7 b in5 a2 ts ai3 s ann1 h o7 e5 th inn1 kh i3
s ua - l oh - l ai kh uann l am - t au - k uan b in - a - ts ai s ann h o e th inn - kh i
s ua l oh l ai kh uann l am t au k uan b in a ts ai s ann h o e th inn kh i


In [16]:
# metrics
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # Decode predictions and references
    pred_str_raw = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str_raw = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # Hanlo case: Use CER
    if target == 'hanlo':
        # Load CER metric
        cer_metric = evaluate.load('cer')

        # Calculate CER
        cer = cer_metric.compute(predictions=pred_str_raw, references=label_str_raw)

        # Print examples for debugging
        for i in range(min(5, len(pred_str_raw))):  # Print first 5 examples
            print(f"Prediction: {pred_str_raw[i]}")
            print(f"Ground Truth: {label_str_raw[i]}")
            print("---")

        return {
            "cer": 100 * cer  # CER as percentage
        }

    # Tailo case: Calculate multiple metrics
    else:
        # Initialize TailoTokenizer
        tailo_tokenizer = TailoTokenizer()

        # Processed strings for different metrics
        pred_str_tokenize = [tailo_tokenizer.tokenize_join(p) for p in pred_str_raw]
        label_str_tokenize = [tailo_tokenizer.tokenize_join(l) for l in label_str_raw]

        pred_str_no_tones = [tailo_tokenizer.tokenize_join_remove_tones(p) for p in pred_str_raw]
        label_str_no_tones = [tailo_tokenizer.tokenize_join_remove_tones(l) for l in label_str_raw]

        # Load WER metric
        wer_metric = evaluate.load('wer')

        # Calculate WER for raw text
        wer = wer_metric.compute(predictions=pred_str_raw, references=label_str_raw)

        # SER for tokenized text (after `tokenize_join`)
        ser = wer_metric.compute(predictions=pred_str_tokenize, references=label_str_tokenize)

        # SER for tokenized text with tones removed (after `tokenize_join_remove_tones`)
        ser_no_tones = wer_metric.compute(predictions=pred_str_no_tones, references=label_str_no_tones)

        # Print examples for debugging
        for i in range(min(5, len(pred_str_raw))):  # Print first 5 examples
            print(f"Original Prediction: {pred_str_raw[i]}")
            print(f"Original Ground Truth: {label_str_raw[i]}")
            print(f"Tokenized Prediction: {pred_str_tokenize[i]}")
            print(f"Tokenized Ground Truth: {label_str_tokenize[i]}")
            print(f"Prediction without Tones: {pred_str_no_tones[i]}")
            print(f"Ground Truth without Tones: {label_str_no_tones[i]}")
            print("---")

        # Return all metrics
        return {
            "wer": 100 * wer,  # Original WER
            "ser": 100 * ser,  # SER after `tokenize_join`
            "ser_no_tones": 100 * ser_no_tones  # SER after `tokenize_join_remove_tones`
        }

In [17]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./logs/"+ target + "-whisper-"+ size +"-training-logs",  # change to a repo name of your choice
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=20,  # originally was 500
    # max_steps=100,  # originally was 5000
    num_train_epochs=n_epoch,  # Use epochs instead of max_steps
    gradient_checkpointing=True,
    remove_unused_columns=False,
    fp16=True,
    eval_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="cer" if target == "hanlo" else "ser", 
    greater_is_better=False,
    push_to_hub=False,
)

In [18]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=combined_dataset['train'],
    eval_dataset=combined_dataset['test'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

  trainer = Seq2SeqTrainer(


In [19]:
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss




TrainOutput(global_step=225, training_loss=0.7765144877963596, metrics={'train_runtime': 487.7987, 'train_samples_per_second': 7.401, 'train_steps_per_second': 0.461, 'total_flos': 1.03198139154432e+18, 'train_loss': 0.7765144877963596, 'epoch': 4.945054945054945})

In [21]:
save_path = pwd + '/model/' + target +'-whisper-'+ size +'-hokkien-finetuned-' + str(n_epoch)
print(save_path)
model.save_pretrained(save_path)
processor.save_pretrained(save_path)

./model/tailo-whisper-small-hokkien-finetuned-5


[]

In [None]:
# Evaluate
results = trainer.evaluate()
print(results)

You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

asr_model = WhisperForConditionalGeneration.from_pretrained(save_path)
processor = WhisperProcessor.from_pretrained(save_path)

asr_pipeline = pipeline("automatic-speech-recognition",
                        model=asr_model,
                        tokenizer=processor.tokenizer,
                        feature_extractor=processor.feature_extractor,
                        chunk_length_s=30,
                        batch_size=16,  # batch size for inference - set based on your device
                        torch_dtype=torch_dtype,
                        device=device)

In [None]:
test_file_name = '/test_hokkien.mp3'
test_audio_path = pwd + test_file_name
# Perform inference on a new audio file
transcription = asr_pipeline(test_audio_path, return_timestamps=True)
print(f"Transcription: {transcription}")

薰一枝一枝一枝咧點
hun tsi̍t ki tsi̍t ki leh tiám

酒一杯一杯一杯咧焦
tsiú tsi̍t pue tsi̍t pue tsi̍t pue leh ta

請你愛體諒我
tshiánn lí ài thé-liōng guá

我酒量無好　莫共我創空
guá tsiú-liōng bô hó, mài kā guá tshòng-khang

時間一工一工一工咧走
sî-kan tsi̍t kang tsi̍t kang tsi̍t kang leh tsáu

汗一滴一滴一滴咧流
kuann tsi̍t tih tsi̍t tih tsi̍t tih leh lâu

有一工　咱攏老
ū tsi̍t kang, lán lóng lāu

𤆬某囝鬥陣
tshuā bóo-kiánn tàu-tīn

浪子回頭
lōng-tsú huê-thâu