In [20]:
import torch
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")

CUDA version: 12.6


In [2]:
!pip install -q transformers accelerate evaluate jiwer tensorboard audiomentations librosa soundfile

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.1/86.1 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.4/109.4 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m54.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m248.5/248.5 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
!unzip dataset_2_0.zip

Archive:  dataset_2_0.zip
   creating: downloaded_audio_fixed_to_mp3/
  inflating: downloaded_audio_fixed_to_mp3/3 meters forward.mp3  
  inflating: downloaded_audio_fixed_to_mp3/A bit more to the other side. Left by 6 meters.mp3  
  inflating: downloaded_audio_fixed_to_mp3/A bit more to the other side. Left by 6 meters_1.mp3  
  inflating: downloaded_audio_fixed_to_mp3/advance 6 meters forward.mp3  
  inflating: downloaded_audio_fixed_to_mp3/advance a little.mp3  
  inflating: downloaded_audio_fixed_to_mp3/advance forward 3 meters.mp3  
  inflating: downloaded_audio_fixed_to_mp3/Advance forward 3 meters_1.mp3  
  inflating: downloaded_audio_fixed_to_mp3/Advance forward 5 meters.mp3  
  inflating: downloaded_audio_fixed_to_mp3/Advance forward 5 meters_1.mp3  
  inflating: downloaded_audio_fixed_to_mp3/Advance forward 5 meters_4.mp3  
  inflating: downloaded_audio_fixed_to_mp3/Advance forward 5 meters_5.mp3  
  inflating: downloaded_audio_fixed_to_mp3/Advance forward 5 meters_6.mp3  
  

In [17]:
import os
import random
from pathlib import Path
import torch
import evaluate
from datasets import Dataset, DatasetDict, Audio
from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import audiomentations
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift
from pydub import AudioSegment
import numpy as np
import librosa
import soundfile as sf
from tqdm import tqdm

In [18]:


def load_data(src='dataset', dst='dataset_processed'):
    os.makedirs(dst, exist_ok=True)
    mp3s = [f for f in os.listdir(src) if f.endswith('.mp3')]
    print(f"Found {len(mp3s)} files")

    a_paths, texts = [], []
    for i, f in enumerate(tqdm(mp3s, desc="Converting")):
        txt = f.replace('.mp3', '')
        src_p = os.path.join(src, f)
        dst_p = os.path.join(dst, f"{i:04d}.wav")
        try:
            a = AudioSegment.from_mp3(src_p).set_frame_rate(16000).set_channels(1)
            a.export(dst_p, format='wav')
            a_paths.append(dst_p)
            texts.append(txt)
        except Exception as e:
            print(f"Skip {f}: {e}")
    return a_paths, texts

print("Loading data...")
a_paths, texts = load_data()
print("Processed: ", len(a_paths))

def split_data(a_paths, texts, t_size=0.15):
    data = list(zip(a_paths, texts))
    random.seed(42)
    random.shuffle(data)
    s = int(len(data) * (1 - t_size))
    tr, ts = data[:s], data[s:]
    ta, tt = zip(*tr) if tr else ([], [])
    va, vt = zip(*ts) if ts else ([], [])
    return list(ta), list(tt), list(va), list(vt)

print("Splitting data...")
tr_a, tr_t, ts_a, ts_t = split_data(a_paths, texts)
print(f"train: {len(tr_a)}, Test: {len(ts_a)}")

def augment_data(a_paths, texts, factor=2):
    aug = Compose([
        AddGaussianNoise(0.001, 0.015, p=0.5),
        TimeStretch(0.9, 1.1, p=0.5),
        PitchShift(-2, 2, p=0.5),
        Shift(-0.5, 0.5, "fraction", rollover=False, p=0.3),
    ])
    aug_a, aug_t = [], []
    for i, (p, t) in enumerate(tqdm(zip(a_paths, texts), total=len(a_paths), desc="Augmenting")):
        try:
            s, sr = librosa.load(p, sr=16000, mono=True)
        except:
            continue
        aug_a.append(p)
        aug_t.append(t)
        for j in range(factor - 1):
            s2 = aug(samples=s, sample_rate=16000)
            out = f"dataset_processed/aug_{i}_{j}.wav"
            sf.write(out, s2, 16000)
            aug_a.append(out)
            aug_t.append(t)
    return aug_a, aug_t

print("augmenting...")
tr_a, tr_t = augment_data(tr_a, tr_t)
print(f"after aug: {len(tr_a)} samples")

class WhisperDS(torch.utils.data.Dataset):
    def __init__(self, a_p, txt, proc):
        self.a_p, self.txt, self.p = a_p, txt, proc
    def __len__(self): return len(self.a_p)
    def __getitem__(self, i):
        x, _ = librosa.load(self.a_p[i], sr=16000, mono=True)
        inp = self.p.feature_extractor(x, sampling_rate=16000, return_tensors="pt").input_features[0]
        lbl = self.p.tokenizer(self.txt[i]).input_ids
        return {"input_features": inp, "labels": lbl}

m_id = "openai/whisper-tiny"
print("loading model ", m_id)
f_ext = WhisperFeatureExtractor.from_pretrained(m_id)
tok = WhisperTokenizer.from_pretrained(m_id, language="English", task="transcribe")
proc = WhisperProcessor.from_pretrained(m_id, language="English", task="transcribe")

tr_ds = WhisperDS(tr_a, tr_t, proc)
ts_ds = WhisperDS(ts_a, ts_t, proc)
print(f"train DS: {len(tr_ds)}, Test DS: {len(ts_ds)}")

@dataclass
class Collator:
    proc: any
    dec_start: int
    def __call__(self, feats):
        inp = [{"input_features": f["input_features"]} for f in feats]
        batch = self.proc.feature_extractor.pad(inp, return_tensors="pt")
        lbls = [{"input_ids": f["labels"]} for f in feats]
        l_batch = self.proc.tokenizer.pad(lbls, return_tensors="pt")
        y = l_batch["input_ids"].masked_fill(l_batch.attention_mask.ne(1), -100)
        if (y[:, 0] == self.dec_start).all().cpu().item(): y = y[:, 1:]
        batch["labels"] = y
        return batch

print("trainer setup...")
model = WhisperForConditionalGeneration.from_pretrained(m_id)
model.generation_config.language = "english"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

coll = Collator(proc, model.config.decoder_start_token_id)
metric = evaluate.load("wer")

def metrics(pred):
    p_ids, l_ids = pred.predictions, pred.label_ids
    l_ids[l_ids == -100] = tok.pad_token_id
    p_str = tok.batch_decode(p_ids, skip_special_tokens=True)
    l_str = tok.batch_decode(l_ids, skip_special_tokens=True)
    return {"wer": 100 * metric.compute(predictions=p_str, references=l_str)}


augmenting...


Augmenting: 100%|██████████| 1148/1148 [00:27<00:00, 41.42it/s]


after aug: 2296 samples
loading model  openai/whisper-tiny
train DS: 2296, Test DS: 203
trainer setup...


In [19]:
from transformers import EarlyStoppingCallback

In [14]:
import shutil
import os

folder_path = 'whisper-drone-command-final'
if os.path.exists(folder_path):
    shutil.rmtree(folder_path)
    print(f"Folder '{folder_path}' deleted successfully.")
else:
    print(f"Folder '{folder_path}' does not exist.")

Folder 'whisper-drone-command-final' deleted successfully.


In [21]:
args = Seq2SeqTrainingArguments(
    output_dir="./whisper-drone",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    warmup_steps=100,
    num_train_epochs=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    fp16=True,
    predict_with_generate=True,
    generation_max_length=225,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    save_total_limit=2,
    remove_unused_columns=False,
)

stop_cb = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)

trainer = Seq2SeqTrainer(
    args=args,
    model=model,
    train_dataset=tr_ds,
    eval_dataset=ts_ds,
    data_collator=coll,
    compute_metrics=metrics,
    tokenizer=proc.feature_extractor,
    callbacks=[stop_cb],
)

#  early stopping patience=3, threshold=1%
trainer.train()


  trainer = Seq2SeqTrainer(
You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Wer
1,1.9585,1.847484,40.913327
2,0.5087,0.511926,33.178006
3,0.3021,0.452119,31.500466
4,0.1895,0.462052,47.343896
5,0.1037,0.479424,32.059646
6,0.0844,0.480311,63.560112


There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=864, training_loss=0.8988283615283392, metrics={'train_runtime': 690.7721, 'train_samples_per_second': 33.238, 'train_steps_per_second': 2.085, 'total_flos': 3.3914976141312e+17, 'train_loss': 0.8988283615283392, 'epoch': 6.0})

In [29]:
trainer.save_model("./whisper-drone-final")
proc.save_pretrained("./whisper-drone-final")

[]

In [10]:
!zip -r whisper-drone-command-final.zip whisper-drone-command-final/
print("zipped")

  adding: whisper-drone-command-final/ (stored 0%)
  adding: whisper-drone-command-final/vocab.json (deflated 58%)
  adding: whisper-drone-command-final/training_args.bin (deflated 54%)
  adding: whisper-drone-command-final/tokenizer_config.json (deflated 96%)
  adding: whisper-drone-command-final/merges.txt (deflated 54%)
  adding: whisper-drone-command-final/model.safetensors (deflated 8%)
  adding: whisper-drone-command-final/preprocessor_config.json (deflated 44%)
  adding: whisper-drone-command-final/tokenizer.json (deflated 82%)
  adding: whisper-drone-command-final/added_tokens.json (deflated 80%)
  adding: whisper-drone-command-final/special_tokens_map.json (deflated 80%)
  adding: whisper-drone-command-final/config.json (deflated 60%)
  adding: whisper-drone-command-final/generation_config.json (deflated 71%)
  adding: whisper-drone-command-final/normalizer.json (deflated 81%)

Model zipped successfully!


In [23]:
import numpy as np, pandas as pd, torch, librosa, os, time
from tqdm import tqdm
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from jiwer import wer
import evaluate

In [30]:
model = WhisperForConditionalGeneration.from_pretrained("./whisper-drone-final")
proc = WhisperProcessor.from_pretrained("./whisper-drone-final")

dev = "cuda" if torch.cuda.is_available() else "cpu"
model.to(dev).eval()

print(f"Device: {dev}")
print(f"Samples: {len(test_audio)}\n")

preds, refs, times = [], [], []
print("Running inference...")

for i, (p, ref) in enumerate(tqdm(zip(test_audio, test_text), total=len(test_audio))):
    try:
        s, _ = librosa.load(p, sr=16000, mono=True)
        inp = proc.feature_extractor(s, sampling_rate=16000, return_tensors="pt").input_features.to(dev)
        t0 = time.time()
        with torch.no_grad():
            ids = model.generate(inp)
            txt = proc.batch_decode(ids, skip_special_tokens=True)[0]
        times.append(time.time() - t0)
        preds.append(txt)
        refs.append(ref)
    except Exception as e:
        print(f"⚠️ Error at sample {i}: {e}")
        continue

print("\n=== METRICS ===")
m = evaluate.load("wer")
wer_score = m.compute(predictions=preds, references=refs)
print(f"WER: {wer_score*100:.2f}%")
print(f"Avg time/sample: {np.mean(times):.3f}s  |  Total: {np.sum(times):.2f}s")

print("\n=== STATS ===")
correct = sum(p.strip().lower() == r.strip().lower() for p, r in zip(preds, refs))
acc = (correct / len(refs)) * 100
print(f"Perfect: {correct}/{len(refs)} ({acc:.2f}%)")
print(f"Errors: {len(refs) - correct}/{len(refs)}")

w_list = [wer(r, p) if r and p else 1.0 for p, r in zip(preds, refs)]
print(f"Mean: {np.mean(w_list)*100:.2f}% | Median: {np.median(w_list)*100:.2f}% | Min: {np.min(w_list)*100:.2f}% | Max: {np.max(w_list)*100:.2f}% | Std: {np.std(w_list)*100:.2f}%")

print("\n=== SAMPLE PREDICTIONS ===")
for i in range(min(10, len(preds))):
    wv = w_list[i] * 100
    print(f"\n{'✅' if wv==0 else '❌'} {i+1}. WER {wv:.1f}%")
    print(f"Ref: {refs[i]}")
    print(f"Pred: {preds[i]}")

print("\n=== WORST PREDICTIONS ===")
bad_idx = np.argsort(w_list)[-5:][::-1]
for r, i in enumerate(bad_idx, 1):
    print(f"\n{r}. {i+1} (WER {w_list[i]*100:.1f}%)")
    print(f"Ref: {refs[i]}")
    print(f"Pred: {preds[i]}")

print("\n=== MODERATE ERRORS ===")
mid = [(i, w) for i, w in enumerate(w_list) if 0 < w < 0.3]
if mid:
    for r, (i, w) in enumerate(sorted(mid, key=lambda x: x[1])[:5], 1):
        print(f"\n{r}. {i+1} (WER {w*100:.1f}%)")
        print(f"Ref: {refs[i]}")
        print(f"Pred: {preds[i]}")
else:
    print("None found.")

print("\n=== SAVING RESULTS ===")
df = pd.DataFrame({
    "id": range(1, len(preds)+1),
    "file": [os.path.basename(p) for p in test_audio],
    "ref": refs,
    "pred": preds,
    "wer": [w*100 for w in w_list],
    "time": times,
    "match": [p.strip().lower()==r.strip().lower() for p, r in zip(preds, refs)]
})
df.to_csv("test_results_detailed.csv", index=False)
print("Saved detailed results.")

summary = {
    "Samples": len(refs),
    "WER (%)": round(wer_score*100, 2),
    "Perfect": correct,
    "Acc (%)": round(acc, 2),
    "Mean WER (%)": round(np.mean(w_list)*100, 2),
    "Median WER (%)": round(np.median(w_list)*100, 2),
    "Std WER (%)": round(np.std(w_list)*100, 2),
    "Avg Time (s)": round(np.mean(times), 3),
    "Total Time (s)": round(np.sum(times), 2)
}
pd.DataFrame([summary]).to_csv("test_results_summary.csv", index=False)
print("Saved summary.\n\n=== EVALUATION DONE ===")

Device: cuda
Samples: 203

Running inference...


100%|██████████| 203/203 [00:14<00:00, 13.91it/s]



=== METRICS ===
WER: 31.50%
Avg time/sample: 0.056s  |  Total: 11.42s

=== STATS ===
Perfect: 68/203 (33.50%)
Errors: 135/203
Mean: 35.20% | Median: 27.27% | Min: 0.00% | Max: 150.00% | Std: 33.05%

=== SAMPLE PREDICTIONS ===

❌ 1. WER 40.0%
Ref: Turn drone right 3 meters_1
Pred:  turn drone right 3 meters

✅ 2. WER 0.0%
Ref: Drone, go up 6 meters
Pred: Drone, go up 6 meters

❌ 3. WER 100.0%
Ref: Drone go down 3 meters_2
Pred:  Go move backward now

❌ 4. WER 33.3%
Ref: Drone land now
Pred:  Drone, land now

❌ 5. WER 25.0%
Ref: Stop after 11 meters!
Pred:  Stop after 11 meters

❌ 6. WER 16.7%
Ref: shift right 0m from current position.m4a
Pred:  safe right 0m from current position.m4a

✅ 7. WER 0.0%
Ref: Proceed straight up 75 meters
Pred:  Proceed straight up 75 meters

❌ 8. WER 20.0%
Ref: Drone move down 1 meter
Pred:  drone move down 1 meter

❌ 9. WER 60.0%
Ref: descend down by 19 meter
Pred:  Descend down by 19m

✅ 10. WER 0.0%
Ref: Drone stop after 2 meters
Pred: Drone stop after 2