Llama-3.1-8B-Instruct (Full Fine-Tuning) + Whisper Small + Medical KG (Semantic + Phonetic)
MedSpeak: A Knowledge Graph-Aided ASR Error Correction Framework for Spoken Medical QA
Spoken medical question answering (SQA) pipelines typically rely on ASR transcripts. However, medical terminology is frequently misrecognized, which propagates errors into downstream QA. MedSpeak addresses this by injecting medical knowledge graph evidence—including both semantic relations and phonetic similarity—into an LLM that jointly performs transcript correction and MCQ answering.
Core idea: retrieve KG evidence conditioned on noisy ASR + answer options, then prompt a (fine-tuned) LLM to generate:
- Corrected transcript, and
- Selected answer option (A/B/C/D)
End-to-end pipeline:
- Build Knowledge Graph (SQLite + phonetic JSONL) from CSVs
- Build SFT JSONL with KG snippets (semantic + phonetic evidence)
- Full fine-tune Llama-3.1-8B-Instruct (budgeted KG context)
- Inference: Whisper Small ASR → KG retrieval → LLM joint correction + MCQ answer
- Compare multiple inference modes:
- Whisper → Non-Fine-Tuned LLM
- Whisper → Fine-Tuned LLM
- GT text → Non-Fine-Tuned / Fine-Tuned LLM
- (Optional) enable/disable KG for ablations
- Evaluation: WER + QA accuracy
Figure 1: MedSpeak overview (Whisper → KG retrieval → LLM joint correction + QA).
scripts/download_and_prepare_benchmarks.py
Download public benchmarks, optionally synthesize audio (TTS), buildmanifest.csv.utils/csv_to_audio_manifest.py
Buildmanifest.csvfrom your CSVs; run Whisper transcription; generate audio if needed.scripts/prepare_kg.py
Build semantic KG (SQLite) + phonetic KG (JSONL).scripts/build_training_jsonl.py
Build SFT JSONL with KG evidence snippets.scripts/convert_train_jsonl.py
Optional: adjust system context to reduce repeated outputs.scripts/full_finetuning.py
Full fine-tuning for Llama-3.1-8B-Instruct with KG budgets.scripts/infer_medspeak_allmodes.py
Multi-GPU server inference + shard merge for all evaluation modes.
conda create -n medspeak python=3.10 -y
conda activate medspeak
pip install -r requirements.txtThis script downloads benchmarks from the internet and can create WAVs + a unified manifest.
python scripts/download_and_prepare_benchmarks.py \
--mmlu_limit 0 \
--medmcqa_limit 0 \
--medqa_limit 0 \
--tts auto \
--out_manifest data/qa/manifest.csvpython utils/csv_to_audio_manifest.py \
--mmlu_csv data/csv_files/mmlu_qa.csv \
--medqa_csv data/csv_files/medqa_qa.csv \
--medmcqa_csv data/csv_files/medmcqa_qa.csv \
--tts auto \
--transcribe_whisper small \
--manifest data/qa/manifest.csvReplace the CSV paths with your own files if needed.
python scripts/prepare_kg.py \
--phonetic_csv data/kg_csv/KG-phonetic.csv \
--rel_csv data/kg_csv/KG-RELATIONSHIP.csv \
--rel_csv2 data/kg_csv/SELECT_DISTINCT_t1_Term_AS_Term_Name__r.csv \
--kg_big_csv data/kg_csv/kg.csv \
--out_sqlite artifacts/kg_semantic.sqlite \
--out_phonetic artifacts/kg_phonetic.jsonlpython scripts/build_training_jsonl.py \
--manifest data/qa/manifest.csv \
--kg_sql artifacts/kg_semantic.sqlite \
--kg_phon artifacts/kg_phonetic.jsonl \
--out_jsonl data/qa/train.jsonlFull fine-tuning is GPU-intensive. Adjust epochs/batch/seq_len as needed.
export TRANSFORMERS_NO_TORCHVISION=1
CUDA_VISIBLE_DEVICES=1,2,4,5 \
python scripts/full_finetuning.py \
--base_model meta-llama/Meta-Llama-3.1-8B-Instruct \
--train_jsonl data/qa/train.jsonl \
--out_dir outputs/fullft-medspeak_hist_server_version \
--epochs 10 \
--batch_size 4 \
--grad_accum 8 \
--lr 5e-5 \
--max_seq_len 2048 \
--kg_sem_budget 600 \
--kg_phon_budget 300 \
--show_hist \
--top_longest 5 \
--full_finetuneThis unified entrypoint supports multiple evaluation modes and shard-based multi-GPU inference.
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \
python scripts/infer_medspeak_allmodes.py infer CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \
python scripts/infer_medspeak_allmodes.py infer CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \
python scripts/infer_medspeak_allmodes.py infer CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \
python scripts/infer_medspeak_allmodes.py infer CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \
python scripts/infer_medspeak_allmodes.py infer python scripts/infer_medspeak_allmodes.py merge \
--shard_dir runs/shards_medspeak_full \
--out runs/medspeak_full.jsonl \
--include_existing_finalWe report:
- WER (Word Error Rate) for ASR correction quality
- QA Accuracy for multiple-choice answering correctness
meta-llama/Meta-Llama-3.1-8B-Instructmay require access approval (HuggingFace / model license).- Full fine-tuning is compute heavy; for quick tests reduce:
--epochs,--batch_size,--max_seq_len, or use fewer GPUs.