In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

model_name_or_path = "openai/whisper-large-v3-turbo"
language = "English"
language_abbr = "en"
task = "transcribe"
dataset_name = "/"

org = ""
trained_adapter_name = "whisper-turbo-names-adapters"
trained_model_name = "whisper-turbo-names"

trained_adapter_repo = org + "/" + trained_adapter_name
trained_model_repo = org + "/" + trained_model_name


In [2]:
from transformers import pipeline
from transformers import (
    AutomaticSpeechRecognitionPipeline,
    WhisperTimeStampLogitsProcessor,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperProcessor,
)
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
whisper_asr = pipeline(
    "automatic-speech-recognition",
    model=model_name_or_path,
    chunk_length_s=30,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

In [4]:
import re

def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = int(seconds % 60)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d}"

def process_audio_create_vtt(audio_filename, audio_type, whisper_asr, output_filename=None):
    prediction = whisper_asr(f"{audio_filename}.{audio_type}", return_timestamps=True)
    
    vtt_filename = output_filename if output_filename else f"{audio_filename}.vtt"

    with open(vtt_filename, "w", encoding="utf-8") as vtt_file:
        vtt_file.write("WEBVTT\n\n")

        for i, chunk in enumerate(prediction.get("chunks", [])):
            start, end = chunk.get("timestamp", (None, None))
            text = chunk.get("text", "").strip()

            if start is None or end is None or not text:
                continue

            start_time = format_time(start)
            end_time = format_time(end)

            if not re.match(r"^\d{2}:\d{2}:\d{2}\.\d{3}$", start_time) or not re.match(r"^\d{2}:\d{2}:\d{2}$", end_time):
                continue
            vtt_file.write(f"{i+1}\n{start_time}.000 --> {end_time}.000\n{text}\n\n")

In [5]:
from datasets import DatasetDict, Audio, Dataset
import webvtt
from datetime import datetime
import librosa
import soundfile as sf
import os
from huggingface_hub import login

# setup
hf_username = ""
repo_name = "names"
train_audio_file = ""
train_vtt_file = ""
validation_audio_file = ""
validation_vtt_file = ""
save_path = f"data/{repo_name}-dataset"

def parse_time(time_str):
    return datetime.strptime(time_str, "%H:%M:%S.%f")

def milliseconds(time_obj):
    return time_obj.hour * 3600000 + time_obj.minute * 60000 + time_obj.second * 1000 + time_obj.microsecond // 1000

def time_to_samples(time_ms, sr):
    return int(time_ms * sr / 1000)

def transform_data(data):
    audio_path = data["train_audio_file"]
    vtt_path = data["train_vtt_file"]
    output_dir = data["save_path"]

    full_audio, sr = librosa.load(audio_path, sr=None, mono=True)

    captions = webvtt.read(vtt_path)
    data = []
    current_text = []
    current_start = None
    current_end = None
    accumulated_duration = 0
    segment_counter = 0

    for caption in captions:
        start = milliseconds(parse_time(caption.start))
        end = milliseconds(parse_time(caption.end))
        duration = (end - start).total_seconds()

        if current_start is None:
            current_start = start

        if accumulated_duration + duration > 30:
            current_text.append(caption.text)
            current_end = end
            accumulated_duration += duration
        
        else:
            segment_filename = f"{output_dir}/segment_{segment_counter}.mp3"
            start_samples = time_to_samples(milliseconds(current_start.time()), sr)
            end_samples = time_to_samples(milliseconds(current_end.time()), sr)
            audio_segment = full_audio[start_samples:end_samples]
            sf.write(segment_filename, audio_segment, sr, format="mp3")

            data.append({
                "audio": segment_filename,
                "text": " ".join(current_text),
                "start_time": current_start.strftime("%H:%M:%S.%f")[:-3],
                "end_time": current_end.strftime("%H:%M:%S.%f")[:-3],
            })

            current_text = [caption.text]
            current_start = start
            current_end = end
            accumulated_duration = duration
            segment_counter += 1
    
    if current_text:
        segment_filename = f"{output_dir}/segment_{segment_counter}.mp3"
        start_samples = time_to_samples(milliseconds(current_start.time()), sr)
        end_samples = time_to_samples(milliseconds(current_end.time()), sr)
        audio_segment = full_audio[start_samples:end_samples]
        sf.write(segment_filename, audio_segment, sr, format="mp3")

        data.append({
            "audio": segment_filename,
            "text": " ".join(current_text),
            "start_time": current_start.strftime("%H:%M:%S.%f")[:-3],
            "end_time": current_end.strftime("%H:%M:%S.%f")[:-3],
        })
    
    return data

def create_dataset(train_audio_file, train_vtt_file, validation_audio_file, validation_vtt_file, save_path):
    os.makedirs(save_path, exist_ok=True)

    train_data = process_audio_create_vtt(train_audio_file, train_vtt_file, f"{save_path}")

    validation_data = process_audio_create_vtt(validation_audio_file, validation_vtt_file, f"{save_path}/validation.vtt")

    train_dataset = Dataset.from_dict(transform_data(train_data))
    valid_dataset = Dataset.from_dict(transform_data(validation_data))

    dataset_dict = DatasetDict({
        "train": train_dataset, 
        "validation": valid_dataset
    })

    return dataset_dict


dataset = create_dataset(train_audio_file, train_vtt_file, validation_audio_file, validation_vtt_file, save_path)

dataset.save_to_disk(save_path)

dataset = Dataset.cast_column("audio", Audio())


TypeError: 'str' object is not callable

In [6]:
from datasets import load_dataset, DatasetDict

dataset = DatasetDict()
dataset["train"] = load_dataset(dataset_name, split="train")
dataset["validation"] = load_dataset(dataset_name, split="validation")

print(dataset)

IndexError: list index out of range

In [7]:
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)
processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)

In [8]:
print(dataset["train"][0])

KeyError: 'train'

In [9]:
from datasets import Audio

dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [10]:
def prepare_dataset(batch):
    audio = batch["audio"]

    batch["input_features"] = feature_extractor(audio, sampling_rate=audio.sampling_rate)

    batch["labels"] = tokenizer(batch["text"]).input_ids
    return batch

In [11]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        if (labels[:, 0] == self.processor.tokenizer.pad_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [12]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [15]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path)
model.config.forced_decoder_ids = None  # Optional: Let the model generate freely
model.config.suppress_tokens = []

In [16]:
print(model)

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1280)
      (layers): ModuleList(
        (0-31): 32 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bia

In [19]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

config = LoraConfig(r=32, lora_alpha=8, use_rslora=True,
                    target_modules=["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"],
                    lora_dropout=0.05, bias ="none")

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 27,852,800 || all params: 836,730,880 || trainable%: 3.3288


In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="results",
    per_device_eval_batch_size=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=5e-4,
    evaluation_strategy="steps",
    bf16=True,
    generation_max_length=128,
    logging_steps=2,
    save_steps=2,
    eval_steps=2,
    remove_unused_columns=False,
    lr_scheduler_type="cosntant",
    warmup_steps=2,
    save_total_limit=2,
    num_train_epochs=2,
    overwrite_output_dir=True,
    logging_dir="logs",
)