In [69]:
%%writefile training_script_for_yadin.py
import librosa
import torch
import numpy as np
import pandas as pd
import gradio as gr
import transformers
import torch
import evaluate

from dataclasses import dataclass
from typing import Any, Dict, List, Union
import datasets
import pydub
import whisper
    
model_arch = "tiny" # TODO: maybe change to small for actual checks 
print("model_arch:",model_arch)
output_dir = f"/home/moshebr/dharelg/moshe/expeimental_model_SBC_{model_arch}"

seed = 42  # Or any other number you like

processor = transformers.WhisperProcessor.from_pretrained(f"openai/whisper-{model_arch}", language="english", task="transcribe")
feature_extractor = processor.feature_extractor
tokenizer = processor.tokenizer

model = transformers.WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{model_arch}")
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
model.config.suppress_tokens = []

# Thats the DataSet You worked on, right? 
ds=datasets.Dataset.load_from_disk("/home/yadinb/dharelg/ds_for_yadin_SBC_wo_augmentation/")




## Train test split 
train_fraction = 0.8
ds_indices = list(range(len(ds)))
train_ids = ds_indices[:int(len(ds)*train_fraction)]
test_ids = ds_indices[int(len(ds)*train_fraction):]
ds_train = ds.select(train_ids)
ds_test = ds.select(test_ids)


## Getting the input_features from the wav (Be vary of this method. make sure this doesn't produce a bad model!!!!)
def load_input_features_from_path(audio_path):
    """
    retrieves mel from wav 
    """
    # audio = pydub.AudioSegment.from_file(audio_path).get_array_of_samples() # doesnt work
    audio = whisper.audio.load_audio(audio_path) # seems to work (based on mmfpeg)
    mel_from_audio_path = feature_extractor(
        audio,
        sampling_rate=16000,
    )["input_features"]
    return mel_from_audio_path

# the new dataset have a field based on the wav file
ds_from_path=ds.map(lambda x:{"input_features_from_path":load_input_features_from_path(x["path"])},num_proc=8)

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

metric = evaluate.load("wer")
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

training_args = transformers.Seq2SeqTrainingArguments(
    output_dir =output_dir,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=1000,
    gradient_checkpointing=True, # trick to conserve GPU-memory (longer training time in order to fit in memory)
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    generation_max_length=448,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["none"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

trainer = transformers.Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=ds_train,
    eval_dataset=ds_test,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)
processor.save_pretrained(training_args.output_dir)
trainer.train()

Overwriting training_script_for_yadin.py


In [None]:
from IPython.display import Audio
Audio(audio_path)