# Dataloading & Preprocessing

In [None]:
!pip install jiwer

In [None]:
import os
import torchaudio

from datasets import Dataset, DatasetDict
from functools import partial
from transformers import Wav2Vec2Processor
from typing import Dict
import numpy as np


def load_movie2sub_data(root_dir: str):
    audio_paths = []
    texts = []

    for movie_name in os.listdir(root_dir):
        movie_path = os.path.join(root_dir, movie_name)
        if not os.path.isdir(movie_path):
            continue

        for fname in os.listdir(movie_path):
            if fname.endswith(".wav"):
                base = os.path.splitext(fname)[0]
                wav_path = os.path.join(movie_path, f"{base}.wav")
                txt_path = os.path.join(movie_path, f"{base}.txt")

                if os.path.exists(wav_path) and os.path.exists(txt_path):
                    audio_paths.append(wav_path)
                    with open(txt_path, "r", encoding="utf-8") as f:
                        texts.append(f.read().strip())

    return Dataset.from_dict({"audio": audio_paths, "text": texts})

def preprocess(batch: Dict[str, any], processor: Wav2Vec2Processor, resample_rate=16_000) -> Dict[str, any]:
    waveform, sample_rate = torchaudio.load(batch["audio"])

    # resample
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)
    waveform = resampler(waveform)

    # convert to mono
    if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

    # normalize
    waveform = waveform / waveform.abs().max()
    
    input = processor(waveform.squeeze().numpy(), sampling_rate=resample_rate, return_attention_mask=True)
    input_values = input.input_values[0]

    input_values_min, input_values_max = min(input_values), max(input_values)
    input_values = 2 * (input_values - input_values_min) / (input_values_max - input_values_min) - 1
    
    batch["input_values"] = input_values
    batch["attention_mask"] = input.attention_mask[0]

    # process text before tokenization
    processed_text = batch["text"].upper().replace(" ", "|").replace("\n", "|")
    batch["labels"] = processor.tokenizer(processed_text).input_ids
    
    return batch

def prepare_dataset(raw_dataset: Dataset, processor: Wav2Vec2Processor, test_size: float | None = None) -> DatasetDict:
    dataset = raw_dataset.train_test_split(test_size=test_size) if test_size is not None else raw_dataset
    preprocess_fn = partial(preprocess, processor=processor)
    return dataset.map(preprocess_fn, remove_columns=["audio", "text"])

# Trainer

In [None]:
from jiwer import wer
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback, EvalPrediction

training_args = TrainingArguments(
    output_dir="./wav2vec2-finetuned-moviesubs",
    group_by_length=True,
    dataloader_num_workers=4,
    per_device_train_batch_size=2,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    num_train_epochs=80,
    fp16=False,
    learning_rate=1e-7,
    logging_strategy="steps",
    logging_dir="./logs",
    logging_steps=10,
    report_to="none",
    save_total_limit=1,
    remove_unused_columns=False,
    max_grad_norm=0.05,
)

def wer_metric(pred: EvalPrediction, processor: Wav2Vec2Processor):
    pred_ids = pred.predictions.argmax(-1)
    
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer_loss = wer(label_str, pred_str)
    return {"wer": wer_loss}

# DataCollator

In [None]:
class DataCollatorCTC:
    def __init__(self, processor: Wav2Vec2Processor):
        self.processor = processor

    def __call__(self, features):
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # pad inputs
        batch = self.processor.pad(input_features, padding=True, return_tensors="pt")
        
        # pad labels
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=True,
                return_tensors="pt"
            )
        
        # replace padding with -100 to ignore loss
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        batch["labels"] = labels
        return batch

# Model

In [None]:
import torch
import warnings

from transformers import Wav2Vec2ForCTC

warnings.filterwarnings("ignore")
os.environ["WANDB_DISABLED"] = "true"
torch.autograd.set_detect_anomaly(True) # to crash in case of anomaly

model_str = "facebook/wav2vec2-base-960h"

processor = Wav2Vec2Processor.from_pretrained(model_str, do_normalize=True, feature_size=1, padding_value=0.0, return_attention_mask=True)
data_collator = DataCollatorCTC(processor=processor)

raw_dataset = load_movie2sub_data("/kaggle/input/movie2sub-dataset/dataset/train")
test_dataset = load_movie2sub_data("/kaggle/input/movie2sub-dataset/dataset/test")

dataset = prepare_dataset(raw_dataset, processor, test_size=0.2)
test_dataset = prepare_dataset(test_dataset, processor)

model = Wav2Vec2ForCTC.from_pretrained(
    model_str,
    vocab_size=len(processor.tokenizer),
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    ctc_zero_infinity=True,
)

wrapped_compute_metrics = partial(wer_metric, processor=processor)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor,
    data_collator=data_collator,
    compute_metrics=wrapped_compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

trainer.train()

In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt

log_history = trainer.state.log_history

train_loss = []
train_epochs = []
eval_loss = []
eval_epochs = []

for log in log_history:
    if "loss" in log and "epoch" in log:
        train_loss.append(log["loss"])
        train_epochs.append(log["epoch"])
    if "eval_loss" in log and "epoch" in log:
        eval_loss.append(log["eval_loss"])
        eval_epochs.append(log["epoch"])

plt.figure(figsize=(10, 5))
plt.plot(train_epochs, train_loss, label="Training Loss")
plt.plot(eval_epochs, eval_loss, label="Validation Loss", marker='o')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()

plt.savefig("./train_vs_val_loss.png", transparent=True)
plt.show()

In [None]:
import os
from torch.utils.data import DataLoader

test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=data_collator,
)

output_dir = "./predictions"
os.makedirs(output_dir, exist_ok=True)

for _, batch in enumerate(test_loader):
    with torch.no_grad():
        output = model(input_values=batch["input_values"].to(device))
    
        with processor.as_target_processor():
            ground_truths = processor.batch_decode(batch["labels"], group_tokens=False, skip_special_tokens=True)
            
        predicted_ids = torch.argmax(output.logits, dim=-1)
        decoded_texts = processor.batch_decode(predicted_ids, skip_special_tokens=True)
        
        for j, (gt, pred) in enumerate(zip(ground_truths, decoded_texts)):
            print(f"\tInput {j:02}")
            print(f"Ground truth: {gt}\n")
            print(f"Decoded text: {pred}\n")
    
            with open(os.path.join(output_dir, f"sample_{j}_gt.txt"), "w", encoding="utf-8") as f:
                f.write(gt)
    
            with open(os.path.join(output_dir, f"sample_{j}_pred.txt"), "w", encoding="utf-8") as f:
                f.write(pred)