## Load necessary packages

In [7]:
import json, ast, torch, random
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification
    # AdamW,
)
from sklearn.metrics import accuracy_score
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
print(torch.cuda.is_available())

True


## Load Data

In [11]:
import pandas as pd

file_path = "Data/PubMedQA_cleaned.json"
QA_data = pd.read_json(file_path)

In [12]:
QA_data

Unnamed: 0,id,context,question,options,gold_index
0,0,(Objective) We evaluated the usefulness of a s...,A short stay or 23-hour ward in a general and ...,"[No, Maybe, Yes]",2
1,1,(Methods) The records of 465 patients with an ...,Amblyopia: is visual loss permanent?,"[No, Maybe, Yes]",0
2,2,(Background) Radiotherapy reduces local recurr...,Does radiotherapy of the primary rectal cancer...,"[No, Maybe, Yes]",2
3,3,(Background) Pterygium is a disease of unknown...,Human papillomavirus and pterygium. Is the vir...,"[No, Maybe, Yes]",1
4,4,(Purpose) Reconstructing the natural joint lin...,Assessing joint line positions by means of the...,"[No, Maybe, Yes]",2
...,...,...,...,...,...
995,995,"(Background) ""America's Best Hospitals,"" an in...","Do ""America's Best Hospitals"" perform better f...","[No, Maybe, Yes]",2
996,996,(Background) Some patients with suspected comm...,The clinical significance of bile duct sludge:...,"[No, Maybe, Yes]",0
997,997,(Objective) To examine longitudinal patterns i...,Does obesity predict knee pain over fourteen y...,"[No, Maybe, Yes]",2
998,998,(Objectives) To assess Internet use amongst yo...,Can the Internet be used to improve sexual hea...,"[No, Maybe, Yes]",1


In [18]:
sample = QA_data.iloc[2].to_frame()
sample

Unnamed: 0,2
id,2
context,(Background) Radiotherapy reduces local recurr...
question,Does radiotherapy of the primary rectal cancer...
options,"[No, Maybe, Yes]"
gold_index,2


## Hugging Face Login

In [5]:
import os
from huggingface_hub import login

hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


## Teacher Model

In [9]:
teacher_model_name = "Henrychur/MMed-Llama-3-8B"  # Example: a T5 base model
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name)

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/4.89G [00:00<?, ?B/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/2.57G [00:00<?, ?B/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/121 [00:00<?, ?B/s]

In [33]:
teacher_model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

## Build Prompt

In [26]:
def build_prompt(input):
    input_context = input.loc['context'].values[0]
    input_question = input.loc['question'].values[0]
    prompt = f"""Read an abstract from a PubMed paper and answer the question: {input_context}

Question: {input_question}
Instruction: Return ONLY a confidence score over the three options ['No', 'Maybe', 'Yes'], DO NOT include any text output.
Format your answer strictly as: [prob_yes, prob_maybe, prob_no], where all numbers are between 0 and 1 and sum up to 1.
Example: [0.1, 0.2, 0.7]
"""
    return prompt

In [27]:
prompt_text = build_prompt(sample)
print(prompt_text)

Read an abstract from a PubMed paper and answer the question: (Background) Radiotherapy reduces local recurrence rates but is also capable of short- and long-term toxicity. It may also render treatment of local recurrence more challenging if it develops despite previous radiotherapy.
(Objective) This study examined the impact of radiotherapy for the primary rectal cancer on outcomes after pelvic exenteration for local recurrence.
(Design) We conducted a retrospective review of exenteration databases.
(Setting) The study took place at a quaternary referral center that specializes in pelvic exenteration.
(Patients) Patients referred for pelvic exenteration from October 1994 to November 2012 were reviewed. Patients who did and did not receive radiotherapy as part of their primary rectal cancer treatment were compared.
(Main outcome measures) The main outcomes of interest were resection margins, overall survival, disease-free survival, and surgical morbidities.
(Results) There were 108 pat

## Run Model

In [32]:
teacher_model = teacher_model.to(device)

OutOfMemoryError: CUDA out of memory. Tried to allocate 224.00 MiB. GPU 0 has a total capacity of 22.04 GiB of which 14.12 MiB is free. Including non-PyTorch memory, this process has 22.01 GiB memory in use. Of the allocated memory 21.83 GiB is allocated by PyTorch, and 1.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [30]:
inputs = teacher_tokenizer(
    prompt_text,
    return_tensors="pt",
    # padding=True,
    truncation=True,
).to(device)

In [31]:
# Teacher inference
with torch.no_grad():
    teacher_logits = teacher_model(**inputs)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)