In [7]:
import os
os.chdir("/workspace/PyHealth")

In [12]:
# !pip install -r requirements.txt
# !pip install tqdm pydantic typing_extensions>=4.5.0

## This script loads in the medical transcription dataset, defines the prompt template for labeling, initializes uncertainty aware, and runs batch predications

In [None]:
from pyhealth.datasets import MedicalTranscriptionsDataset
from pyhealth.models.uncertainty_aware_transformer import UncertaintyAwareZeroShotClassifier

ds = MedicalTranscriptionsDataset(root=".")
ds = ds.set_task()               # any classification task works

prompt = (
            "System: You are a medical text classifier.\n"
            "User: Classify the following clinical note into **one** most likely matching label "
            "from the list below and reply with only that label and nothing else.\n\n"
            "Possible labels:\n{labels}\n\n"
            "Clinical note:\n{text}\n\n"
            "Answer:"
        )

uazs = UncertaintyAwareZeroShotClassifier(ds, model_name="Qwen/Qwen2.5-3B-Instruct", batch_size=16, cache_dir = "./.hf_cache", prompt_template = prompt)

uazs.predict_dataset(max_batches=100)


No config path provided, using default config
Initializing medical_transcriptions dataset from . (dev mode: False)
Scanning table: mtsamples from /workspace/PyHealth/mtsamples.csv
Setting task MedicalTranscriptionsClassification for medical_transcriptions base dataset...
Collecting global event dataframe...
Collected dataframe with shape: (4999, 8)
Generating samples with 8 worker(s)...
Generating samples for MedicalTranscriptionsClassification
Label medical_specialty vocab: {' Allergy / Immunology': 0, ' Autopsy': 1, ' Bariatrics': 2, ' Cardiovascular / Pulmonary': 3, ' Chiropractic': 4, ' Consult - History and Phy.': 5, ' Cosmetic / Plastic Surgery': 6, ' Dentistry': 7, ' Dermatology': 8, ' Diets and Nutritions': 9, ' Discharge Summary': 10, ' ENT - Otolaryngology': 11, ' Emergency Room Reports': 12, ' Endocrinology': 13, ' Gastroenterology': 14, ' General Medicine': 15, ' Hematology - Oncology': 16, ' Hospice - Palliative Care': 17, ' IME-QME-Work Comp etc.': 18, ' Lab Medicine - Pa

Processing samples: 100%|██████████| 4966/4966 [00:00<00:00, 117494.93it/s]

Generated 4966 samples for task MedicalTranscriptionsClassification



Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Zero-shot generate:  49%|████▉     | 49/100 [02:35<02:41,  3.17s/it]

## This retrieves model prediction and computes intial accuracy 

In [2]:
results = uazs.results
results_acc = sum(x["prediction"] == x["gt"] for x in results) / len(results)

print(f"Initial Accuracy: {results_acc:.4f})")

Initial Accuracy: 0.2150)


## This section selects top 10% most uncertain predictions, computes accuracy on uncertain ones and remaining ones

In [3]:
frac = .1
top_samples, remaining_samples = uazs.get_uncertain(k=frac) # % most uncertain

# Accuracy for top uncertain samples
top_acc = sum(x["prediction"] == x["gt"] for x in top_samples) / len(top_samples)

# Accuracy for remaining samples
rest_acc = sum(x["prediction"] == x["gt"] for x in remaining_samples) / len(remaining_samples)

print(f"Top-k Uncertain Accuracy: {top_acc:.4f}")
print(f"Remaining Accuracy: {rest_acc:.4f}")

Top-k Uncertain Accuracy: 0.1938
Remaining Accuracy: 0.2174


In [4]:
random_dataset_subset, random_rest, sm_idx, bg_idx = uazs.get_random_datasets(k=frac) # randomly sampled for control
uncertain_dataset, uncertain_rest = uazs.get_uncertain_datasets(k=frac) # % most uncertain

# Control:

## Initialize new uncertainty aware classifier on random control subset, calculate accuracy on cheaper and more expensive model, get the blended accuracy

In [5]:
uazs_random = UncertaintyAwareZeroShotClassifier(random_dataset_subset, model_name="Qwen/Qwen2.5-7B-Instruct", batch_size=4, cache_dir = "./.hf_cache", prompt_template = prompt)

uazs_random.predict_dataset()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Zero-shot generate: 100%|██████████| 40/40 [00:41<00:00,  1.05s/it]


In [6]:
results_random = uazs_random.results

results_acc_random = sum(x["prediction"] == x["gt"] for x in results_random) / len(results_random)
rest_acc_random  = sum(r["prediction"] == r["gt"] for r in bg_idx) / len(bg_idx)

print(f"Random Subset Expensive Model Accuracy: {results_acc_random:.4f}")
print(f"Remainder Cheap Model Accuracy: {rest_acc_random:.4f}")

Random Subset Expensive Model Accuracy: 0.3063
Remainder Cheap Model Accuracy: 0.2104


In [7]:
subset_idx = {r["index"] for r in uazs_random.results}

correct_big   = sum(r["prediction"] == r["gt"] for r in uazs_random.results)
correct_rest  = sum(r["prediction"] == r["gt"] for r in bg_idx)

blended_acc = (correct_big + correct_rest) / len(uazs.results)
print(f"Blended accuracy: {blended_acc:.4f}")

Blended accuracy: 0.2200


# Uncertainty Aware:

## Initialize new uncertainty aware classifier on uncertain dateset, runs cheaper and more expensive model, get blended accuracy  

In [8]:
uazs_big = UncertaintyAwareZeroShotClassifier(uncertain_dataset, model_name="Qwen/Qwen2.5-7B-Instruct", batch_size=8, cache_dir = "./.hf_cache", prompt_template = prompt)

uazs_big.predict_dataset()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Zero-shot generate: 100%|██████████| 20/20 [00:43<00:00,  2.19s/it]


In [9]:
results_big = uazs_big.results

results_acc_big = sum(x["prediction"] == x["gt"] for x in results_big) / len(results_big)

print(f"Original Model Accuracy on Originally Uncertain Samples: {top_acc:.4f}")
print(f"Expensive Model Accuracy on Originally Uncertain Samples: {results_acc_big:.4f}")

Original Model Accuracy on Originally Uncertain Samples: 0.1938
Expensive Model Accuracy on Originally Uncertain Samples: 0.3187


In [10]:
blended_acc = (results_acc_big * len(results_big) + rest_acc * len(remaining_samples)) / (len(results_big) + len(remaining_samples))
print(f"Blended Uncertainty Aware Accuracy: {blended_acc:.4f})")

Blended Uncertainty Aware Accuracy: 0.2275)
