In [1]:
!pip install -q "transformers>=4.43.0" peft bitsandbytes accelerate


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m00:01[0m:00:01[0m
[?25h

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from huggingface_hub import login
from getpass import getpass


In [3]:
HF_TOKEN = getpass("Enter your HuggingFace token (read access): ")
login(HF_TOKEN)


In [4]:
base_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(
    base_model_name,
    use_fast=True,
    token=HF_TOKEN,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto",
    token=HF_TOKEN,
)

base_model.eval()


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRM

In [5]:
adapter_repo_id = "smedara/llama3-med-easi-explainer"

model = PeftModel.from_pretrained(
    base_model,
    adapter_repo_id,
    token=HF_TOKEN,
)

model.eval()


adapter_config.json:   0%|          | 0.00/1.07k [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.9M [00:00<?, ?B/s]

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.

In [6]:
def build_explain_prompt(expert_text: str) -> str:
    system_message = (
        "You are a medical explainer. Your job is to rewrite the medical text below "
        "into a clear explanation for a patient.\n\n"
        "Rules:\n"
        "- Only use information found in the original text.\n"
        "- Do NOT add any new causes, treatments, or details.\n"
        "- Use short, simple sentences (8th-grade level).\n"
        "- Avoid medical jargon. If you must use it, briefly define it.\n"
        "- Be friendly and reassuring.\n"
    )

    user_message = (
        "Explain the following medical diagnosis to a patient:\n\n"
        f"{expert_text}\n\n"
        "Patient-friendly explanation:"
    )

    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
    ]

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    return prompt


In [11]:
def explain_for_patient(expert_text: str, max_new_tokens=150):
    prompt = build_explain_prompt(expert_text)

    enc = tokenizer(
        [prompt],
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512,
    ).to(model.device)

    input_ids = enc["input_ids"]
    input_len = input_ids.shape[1]

    with torch.no_grad():
        out = model.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            min_new_tokens=50,
            num_beams=4,
            length_penalty=1.1,
            no_repeat_ngram_size=3,
            early_stopping=True,
        )

    gen_ids = out[0, input_len:]
    decoded = tokenizer.decode(gen_ids, skip_special_tokens=True)
    return decoded.strip()


In [13]:
expert_example = (
    "The patient presents with acute decompensated systolic heart failure secondary "
    "to long-standing ischemic cardiomyopathy. Echocardiography shows a left "
    "ventricular ejection fraction of 20% with global hypokinesis and moderate "
    "functional mitral regurgitation. He also has poorly controlled hypertension "
    "and stage 3 chronic kidney disease, which further exacerbate volume overload "
    "and limit the use of certain guideline-directed medical therapies."
)

print("EXPERT TEXT\n", expert_example)
print("MODEL EXPLANATION\n")
print(explain_for_patient(expert_example))


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


EXPERT TEXT
 The patient presents with acute decompensated systolic heart failure secondary to long-standing ischemic cardiomyopathy. Echocardiography shows a left ventricular ejection fraction of 20% with global hypokinesis and moderate functional mitral regurgitation. He also has poorly controlled hypertension and stage 3 chronic kidney disease, which further exacerbate volume overload and limit the use of certain guideline-directed medical therapies.
MODEL EXPLANATION

The patient has a history of heart failure, high blood pressure, and kidney disease. His heart is not pumping well, and his kidneys are not working as well as they should. These conditions make it harder to use some of the treatments that are usually recommended for heart failure.
