## Constants

In [1]:
TRAIN_SIZE = 30
TEST_SIZE = 3
SEED = 123
BATCH_SIZE = 8

## Dataset

In [2]:
# pip install datasets
from datasets import load_dataset

dataset = load_dataset("medalpaca/medical_meadow_medqa")

### Preprocessing steps

We're not going to focus much on the preprocessing in this tutorial, so feel free to skim over or skip this subsection.

In [3]:
# Sample train and test sets
display(dataset)

dataset = dataset["train"].train_test_split(
    train_size=TRAIN_SIZE,
    test_size=TEST_SIZE,
    shuffle=True,
    seed=SEED
)
display(dataset)

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 10178
    })
})

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 30
    })
    test: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 3
    })
})

In [4]:
# original format of the dataset
def print_sample(sample: dict[str, str]):
    message = "\n".join(
        f"\n# {k}\n{v}" for k, v in sample.items()
    )[1:]
    print(message)

print_sample(example := dataset["train"][0])

# input
Q:A 9-year-old girl is brought to the pediatrician by her parents because of unremitting cough, fevers, night sweats, anorexia, and weight loss for 4 weeks. Her vaccinations are up to date. When asked about recent exposure to an ill person, the parents mention that she is frequently under the care of a middle-aged woman who recently immigrated from a small rural community in north India. Her temperature is 39.0°C (102.2°F), respiratory rate is 30/min, and heart rate is 120/min. Her weight is 2 standard deviations below normal for her age. Chest auscultation shows fine crackles in both lung fields.  The patient is referred to a nearby children’s hospital where her clinical condition rapidly worsens over several weeks. A chest radiograph is shown. Microbiological evaluation of a bronchial aspirate reveals an organism with a cell wall that is impervious to Gram stain. Which of the following best describes the cell wall of the causative agent?? 
{'A': 'Low muramic acid content', 'B

In [5]:
# reformat the dataset
import json

def reformat_sample(sample: dict[str, str]) -> dict[str, str]:
    input = "Q: " + sample["input"].removeprefix("Q:").removesuffix(",")
    input = input.replace(
        "\n{",
        (
            'Give your answer as a JSON dictionary in the form of'
            ' {"option": "A-E", "text": "corresponding text"}.'
            ' No yapping.'
            '\n{'
        )
    )

    answer_option = sample["output"][0]
    answer_text = sample["output"][3:]
    true_label = answer_option
    output = json.dumps({"option": answer_option, "text": answer_text})

    return {"input": input, "output": output, "true_label": true_label}


dataset = dataset.map(reformat_sample).remove_columns("instruction")

display(dataset)
print_sample(example := dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['input', 'output', 'true_label'],
        num_rows: 30
    })
    test: Dataset({
        features: ['input', 'output', 'true_label'],
        num_rows: 3
    })
})

# input
Q: A 9-year-old girl is brought to the pediatrician by her parents because of unremitting cough, fevers, night sweats, anorexia, and weight loss for 4 weeks. Her vaccinations are up to date. When asked about recent exposure to an ill person, the parents mention that she is frequently under the care of a middle-aged woman who recently immigrated from a small rural community in north India. Her temperature is 39.0°C (102.2°F), respiratory rate is 30/min, and heart rate is 120/min. Her weight is 2 standard deviations below normal for her age. Chest auscultation shows fine crackles in both lung fields.  The patient is referred to a nearby children’s hospital where her clinical condition rapidly worsens over several weeks. A chest radiograph is shown. Microbiological evaluation of a bronchial aspirate reveals an organism with a cell wall that is impervious to Gram stain. Which of the following best describes the cell wall of the causative agent?? Give your answer as a JSON dictionar

### Finetuning dataset

## Load Gemma 2B instruct model

In [6]:
# pip install torch transformers
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")

if torch.cuda.is_available():
    # pip install bitsandbytes accelerate
    from transformers import BitsAndBytesConfig

    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
else:
    quantization_config = None

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    quantization_config=quantization_config
)
print(model.device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

cpu


In [7]:
chat = [
    {"role": "user", "content": example["input"]},
]
input_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt").to(model.device)

output_ids = model.generate(
    input_ids,
    do_sample=True,
    max_new_tokens=512,
    temperature=1e-3,
)

print(tokenizer.decode(output_ids[0]))

<bos><start_of_turn>user
Q: A 9-year-old girl is brought to the pediatrician by her parents because of unremitting cough, fevers, night sweats, anorexia, and weight loss for 4 weeks. Her vaccinations are up to date. When asked about recent exposure to an ill person, the parents mention that she is frequently under the care of a middle-aged woman who recently immigrated from a small rural community in north India. Her temperature is 39.0°C (102.2°F), respiratory rate is 30/min, and heart rate is 120/min. Her weight is 2 standard deviations below normal for her age. Chest auscultation shows fine crackles in both lung fields.  The patient is referred to a nearby children’s hospital where her clinical condition rapidly worsens over several weeks. A chest radiograph is shown. Microbiological evaluation of a bronchial aspirate reveals an organism with a cell wall that is impervious to Gram stain. Which of the following best describes the cell wall of the causative agent?? Give your answer as

In [8]:
def get_generated_texts(input_ids: torch.Tensor, output_ids: torch.Tensor, remove_eos: bool = True) -> list[str]:
    """Retreive only the generated text based on input and output ids

    Args:
        input_ids (torch.Tensor): batch of input ids
        output_ids (torch.Tensor): corresponding output ids
        remove_eos (bool, optional): whether to remove the final <eos> token. Defaults to True.

    Returns:
        list[str]: _description_
    """
    texts = [
        tokenizer.decode(out_seq[len(in_seq):])
        for in_seq, out_seq in zip(input_ids, output_ids)
    ]
    if remove_eos:
        texts = [text.removesuffix("<eos>") for text in texts]
    return texts

print(get_generated_texts(input_ids, output_ids)[0])

{"option": "D", "text": "Absence of cellular wall"}


## Evaluate before finetuning

### Evaluation utils

These utils are used to extract the final label from the generated text.

In [9]:
import re

# pip install rapidfuzz
from rapidfuzz import fuzz
from rapidfuzz.utils import default_process

example_passage = """
Q:A 67-year-old man with a past medical history of poorly-controlled type 2 diabetes mellitus (T2DM) is brought to the emergency department for acute onset nausea and vomiting. According to the patient, he suddenly experienced vertigo and began vomiting 3 hours ago while watching TV. He reports hiking in New Hampshire with his wife 2 days ago. Past medical history is significant for a myocardial infarction (MI) that was treated with cardiac stenting, T2DM, and hypertension. Medications include lisinopril, aspirin, atorvastatin, warfarin, and insulin. Physical examination demonstrates left-sided facial droop and decreased pinprick sensation at the right arm and leg. What is the most likely etiology of this patient’s symptoms?? {'A': 'Early disseminated Lyme disease', 'B': 'Embolic stroke at the posterior inferior cerebellar artery (PICA)', 'C': 'Hypoperfusion of the anterior spinal artery (ASA)', 'D': 'Labryrinthitis', 'E': 'Thrombotic stroke at the anterior inferior cerebellar artery (AICA)'},
""".strip()

example_substr = "stroke at the anterior inferior cerebellar artery"

def get_available_qa_choices(passage: str) -> dict[str, str]:
    return {
        match.group(1): match.group(2)
        for match in re.finditer("'([^']+)': '([^']+)'", passage)
    }

def get_best_match(answer_text: str, passage: str) -> str:
    choices = get_available_qa_choices(passage)
    if not choices:
        return ""  # no choices found
    return sorted(
        choices,
        key=lambda c: fuzz.token_set_ratio(choices[c], answer_text, processor=default_process)
    )[-1]

display(get_available_qa_choices(example_passage))
display(get_best_match(example_substr, example_passage))

{'A': 'Early disseminated Lyme disease',
 'B': 'Embolic stroke at the posterior inferior cerebellar artery (PICA)',
 'C': 'Hypoperfusion of the anterior spinal artery (ASA)',
 'D': 'Labryrinthitis',
 'E': 'Thrombotic stroke at the anterior inferior cerebellar artery (AICA)'}

'E'

In [10]:
example_pred_text = """
{
"option": "A",
"answer": "Early disseminated Lyme disease"
}<eos>
""".strip()

def parse_prediction(pred_text: str, passage: str = "") -> str:
    """Parse the predicted answer based on the output text, with text matching as backup.

    Args:
        pred_text (str): text outputted from the language model.
        passage (str, optional): The input passage for text matchingin case the LLM does
            not output parseable JSON. Useful for evaluating the model before finetuning.
            Defaults to "" (no passage); in this case the backup prediction will be "".

    Returns:
        str: option (a letter or "") predicted by the model.
    """
    json_text = pred_text
    if match := re.search(r"\{", json_text):  # remove anything before first {
        json_text = json_text[match.start() :]
    if match := re.search(r"\}", json_text[::-1]):  # remove anything after last }
        json_text = json_text[: len(json_text) - match.start()]

    try:
        if match := re.match(r"^[a-eA-E]$", json.loads(json_text)["option"].strip()):
            return match.group(0)
    except (json.JSONDecodeError, KeyError):
        pass

    # backup: if a passage is supplied, get the best match
    if passage:
        return get_best_match(pred_text, passage)
    else:
        return ""

display(parse_prediction(example_pred_text))
display(parse_prediction(example_substr, example_passage))
display(parse_prediction('{"option": "A - Early disseminated Lyme disease"}', example_passage))
display(parse_prediction(example_substr))

'A'

'E'

'A'

''

### Run eval

In [19]:
from typing import Callable
from transformers.modeling_utils import PreTrainedModel

def batch_get_preds(model: PreTrainedModel, model_name: str) -> Callable[[dict], dict]:
    def _get_preds(samples: dict) -> dict:
        input_ids = tokenizer.apply_chat_template(
            [[{"role": "user", "content": text}] for text in samples["input"]],
            add_generation_prompt=True,
            padding=True,
            return_tensors="pt",
        ).to(model.device)

        output_ids = model.generate(
            input_ids,
            do_sample=True,
            max_new_tokens=512,
            temperature=1e-3,
        )

        pred_texts = get_generated_texts(input_ids, output_ids)
        samples[f"{model_name}_pred"] = pred_texts
        samples[f"{model_name}_label"] = pred_labels = [
            parse_prediction(pred, text)
            for text, pred in zip(samples["input"], pred_texts)
        ]
        samples[f"{model_name}_correct"] = [
            pred_label == true_label
            for true_label, pred_label in zip(samples["true_label"], pred_labels)
        ]

        return samples
    return _get_preds

test_set = dataset["test"]
test_set = test_set.map(batch_get_preds(model, "gemma"), batched=True, batch_size=BATCH_SIZE)
gemma_accuracy = sum(test_set["gemma_correct"]) / len(test_set)
print(round(gemma_accuracy * 100, 1))

33.3


In [12]:
test_set["true_label"]

['D', 'D', 'A']