### 1. Dependencies and imports.

!pip install torch torchaudio transformers datasets jiwer

In [None]:
import json
import torchaudio
import torch
from datasets import Dataset, load_dataset

### 2. Process dataset for training.

In [None]:
# Load JSONL dataset
dataset_path = "asr_dataset/synthetic_asr_data.jsonl"
data_entries = []

with open(dataset_path, "r") as f:
    for line in f:
        data_entries.append(json.loads(line))

# Convert to Hugging Face Dataset format
dataset = Dataset.from_list(data_entries)

# Preprocessing function
def preprocess(batch):
    waveform, sample_rate = torchaudio.load(batch["audio_path"])
    batch["input_values"] = waveform.squeeze(0).numpy()  # Convert to NumPy
    batch["labels"] = batch["bpe_tokens"]
    return batch

# Apply preprocessing
dataset = dataset.map(preprocess)

# Split into train/test sets
dataset = dataset.train_test_split(test_size=0.1)


### 3. Load Pretrained ASR model and tokenizer

In [None]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

# Load Wav2Vec2 tokenizer and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h", vocab_size=50257)  # Match GPT-2 vocab size

# Update model config to recognize LLM tokens
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = 50257  # Match GPT-2 vocabulary


### 4. Define Training Arguments

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./wav2vec2_asr",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-4,
    warmup_steps=500,
    logging_steps=100,
    save_total_limit=2,
    num_train_epochs=10,
    fp16=True,  # Use mixed precision for speed
    gradient_checkpointing=True,
    push_to_hub=False
)


### 5. Fine-Tune the ASR Model

In [None]:
from transformers import Trainer, TrainingArguments

# Data Collator
def data_collator(batch):
    input_values = [torch.tensor(item["input_values"]) for item in batch]
    labels = [torch.tensor(item["labels"]) for item in batch]

    return {"input_values": torch.stack(input_values), "labels": torch.stack(labels)}

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.feature_extractor,
    data_collator=data_collator
)

# Start training
trainer.train()


### 6. Save and Test the Model

In [None]:
# Save fine-tuned model
trainer.save_model("fine_tuned_wav2vec2")

# Load and test on a sample
processor = Wav2Vec2Processor.from_pretrained("fine_tuned_wav2vec2")
model = Wav2Vec2ForCTC.from_pretrained("fine_tuned_wav2vec2")

# Load a sample audio file
waveform, sample_rate = torchaudio.load("asr_dataset/audio_0.wav")

# Process input
input_values = processor(waveform.squeeze(0), return_tensors="pt", sampling_rate=sample_rate).input_values

# Generate prediction
with torch.no_grad():
    logits = model(input_values).logits

# Decode to BPE tokens
predicted_ids = torch.argmax(logits, dim=-1)
decoded_tokens = processor.batch_decode(predicted_ids)

print("Predicted BPE Tokens:", decoded_tokens)
