## Constants

In [1]:
TRAIN_SIZE = 300
TEST_SIZE = 30
SEED = 123

## Dataset

In [2]:
from datasets import load_dataset

dataset = load_dataset("medalpaca/medical_meadow_medqa")
display(dataset)

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 10178
    })
})

In [3]:
dataset = dataset["train"].train_test_split(
    train_size=TRAIN_SIZE,
    test_size=TEST_SIZE,
    shuffle=True,
    seed=SEED
)
display(dataset)

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 300
    })
    test: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 30
    })
})

In [4]:
def print_sample(sample: dict[str, str]):
    message = "\n".join(
        f"\n# {k.capitalize()}\n{v}" for k, v in sample.items()
    )[1:]
    print(message)

print_sample(example := dataset["train"][0])

# Input
Q:A 5-year-old boy is brought to the emergency department after he fell on the playground in kindergarten and was unable to get up. His right leg was found to be bent abnormally at the femur, and he was splinted on site by first responders. His past medical history is significant for multiple prior fractures in his left humerus and femur. Otherwise, he has been hitting normal developmental milestones and appears to be excelling in kindergarten. Physical exam also reveals the finding shown in figure A. Which of the following is the most likely cause of this patient's multiple fractures?? 
{'A': 'Abnormal collagen production', 'B': 'Decreased collagen hydroxylation', 'C': 'Increased adenylyl cyclase activity', 'D': 'Mutation in neurofibromin', 'E': 'Non-accidental trauma'},

# Instruction
Please answer with one of the option in the bracket

# Output
C: Increased adenylyl cyclase activity


In [5]:
import json

def reformat_sample(sample: dict[str, str]) -> dict[str, str]:
    input = "Q: " + sample["input"].removeprefix("Q:").removesuffix(",")
    input = input.replace(
        "\n{",
        'Provide your answer as a JSON dictionary with the "option" and answer "text".\n{'
    )

    answer_option = sample["output"][0]
    answer_text = sample["output"][3:]
    output = json.dumps({"option": answer_option, "text": answer_text})

    return {"input": input, "output": output}


dataset = dataset.map(reformat_sample).remove_columns("instruction")

display(dataset)
print_sample(example := dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['input', 'output'],
        num_rows: 300
    })
    test: Dataset({
        features: ['input', 'output'],
        num_rows: 30
    })
})

# Input
Q: A 5-year-old boy is brought to the emergency department after he fell on the playground in kindergarten and was unable to get up. His right leg was found to be bent abnormally at the femur, and he was splinted on site by first responders. His past medical history is significant for multiple prior fractures in his left humerus and femur. Otherwise, he has been hitting normal developmental milestones and appears to be excelling in kindergarten. Physical exam also reveals the finding shown in figure A. Which of the following is the most likely cause of this patient's multiple fractures?? Provide your answer as a JSON dictionary with the "option" and answer "text".
{'A': 'Abnormal collagen production', 'B': 'Decreased collagen hydroxylation', 'C': 'Increased adenylyl cyclase activity', 'D': 'Mutation in neurofibromin', 'E': 'Non-accidental trauma'}

# Output
{"option": "C", "text": "Increased adenylyl cyclase activity"}


## Load Gemma 2B instruct model

In [10]:
# pip install torch transformers
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")

if torch.cuda.is_available():
    # pip install bitsandbytes accelerate
    from transformers import BitsAndBytesConfig

    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
else:
    quantization_config = None

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    quantization_config=quantization_config
)
print(model.device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

cpu


In [12]:
chat = [
    {"role": "user", "content": example["input"]},
    # {"role": "assistant", "content": example["output"]},
]
print(tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True))

<bos><start_of_turn>user
Q: A 5-year-old boy is brought to the emergency department after he fell on the playground in kindergarten and was unable to get up. His right leg was found to be bent abnormally at the femur, and he was splinted on site by first responders. His past medical history is significant for multiple prior fractures in his left humerus and femur. Otherwise, he has been hitting normal developmental milestones and appears to be excelling in kindergarten. Physical exam also reveals the finding shown in figure A. Which of the following is the most likely cause of this patient's multiple fractures?? Provide your answer as a JSON dictionary with the "option" and answer "text".
{'A': 'Abnormal collagen production', 'B': 'Decreased collagen hydroxylation', 'C': 'Increased adenylyl cyclase activity', 'D': 'Mutation in neurofibromin', 'E': 'Non-accidental trauma'}<end_of_turn>
<start_of_turn>model



In [None]:
input_text = """
Q:An 82-year-old male with congestive heart failure experiences rapid decompensation of his condition, manifesting as worsening dyspnea, edema, and increased fatigue. Labs reveal an increase in his serum creatinine from baseline. As part of the management of this acute change, the patient is given IV dobutamine to alleviate his symptoms. Which of the following effects occur as a result of this therapy?? {'A': 'Slowed atrioventricular conduction velocities', 'B': 'Increased myocardial oxygen consumption', 'C': 'Decreased heart rate', 'D': 'Increased systemic vascular resistance due to systemic vasoconstriction', 'E': 'Decreased cardiac contractility'},

Provide your answer as a JSON dictionary with the "option" and answer "text".
""".strip()

input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(
    **input_ids,
    do_sample=True,
    max_new_tokens=512,
    temperature=1e-3,
)
print(tokenizer.decode(outputs[0]))