In [1]:
!pip install -q "transformers>=4.43.0" peft bitsandbytes accelerate


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login
from getpass import getpass


In [3]:
HF_TOKEN = getpass("Enter your HuggingFace token (read access): ")
login(HF_TOKEN)


In [18]:
base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(
    base_model_name,
    use_fast=True,
    token=HF_TOKEN,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto",
    token=HF_TOKEN,
)

base_model.eval()


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRM

In [None]:
adapter_repo_id = "smedara/llama3-med-easi-explainer" #can leave this as is, since it's public, but if you trained your own adapter, change to your adapter repo ID

model = PeftModel.from_pretrained(
    base_model,
    adapter_repo_id,
    token=HF_TOKEN,
)

model.eval()


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.

In [21]:
def build_explain_prompt(expert_text: str) -> str:
    system_message = (
        "You are a medical explainer. Your job is to rewrite the medical text below "
        "into a clear explanation for a patient.\n\n"
        "Rules:\n"
        "- Only use information found in the original text.\n"
        "- Do NOT add any new causes, treatments, or details.\n"
        "- Use short, simple sentences (8th-grade level).\n"
        "- Avoid medical jargon. If you must use it, briefly define it.\n"
        "- Be friendly and reassuring.\n"
    )

    user_message = (
        "Explain the following medical diagnosis to a patient:\n\n"
        f"{expert_text}\n\n"
        "Patient-friendly explanation:"
    )

    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
    ]

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    return prompt


In [22]:
def explain_for_patient(expert_text: str, max_new_tokens=150):
    prompt = build_explain_prompt(expert_text)

    enc = tokenizer(
        [prompt],
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512,
    ).to(model.device)

    input_ids = enc["input_ids"]
    input_len = input_ids.shape[1]

    with torch.no_grad():
        out = model.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            min_new_tokens=50,
            num_beams=4,
            length_penalty=1.1,
            no_repeat_ngram_size=3,
            early_stopping=True,
        )

    gen_ids = out[0, input_len:]
    decoded = tokenizer.decode(gen_ids, skip_special_tokens=True)
    return decoded.strip()


In [23]:
expert_example = (
    "The patient presents with acute decompensated systolic heart failure secondary "
    "to long-standing ischemic cardiomyopathy. Echocardiography shows a left "
    "ventricular ejection fraction of 20% with global hypokinesis and moderate "
    "functional mitral regurgitation. He also has poorly controlled hypertension "
    "and stage 3 chronic kidney disease, which further exacerbate volume overload "
    "and limit the use of certain guideline-directed medical therapies."
)

print("EXPERT TEXT\n", expert_example)
print("MODEL EXPLANATION\n")
print(explain_for_patient(expert_example))


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


EXPERT TEXT
 The patient presents with acute decompensated systolic heart failure secondary to long-standing ischemic cardiomyopathy. Echocardiography shows a left ventricular ejection fraction of 20% with global hypokinesis and moderate functional mitral regurgitation. He also has poorly controlled hypertension and stage 3 chronic kidney disease, which further exacerbate volume overload and limit the use of certain guideline-directed medical therapies.
MODEL EXPLANATION

The patient has a history of heart failure, high blood pressure, and kidney disease. His heart is not pumping well, and his kidneys are not working as well as they should. These conditions make it harder to use some of the treatments that are usually recommended for heart failure.


Evaluation on Med-EASi test dataset

In [24]:
from datasets import load_dataset

#load Med-EASi dataset
ds = load_dataset("cbasu/Med-EASi")
print(ds)
print("Train sample:\n", ds["train"][0])

DatasetDict({
    train: Dataset({
        features: ['Expert', 'Simple', 'Annotation', 'sim', 'sentence_sim', 'compression', 'expert_fk_grade', 'expert_ari', 'layman_fk_grade', 'layman_ari', 'umls_expert', 'umls_layman', 'expert_terms', 'layman_terms', 'idx'],
        num_rows: 1397
    })
    validation: Dataset({
        features: ['Expert', 'Simple', 'Annotation', 'sim', 'sentence_sim', 'compression', 'expert_fk_grade', 'expert_ari', 'layman_fk_grade', 'layman_ari', 'umls_expert', 'umls_layman', 'expert_terms', 'layman_terms', 'idx'],
        num_rows: 196
    })
    test: Dataset({
        features: ['Expert', 'Simple', 'Annotation', 'sim', 'sentence_sim', 'compression', 'expert_fk_grade', 'expert_ari', 'layman_fk_grade', 'layman_ari', 'umls_expert', 'umls_layman', 'expert_terms', 'layman_terms', 'idx'],
        num_rows: 300
    })
})
Train sample:
 {'Expert': '75-90 % of the affected people have mild intellectual disability.', 'Simple': "People with syndromic intellectual disabi

In [25]:
test_ds = ds["test"]



In [13]:
!pip install -q evaluate
!pip install -q sacremoses sacrebleu
!pip install -q rouge_score
!pip intall -q tqdm

ERROR: unknown command "intall" - maybe you meant "install"


In [12]:
import evaluate
sari_metric = evaluate.load("sari")
bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

In [26]:
from tqdm.auto import tqdm

sources = []
references = []
predictions = []

for example in tqdm(test_ds, desc="Generating explanations"):
    expert = example["Expert"]
    simple = example["Simple"]

    pred = explain_for_patient(expert, max_new_tokens=150)

    sources.append(expert)
    references.append(simple)
    predictions.append(pred)

Generating explanations:   0%|          | 0/300 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for

In [27]:

sari_metric = evaluate.load("sari")
bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

In [28]:
sari_result = sari_metric.compute(
    sources=sources,
    predictions=predictions,
    references=[[r] for r in references],
)

print("SARI:", sari_result)

bleu_result = bleu_metric.compute(
    predictions=predictions,
    references=[[r] for r in references],
)

print("BLEU:", bleu_result)

rouge_result = rouge_metric.compute(
    predictions=predictions,
    references=references,
)

print("ROUGE:", rouge_result)



SARI: {'sari': 38.24646861536587}
BLEU: {'bleu': 0.042200231904256924, 'precisions': [0.2425221267523742, 0.07003096383160946, 0.021103568788225016, 0.00884834350778517], 'brevity_penalty': 1.0, 'length_ratio': 1.8728372655777374, 'translation_length': 15479, 'reference_length': 8265}
ROUGE: {'rouge1': np.float64(0.28946684009831014), 'rouge2': np.float64(0.09242298307883069), 'rougeL': np.float64(0.21672256435524395), 'rougeLsum': np.float64(0.21682074843098142)}


In [29]:
!pip install -q textstat

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.4/176.4 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m52.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h

In [30]:
import textstat
fk_expert = []
fk_reference = []
fk_model = []

for example in tqdm(test_ds, desc="Computing FK scores"):
    expert = example["Expert"]
    simple = example["Simple"]

    fk_expert.append(textstat.flesch_kincaid_grade(expert))
    fk_reference.append(textstat.flesch_kincaid_grade(simple))

fk_model = [textstat.flesch_kincaid_grade(p) for p in predictions]

Computing FK scores:   0%|          | 0/300 [00:00<?, ?it/s]

In [31]:
import pandas as pd

df = pd.DataFrame({
    "expert_text": [ex["Expert"] for ex in test_ds],
    "simple_reference": [ex["Simple"] for ex in test_ds],
    "model_output": predictions,
    "fk_expert": fk_expert,
    "fk_reference": fk_reference,
    "fk_model": fk_model if len(fk_model) else None,
})

df.head()


Unnamed: 0,expert_text,simple_reference,model_output,fk_expert,fk_reference,fk_model
0,Intervention for obese adolescents should be f...,The treatment of adolescent obesity is focused...,"For obese adolescents, the focus should be on ...",15.953913,15.953913,11.651596
1,"The liver may be enlarged, hard, or tender; ma...","Typically, the liver is enlarged and hard. It ...",The liver is often enlarged and may feel hard ...,12.557647,4.84,6.255556
2,"Frequency, urgency, and nocturia are due to in...","At first, men may have difficulty starting uri...","The bladder does not empty completely, so the ...",13.4275,8.763333,7.222341
3,Desmopressin,"Sometimes, the drug desmopressin",Nose drops (nasal decongestants) such as oxyme...,8.4,3.67,12.407778
4,"Some patients have weight loss, rarely enough ...","Some people are undernourished, have mild weig...",Some people with celiac disease lose weight be...,11.981957,8.474444,11.436418


In [33]:
df.to_csv("llama3_med_easi_output.csv", index=True)

In [34]:
!ls

llama3_med_easi_output.csv  sample_data
