In [None]:
# ⬛ SET‑UP ⬛
!pip install -q datasets transformers torchaudio jiwer accelerate
from datasets import load_dataset, Audio
from transformers import (Wav2Vec2ForCTC, Wav2Vec2Processor,
                          TrainingArguments, Trainer)
from jiwer import wer
import torch, json, os, re

DATA_DIR = "data/raw/toy_lang_dataset"       
MANIFEST = f"{DATA_DIR}/manifest_train.jsonl"
SAMPLE_RATE = 16_000
MODEL_NAME = "facebook/wav2vec2-base-960h"

# Helper: read JSONL → HF dataset
def jsonl_to_dataset(path):
    with open(path) as f:
        items = [json.loads(l) for l in f]
    ds = {
        "file": [os.path.join(DATA_DIR, it["file"]) for it in items],
        "transcription": [it["transcription"] for it in items]
    }
    ds = load_dataset("json", data_files={"train": path})["train"]
    ds = ds.cast_column("file", Audio(sampling_rate=SAMPLE_RATE))
    ds = ds.train_test_split(test_size=0.2, seed=42)
    return ds

ds = jsonl_to_dataset(MANIFEST)

# ⬛ TOKENIZER / PROCESSOR ⬛
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
def prepare(batch):
    audio = batch["file"]["array"]
    batch["input_values"] = processor(audio, sampling_rate=SAMPLE_RATE).input_values[0]
    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcription"]).input_ids
    return batch

ds = ds.map(prepare, remove_columns=ds["train"].column_names, num_proc=2)

# ⬛ MODEL + TRAINER ⬛
model = Wav2Vec2ForCTC.from_pretrained(
    MODEL_NAME,
    vocab_size=len(processor.tokenizer),
)

args = TrainingArguments(
    "models/asr",
    per_device_train_batch_size=8,
    evaluation_strategy="epoch",
    logging_steps=10,
    learning_rate=1e-4,
    num_train_epochs=5,
    fp16=True,
    save_total_limit=1
)

def compute_metrics(pred):
    pred_ids = torch.argmax(torch.tensor(pred.predictions), dim=-1)
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    return {"wer": wer(label_str, pred_str)}

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=processor.feature_extractor,
    compute_metrics=compute_metrics,
)
trainer.train()

# ⬛ EVALUATE & SAVE ⬛
metrics = trainer.evaluate()
print(metrics)
with open("results/asr_metrics.json", "w") as f:
    json.dump(metrics, f, indent=2)
model.save_pretrained("models/asr")
processor.save_pretrained("models/asr")
