## Load necessary packages

In [1]:
import json, ast, torch, random
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification
    # AdamW,
)
from sklearn.metrics import accuracy_score
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
print(torch.cuda.is_available())

True


## Load Data

In [3]:
import pandas as pd

file_path = "Data/PubMedQA_cleaned.json"
QA_data = pd.read_json(file_path)

In [4]:
QA_data

Unnamed: 0,id,context,question,options,gold_index
0,0,(Objective) We evaluated the usefulness of a s...,A short stay or 23-hour ward in a general and ...,"[No, Maybe, Yes]",2
1,1,(Methods) The records of 465 patients with an ...,Amblyopia: is visual loss permanent?,"[No, Maybe, Yes]",0
2,2,(Background) Radiotherapy reduces local recurr...,Does radiotherapy of the primary rectal cancer...,"[No, Maybe, Yes]",2
3,3,(Background) Pterygium is a disease of unknown...,Human papillomavirus and pterygium. Is the vir...,"[No, Maybe, Yes]",1
4,4,(Purpose) Reconstructing the natural joint lin...,Assessing joint line positions by means of the...,"[No, Maybe, Yes]",2
...,...,...,...,...,...
995,995,"(Background) ""America's Best Hospitals,"" an in...","Do ""America's Best Hospitals"" perform better f...","[No, Maybe, Yes]",2
996,996,(Background) Some patients with suspected comm...,The clinical significance of bile duct sludge:...,"[No, Maybe, Yes]",0
997,997,(Objective) To examine longitudinal patterns i...,Does obesity predict knee pain over fourteen y...,"[No, Maybe, Yes]",2
998,998,(Objectives) To assess Internet use amongst yo...,Can the Internet be used to improve sexual hea...,"[No, Maybe, Yes]",1


In [30]:
sample = QA_data.iloc[7].to_frame()
sample

Unnamed: 0,7
id,7
context,(Purpose) To evaluate the efficacy of extracor...
question,Can infundibular height predict the clearance ...
options,"[No, Maybe, Yes]"
gold_index,2


## Hugging Face Login

In [6]:
import os
from huggingface_hub import login

hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


## Teacher Model

In [7]:
teacher_model_name = "Henrychur/MMed-Llama-3-8B"  
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [8]:
teacher_model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

## Build Prompt

In [39]:
def build_prompt(input):
    input_context = input.loc['context'].values[0]
    input_question = input.loc['question'].values[0]
    prompt = f"""Read an abstract from a PubMed paper and answer the question: {input_context}

Question: {input_question}

Instruction: Return ONLY three confidence scores over the three options ['No', 'Maybe', 'Yes'], the confidence scores should be less than 1 and sum to 1. 
DO NOT include any text output. Format your answer strictly as: [score_no, score_maybe, score_yes].
DO NOT include any explanation or additional text. Only return the scores in the specified format.
Answer:
"""
    return prompt

In [40]:
prompt_text = build_prompt(sample)
print(prompt_text)

Read an abstract from a PubMed paper and answer the question: (Purpose) To evaluate the efficacy of extracorporeal shock wave lithotripsy (SWL) on lower calyceal calculi in relation to the renal anatomical factors and determine which of these factors can be used to select patients who will benefit from SWL.
(Materials and methods) We analyzed retrospectively 78 patients with single radiopaque lower calyceal stones treated with SWL. The patients were evaluated 3 months after lithotripsy with a simple abdominal X-ray and a kidney ultrasound scan. The success of the treatment, removal of all fragments, was correlated with renal anatomical factors measured in the pre-treatment intravenous urography: infundibulopelvic angle, lower infundibulum width, lower infundibulum length, ratio length/width, infundibulum height, and number of minor calyces in the lower calyceal group.
(Results) Three months after SWL treatment, 39 patients were stone-free (NR group) and 39 had residual fragments (R gro

## Run Model

In [11]:
teacher_model = teacher_model.half() # convert float32 to float16
teacher_model = teacher_model.to(device)

In [1]:
inputs = teacher_tokenizer(
    prompt_text,
    return_tensors="pt",
    # padding=True,
    # truncation=True,
).to(device)

NameError: name 'teacher_tokenizer' is not defined

In [81]:
# Teacher inference
with torch.no_grad():
    teacher_logits = teacher_model(**inputs)

In [82]:
teacher_logits

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 5.3828,  3.3457,  1.2578,  ..., -7.1992, -7.1992, -7.1992],
         [ 3.8145,  3.8008,  2.6875,  ..., -6.4961, -6.4961, -6.4961],
         [ 7.6992,  3.9766,  1.2461,  ..., -4.1719, -4.1719, -4.1719],
         ...,
         [ 6.9805,  9.3906,  9.8203,  ..., -5.4609, -5.4609, -5.4609],
         [ 7.0742,  4.0312,  5.1797,  ..., -4.2227, -4.2227, -4.2227],
         [ 6.4336,  9.4453,  9.8125,  ..., -5.3711, -5.3711, -5.3711]]],
       device='cuda:0', dtype=torch.float16), past_key_values=<transformers.cache_utils.DynamicCache object at 0x7f3a0d76f490>, hidden_states=None, attentions=None)

In [None]:
# Generate from the model
teacher_out_ids = teacher_model.generate(
    **inputs, 
    eos_token_id=60,
    pad_token_id=teacher_tokenizer.eos_token_id
    )

In [35]:
# Decode back to text
decoded = teacher_tokenizer.decode(teacher_out_ids[0], skip_special_tokens=False)
decoded = decoded.split("Answer:")[1]
print(decoded)


[0, 0, 1]
Instruction: Return ONLY three confidence scores over the three options


In [36]:
# Get the token ID for "]"
token_id = teacher_tokenizer.convert_tokens_to_ids("]")
print(f"Token ID for ']': {token_id}")

Token ID for ']': 60


## Calculate Accuracy

In [132]:
def eval_prompt(input):
    input_context = input.loc['context'].values[0]
    input_question = input.loc['question'].values[0]
    prompt = f"""INSTRUCTION:
DO NOT include any explanation or additional text. Only return the one word answer from the three options: 'No', 'Maybe', or 'Yes'.

EXAMPLES:
Input: Read an abstract about X. Question: Does this support hypothesis Y?
Answer: Yes

Input: Read an abstract about Z. Question: Is this evidence inconclusive?
Answer: Maybe

TASK:
Read an abstract from a PubMed paper and answer the question: {input_context}

Question: {input_question}
Answer:
"""
    return prompt

In [133]:
prompt = eval_prompt(sample)
print(prompt)

INSTRUCTION:
DO NOT include any explanation or additional text. Only return the one word answer from the three options: 'No', 'Maybe', or 'Yes'.

EXAMPLES:
Input: Read an abstract about X. Question: Does this support hypothesis Y?
Answer: Yes

Input: Read an abstract about Z. Question: Is this evidence inconclusive?
Answer: Maybe

TASK:
Read an abstract from a PubMed paper and answer the question: (Background) Several prospective randomized trials have proved carotid endarterectomy to be safe and effective for both symptomatic and asymptomatic patients younger than 80 years of age. Recently, carotid artery stenting (CAS) has been approved for use in selected high-risk patients. It has been proposed that being an octogenarian places patients in this high-risk category.
(Study design) All patients between the ages of 80 to 89 years undergoing carotid endarterectomy during a 12-year period were included in the study. Information included indications for carotid endarterectomy, associated 

In [134]:
# Tokenize the prompt
inputs = teacher_tokenizer(prompt, return_tensors="pt").to(teacher_model.device)

In [135]:
# Generate output from the teacher model
with torch.no_grad():
    teacher_out_ids = teacher_model.generate(
        **inputs,
        eos_token_id=teacher_tokenizer.eos_token_id
    )

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


In [136]:
# Decode the output
decoded_output = teacher_tokenizer.decode(teacher_out_ids[0], skip_special_tokens=True)

In [137]:
print(decoded_output)

INSTRUCTION:
DO NOT include any explanation or additional text. Only return the one word answer from the three options: 'No', 'Maybe', or 'Yes'.

EXAMPLES:
Input: Read an abstract about X. Question: Does this support hypothesis Y?
Answer: Yes

Input: Read an abstract about Z. Question: Is this evidence inconclusive?
Answer: Maybe

TASK:
Read an abstract from a PubMed paper and answer the question: (Background) Several prospective randomized trials have proved carotid endarterectomy to be safe and effective for both symptomatic and asymptomatic patients younger than 80 years of age. Recently, carotid artery stenting (CAS) has been approved for use in selected high-risk patients. It has been proposed that being an octogenarian places patients in this high-risk category.
(Study design) All patients between the ages of 80 to 89 years undergoing carotid endarterectomy during a 12-year period were included in the study. Information included indications for carotid endarterectomy, associated 