In [None]:
!pip install -q transformers datasets evaluate soundfile librosa pandas tqdm torch jiwer peft seaborn matplotlib


In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import LoraConfig, get_peft_model
import torch
from google.colab import drive
import os
import glob
import json
import pandas as pd
import librosa
import soundfile as sf
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import evaluate
from jiwer import wer, cer

# Google Drive mount
from google.colab import drive
drive.mount('/content/drive')

# Load audio into Whisper input features
def load_audio_for_whisper(path, processor):
    audio, sr = librosa.load(path, sr=16000)
    return processor(audio, sampling_rate=16000, return_tensors="pt").input_features

# Path to held out + rated samples
dataset_folder = "/content/drive/My Drive/capstone/held_out_data_with_ratings"

# LoRA checkpoint path
lora_checkpoint = "/content/drive/MyDrive/whisper_lora_epoch1.pt"

# Notebook ID for results
NOTEBOOK_ID = "1xm-qp_aGvQw0vkTj9wCq4spsoeI8LYlL"


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Base model
base_model_name = "openai/whisper-small"
base_model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
processor = WhisperProcessor.from_pretrained(base_model_name)

# Load LoRA checkpoint
state_dict = torch.load(lora_checkpoint, map_location="cpu")
vocab_ckpt = state_dict['base_model.model.model.decoder.embed_tokens.weight'].shape[0]
vocab_base = base_model.config.vocab_size

# Resize embeddings BEFORE PEFT
if vocab_base != vocab_ckpt:
    base_model.model.decoder.embed_tokens = torch.nn.Embedding(vocab_ckpt, base_model.config.d_model)
    base_model.model.proj_out = torch.nn.Linear(base_model.config.d_model, vocab_ckpt)

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)
lora_model = get_peft_model(base_model, peft_config)

lora_model.load_state_dict(state_dict, strict=False)
lora_model = lora_model.to(DEVICE)
lora_model.eval()
print("LoRA model loaded successfully!")


In [None]:

def compute_metrics_whisper(dataset_folder, model=None, processor=None, model_name_or_path=None):
    """
    Only the first disease per file is considered, and patient ID is extracted
    from the filename.
    Saves a JSON file with results using the notebook ID.

    model: optional, pre-loaded Whisper model (e.g., LoRA-attached)
    processor: optional, corresponding WhisperProcessor
    model_name_or_path: str, used if model/processor not provided
    """
    NOTEBOOK_ID = "1xm-qp_aGvQw0vkTj9wCq4spsoeI8LYlL"
    BATCH_SIZE = 11
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    if model is None or processor is None:
        if model_name_or_path is None:
            raise ValueError("Either model/processor or model_name_or_path must be provided.")
        print("Using device:", DEVICE)
        print("Using model/path:", model_name_or_path)
        processor = WhisperProcessor.from_pretrained(model_name_or_path)
        model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(DEVICE)

    model.eval()

    # Load metadata
    metadata_path = os.path.join(dataset_folder, "held_out_data_with_ratings_metadata.json")

    with open(metadata_path, "r") as f:
        metadata = json.load(f)

    audio_files = sorted([os.path.join(dataset_folder, f)
                          for f in os.listdir(dataset_folder)
                          if f.lower().endswith(".wav")])
    if not audio_files:
        raise ValueError("No WAV files found in dataset folder!")

    wer_metric = evaluate.load("wer")
    cer_metric = evaluate.load("cer")
    results = []

    for i in tqdm(range(0, len(audio_files), BATCH_SIZE)):
        batch_files = audio_files[i:i+BATCH_SIZE]
        batch_inputs = []

        for path in batch_files:
            audio, sr = librosa.load(path, sr=16000)
            feats = processor(audio, sampling_rate=16000, return_tensors="pt").input_features
            batch_inputs.append(feats)

        batch_inputs = torch.cat(batch_inputs).to(DEVICE)

        with torch.no_grad():
            pred_ids = model.generate(inputs=batch_inputs)
        batch_texts = processor.batch_decode(pred_ids, skip_special_tokens=True)

        for path, pred_text in zip(batch_files, batch_texts):
            filename = os.path.basename(path)


            meta_entry = next((m for m in metadata if m.get("Filename") == filename), None)
            if meta_entry is None:
                raise KeyError(f"No metadata entry found for {filename}")

            ref_text = meta_entry["Prompt"]["Transcript"].strip().lower()
            file_wer = wer_metric.compute(predictions=[pred_text.lower()], references=[ref_text])
            file_cer = cer_metric.compute(predictions=[pred_text.lower()], references=[ref_text])

            ratings = meta_entry.get("Ratings", [])

            symptoms_dict = {r["Dimension Description"]: r["Level"] for r in ratings if r.get("Dimension Category Description")}
            first_disease = ratings[0]["Dimension Category Description"] if ratings and ratings[0].get("Dimension Category Description") else None
            patient_id = filename.split("-")[0]

            results.append({
                "Filename": filename,
                "Patient_ID": patient_id,
                "Disease": first_disease,
                "Symptoms": symptoms_dict,
                "Reference": ref_text,
                "Prediction": pred_text,
                "WER": file_wer,
                "CER": file_cer
            })

    df = pd.DataFrame(results)
    overall_wer = wer_metric.compute(predictions=df["Prediction"].str.lower(), references=df["Reference"].str.lower())
    overall_cer = cer_metric.compute(predictions=df["Prediction"].str.lower(), references=df["Reference"].str.lower())
    print("\n==== Overall Metrics ====")
    print("Overall WER:", overall_wer)
    print("Overall CER:", overall_cer)

    save_path = os.path.join(dataset_folder, f"metrics_results_{NOTEBOOK_ID}_{os.path.basename(model_name_or_path) if model_name_or_path else 'custom_model'}.json")
    df.to_json(save_path, orient="records", indent=4)
    print(f"Saved results â†’ {save_path}")

    return df


In [None]:
def plot_results(df, top_symptoms=10):
    sns.set(style="whitegrid")

    # WER & CER distributions
    plt.figure(figsize=(12,4))
    sns.histplot(df["WER"], bins=20, kde=True, color="skyblue")
    plt.title("Distribution of WER")
    plt.show()

    plt.figure(figsize=(12,4))
    sns.histplot(df["CER"], bins=20, kde=True, color="salmon")
    plt.title("Distribution of CER")
    plt.show()

    # WER/CER by Disease
    if df["Disease"].nunique() > 0:
        plt.figure(figsize=(10,4))
        sns.barplot(x="Disease", y="WER", data=df)
        plt.title("Average WER by Disease")
        plt.show()

        plt.figure(figsize=(10,4))
        sns.barplot(x="Disease", y="CER", data=df)
        plt.title("Average CER by Disease")
        plt.show()

    # WER by Patient
    plt.figure(figsize=(14,4))
    sns.barplot(x="Patient_ID", y="WER", data=df)
    plt.title("WER per Patient")
    plt.xticks(rotation=90)
    plt.show()

    # Symptoms vs WER (top N for readability)
    all_symptoms = []
    for idx, row in df.iterrows():
        for sym, lvl in row["Symptoms"].items():
            all_symptoms.append({"Symptom": sym, "Level": lvl, "WER": row["WER"], "CER": row["CER"]})

    if all_symptoms:
        sym_df = pd.DataFrame(all_symptoms)
        top_symptom_names = sym_df["Symptom"].value_counts().head(top_symptoms).index
        sym_df_top = sym_df[sym_df["Symptom"].isin(top_symptom_names)]

        plt.figure(figsize=(12,5))
        sns.boxplot(x="Symptom", y="WER", hue="Level", data=sym_df_top)
        plt.title(f"WER by Top {top_symptoms} Symptoms and Level")
        plt.xticks(rotation=45)
        plt.show()


In [None]:
dataset_folder = "/content/drive/My Drive/capstone/held_out_data_with_ratings"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Base
base_model_name = "openai/whisper-small"
df_base = compute_metrics_whisper(dataset_folder, model_name_or_path=base_model_name)

# LoRA tuned model
#  Load base model & processor
base_model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
processor = WhisperProcessor.from_pretrained(base_model_name)

# Attach LoRA
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)
lora_model = get_peft_model(base_model, peft_config)

# Load LoRA checkpoint weights
state_dict = torch.load("/content/drive/MyDrive/whisper_lora_epoch1.pt", map_location="cpu")
lora_model.load_state_dict(state_dict, strict=False)
lora_model.to(DEVICE)
lora_model.eval()

# Compute metrics using the loaded LoRA model
df_lora = compute_metrics_whisper(dataset_folder, model=lora_model, processor=processor)


In [None]:
# For base model
plot_results(df_base)

# For LoRA model
plot_results(df_lora)


In [None]:
def plot_symptom_histograms(df):
    sns.set(style="whitegrid")

    # Flatten all symptoms
    all_symptoms = []
    for idx, row in df.iterrows():
        for sym, lvl in row["Symptoms"].items():
            all_symptoms.append({"Symptom": sym, "Level": lvl, "WER": row["WER"]})

    if not all_symptoms:
        print("No symptoms found in dataset!")
        return

    sym_df = pd.DataFrame(all_symptoms)
    unique_symptoms = sym_df["Symptom"].unique()

    for symptom in unique_symptoms:
        symptom_data = sym_df[sym_df["Symptom"] == symptom]

        plt.figure(figsize=(8,5))
        levels = symptom_data["Level"].unique()
        for lvl in levels:
            lvl_data = symptom_data[symptom_data["Level"] == lvl]["WER"]
            plt.hist(lvl_data, bins=10, alpha=0.6, label=f"Level {lvl}")

        plt.title(f"WER Histogram for Symptom: {symptom}")
        plt.xlabel("WER")
        plt.ylabel("Count")
        plt.legend()
        plt.show()


In [None]:
# For the base model
plot_symptom_histograms(df_base)

# For the LoRA-tuned model
plot_symptom_histograms(df_lora)