In [None]:
import os
import datasets
import pandas as pd
import numpy as np
import librosa
from datasets import Audio, DatasetDict, Dataset
from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments
)

# Use relative paths to work with SLURM environments
BASE_DIR = "/home/yyt005/whisper_fine_tuning"  
PROCESSED_AUDIO_PATH = os.path.join(BASE_DIR, "processed_audio")
MODEL_OUTPUT_PATH = os.path.join(BASE_DIR, "fine-tuned-whisper-large")

# Create output directory if it doesn't exist
os.makedirs(MODEL_OUTPUT_PATH, exist_ok=True)

# Load data
meta_df = pd.read_csv(os.path.join(BASE_DIR, 'metadata.csv'))

# Process audio file paths
def update_filepath(path):
    # Extract the filename and its parent directory
    filename = os.path.basename(path)
    subdirectory = os.path.basename(os.path.dirname(path))
    return os.path.join(PROCESSED_AUDIO_PATH, subdirectory, filename)

meta_df['audio_filepath'] = meta_df['audio_filepath'].apply(update_filepath)

# Check if files exist
valid_files = []
for idx, row in meta_df.iterrows():
    if os.path.exists(row['audio_filepath']):
        valid_files.append(idx)
    else:
        print(f"Warning: File not found - {row['audio_filepath']}")

meta_df = meta_df.loc[valid_files]
print(f"Processing {len(meta_df)} valid audio files")

# Create train/test splits
train_sample = meta_df.sample(n=min(100, len(meta_df)//2))
test_sample = meta_df.drop(train_sample.index).sample(n=min(20, len(meta_df)//4))

# Create dataset dictionary
sample_dataset = DatasetDict({
    "train": Dataset.from_pandas(train_sample),
    "test": Dataset.from_pandas(test_sample)
})

# Load Whisper components
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-large")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large", language="english", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-large", language="english", task="transcribe")

# Load and process audio files
def load_audio(example):
    audio_path = example["audio_filepath"]
    try:
        audio, sr = librosa.load(audio_path, sr=16000)
        return {"audio": {"array": audio, "sampling_rate": sr}}
    except Exception as e:
        print(f"Error loading audio file {audio_path}: {e}")
        return {"audio": {"array": np.zeros(1600), "sampling_rate": 16000}}

# Prepare features and labels
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(
        audio["array"],
        sampling_rate=audio["sampling_rate"]
    ).input_features[0]
    batch["labels"] = tokenizer(batch["text"]).input_ids
    return batch

# Process the dataset
sample_dataset = sample_dataset.map(
    load_audio,
    num_proc=4
)

sample_dataset = sample_dataset.cast_column(
    "audio",
    Audio(sampling_rate=16000)
)

processed_dataset = sample_dataset.map(
    prepare_dataset,
    remove_columns=sample_dataset.column_names["train"],
    num_proc=4
)

# Data Collator
def data_collator(features):
    input_features = [{"input_features": feature["input_features"]} for feature in features]
    batch = processor.feature_extractor.pad(input_features, padding="longest", return_tensors="pt")
    labels = [feature["labels"] for feature in features]
    batch["labels"] = tokenizer.pad({"input_ids": labels}, padding="longest", return_tensors="pt")["input_ids"]
    return batch

# Training arguments
seq2seq_training_args = Seq2SeqTrainingArguments(
    output_dir=MODEL_OUTPUT_PATH,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    num_train_epochs=3,
    gradient_checkpointing=True,
    fp16=True,  # GPU should support this on HPC
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=100,
    save_steps=100,
    logging_steps=25,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    remove_unused_columns=True,
    report_to=["tensorboard"],
)

# Model Initialization
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")

# Trainer Initialization
trainer = Seq2SeqTrainer(
    model=model,
    args=seq2seq_training_args,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["test"],
    data_collator=data_collator,
)

# Training
trainer.train()

# Save the fine-tuned model
trainer.save_model(MODEL_OUTPUT_PATH)
print(f"Model saved to {MODEL_OUTPUT_PATH}")