In [None]:
pip install torch torchaudio transformers peft torchcodec evaluate jiwer

In [None]:

import os
os.environ["TORCHCODEC_DISABLE"] = "1"

In [None]:
import os
import json
import tarfile

# ---- CONFIG ----
DATA_ARCHIVE = "./digital_assistant_prompt_samples.tar"  # path to your .tar
DATA_DIR = "./digital_assistant_prompt_samples"           # extraction folder
META_FILE_NAME = "digital_assistant_metadata.json"
SAMPLING_RATE = 16_000  # typical for Whisper-small

# ---- EXTRACT TAR IF NOT ALREADY DONE ----
if not os.path.exists(DATA_DIR):
    print(f"Extracting {DATA_ARCHIVE} to {DATA_DIR}...")
    with tarfile.open(DATA_ARCHIVE, "r") as tar:
        tar.extractall(path=os.path.dirname(DATA_DIR))
    print("Extraction complete.")
else:
    print(f"{DATA_DIR} already exists, skipping extraction.")

    
META_FILE = os.path.join(DATA_DIR, "digital_assistant_metadata.json")
SAMPLING_RATE = 16_000

# Load metadata
with open(META_FILE, "r") as f:
    metadata = json.load(f)

# Build list of (audio_path, transcript) pairs
data = []
for item in metadata:
    filename = item.get("Filename")
    transcript = item.get("Prompt", {}).get("Transcript", "")
    if not filename or not transcript:
        continue
    audio_path = os.path.join(DATA_DIR, filename)
    if os.path.exists(audio_path):
        data.append((audio_path, transcript))

print(f"Found {len(data)} usable audio files")


In [None]:
import torchaudio
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperProcessor

BATCH_SIZE = 4
TEST_SPLIT = 0.1
SEED = 42

# ---- 3. Shuffle and split into train/test ----
import random

random.seed(SEED)
shuffled_data = data.copy()
random.shuffle(shuffled_data)

split_idx = int(len(shuffled_data) * (1 - TEST_SPLIT))
train_data = shuffled_data[:split_idx]
test_data = shuffled_data[split_idx:]

# ---- 4. Define PyTorch Dataset ----
class AudioTextDataset(Dataset):
    def __init__(self, data_list, processor, sampling_rate=SAMPLING_RATE):
        self.data_list = data_list
        self.processor = processor
        self.sr = sampling_rate

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        audio_path, transcript = self.data_list[idx]

        # Load audio and resample if necessary
        waveform, sr = sf.read(audio_path)
        waveform = torch.tensor(waveform).float()
        
        # If stereo → average or take first channel
        if waveform.ndim > 1:
            waveform = waveform.mean(dim=1)
        if sr != self.sr:
            waveform = torchaudio.functional.resample(waveform, sr, self.sr)
        waveform = waveform.squeeze(0).numpy()  # remove channel dim

        # Feature extraction
        input_features = self.processor.feature_extractor(
            waveform,
            sampling_rate=self.sr
        ).input_features[0]  # single array

        # Tokenize transcript
        labels = self.processor.tokenizer(
            transcript,
            add_special_tokens=True
        ).input_ids

        return {"input_features": torch.tensor(input_features, dtype=torch.float32),
                "labels": torch.tensor(labels, dtype=torch.long)}

# ---- 5. Initialize processor ----
processor = WhisperProcessor.from_pretrained("openai/whisper-small")

# ---- 6. Create train/test datasets ----
train_dataset = AudioTextDataset(train_data, processor)
test_dataset = AudioTextDataset(test_data, processor)


# ---- 7. Create DataLoaders ----
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: {
    "input_features": torch.nn.utils.rnn.pad_sequence([item["input_features"] for item in x], batch_first=True),
    "labels": torch.nn.utils.rnn.pad_sequence([item["labels"] for item in x], batch_first=True, padding_value=processor.tokenizer.pad_token_id)
})
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: {
    "input_features": torch.nn.utils.rnn.pad_sequence([item["input_features"] for item in x], batch_first=True),
    "labels": torch.nn.utils.rnn.pad_sequence([item["labels"] for item in x], batch_first=True, padding_value=processor.tokenizer.pad_token_id)
})

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")


In [None]:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from peft import LoraConfig, get_peft_model


DEVICE = "cuda" 
# ---- 7. Load Whisper-small model ----
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.to(DEVICE)

# ---- 8. Apply LoRA (no quantization) ----
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(model, config)
model.print_trainable_parameters()


In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": f["input_features"]} for f in features]
        labels_batch = [{"input_ids": f["labels"]} for f in features]

        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        labels = self.processor.tokenizer.pad(labels_batch, return_tensors="pt")["input_ids"]

        batch["labels"] = labels.masked_fill(labels == self.processor.tokenizer.pad_token_id, -100)
        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
from types import MethodType

# Patch forward pass to accept input features
def patched_forward(self, input_features=None, **kwargs):
    return self.model.forward(input_features=input_features, **kwargs)

model.forward = MethodType(patched_forward, model)


In [None]:
batch = data_collator([train_dataset[0], train_dataset[1]])
batch = {k: v.to(model.device) for k, v in batch.items()}

outputs = model(
    input_features=batch["input_features"],
    labels=batch["labels"]
)

In [None]:
import os

# ---- 9. Define optimizer ----
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# ---- 10. Training loop with progress prints and checkpointing ----
EPOCHS = 3  # adjust as needed
SAVE_DIR = "./checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

PRINT_EVERY = 10        # print every N batches
CHECKPOINT_EVERY = 100  # save checkpoint every N batches

global_step = 0  # track batches across epochs

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    for i, batch in enumerate(train_loader, start=1):
        global_step += 1

        input_features = batch["input_features"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(input_features=input_features, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # ---- PRINT PROGRESS ----
        if i % PRINT_EVERY == 0:
            avg_loss = running_loss / PRINT_EVERY
            print(f"[Epoch {epoch+1} Batch {i}/{len(train_loader)}] Avg Loss: {avg_loss:.4f}")
            running_loss = 0.0

        # ---- SAVE CHECKPOINT EVERY 100 BATCHES ----
        if global_step % CHECKPOINT_EVERY == 0:
            checkpoint_path = os.path.join(SAVE_DIR, f"checkpoint_step{global_step}.pt")
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Saved checkpoint at step {global_step} → {checkpoint_path}")

    # ---- SAVE END-OF-EPOCH CHECKPOINT ----
    epoch_ckpt = os.path.join(SAVE_DIR, f"whisper_lora_epoch{epoch+1}.pt")
    torch.save(model.state_dict(), epoch_ckpt)
    print(f"Epoch {epoch+1} finished. Checkpoint saved to {epoch_ckpt}")

# ---- 11. Optional: evaluation with incremental prints ----
model.eval()
eval_loss = 0.0

with torch.no_grad():
    for i, batch in enumerate(test_loader, start=1):
        input_features = batch["input_features"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(input_features=input_features, labels=labels)
        loss = outputs.loss

        eval_loss += loss.item()

        if i % PRINT_EVERY == 0:
            avg_loss = eval_loss / i
            print(f"[Eval Batch {i}/{len(test_loader)}] Avg Loss so far: {avg_loss:.4f}")

final_eval_loss = eval_loss / len(test_loader)
print(f"Final Evaluation Loss: {final_eval_loss:.4f}")
