## Constants

In [1]:
TRAIN_SIZE = 200
TEST_SIZE = 20
EVAL_SIZE = 50
SEED = 123

EVAL_BATCH_SIZE = 8
TRAIN_BATCH_SIZE = 1
TRAIN_GRADIENT_ACCUMULATION_STEPS = 4
TRAIN_LOGGING_STEPS = 10

## Dataset

In [2]:
# pip install datasets
from datasets import load_dataset

dataset = load_dataset("medalpaca/medical_meadow_medqa")

### Preprocessing steps

We're not going to focus much on the preprocessing in this tutorial, so feel free to skim over or skip this subsection.

In [3]:
# Sample train and test sets
display(dataset)

dataset = dataset["train"].train_test_split(
    train_size=TRAIN_SIZE,
    test_size=TEST_SIZE,
    shuffle=True,
    seed=SEED
)
display(dataset)

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 10178
    })
})

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 200
    })
    test: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 20
    })
})

In [4]:
# original format of the dataset
def print_sample(sample: dict[str, str]):
    message = "\n".join(
        f"\n# {k}\n{v}" for k, v in sample.items()
    )[1:]
    print(message)

print_sample(example := dataset["train"][0])

# input
Q:A 65-year-old man presents to the emergency department for sudden weakness. The patient states that he was at home enjoying his morning coffee when his symptoms began. He says that his left arm suddenly felt very odd and weak thus prompting him to come to the ED. The patient has a past medical history of diabetes, COPD, hypertension, anxiety, alcohol abuse, and PTSD. He recently fell off a horse while horseback riding but claims to not have experienced any significant injuries. He typically drinks 5-7 drinks per day and his last drink was yesterday afternoon. His current medications include insulin, metformin, atorvastatin, lisinopril, albuterol, and fluoxetine. His temperature is 99.5°F (37.5°C), blood pressure is 177/118 mmHg, pulse is 120/min, respirations are 18/min, and oxygen saturation is 93% on room air. On physical exam, you note an elderly man who is mildly confused. Cardiopulmonary exam demonstrates bilateral expiratory wheezes and a systolic murmur along the right

In [5]:
# reformat the dataset
import json

def reformat_sample(sample: dict[str, str]) -> dict[str, str]:
    input = "Q: " + sample["input"].removeprefix("Q:").removesuffix(",")
    input = input.replace(
        "\n{",
        (
            'Give your answer as a JSON dictionary in the form of'
            ' {"option": "A-E", "text": "corresponding text"}.'
            ' No yapping.'
            '\n{'
        )
    )
    answer_option = sample["output"][0]
    answer_text = sample["output"][3:]
    true_label = answer_option
    output = json.dumps({"option": answer_option, "text": answer_text})
    return {"input": input, "output": output, "true_label": true_label}


dataset = dataset.map(reformat_sample).remove_columns("instruction")

display(dataset)
print_sample(example := dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['input', 'output', 'true_label'],
        num_rows: 200
    })
    test: Dataset({
        features: ['input', 'output', 'true_label'],
        num_rows: 20
    })
})

# input
Q: A 65-year-old man presents to the emergency department for sudden weakness. The patient states that he was at home enjoying his morning coffee when his symptoms began. He says that his left arm suddenly felt very odd and weak thus prompting him to come to the ED. The patient has a past medical history of diabetes, COPD, hypertension, anxiety, alcohol abuse, and PTSD. He recently fell off a horse while horseback riding but claims to not have experienced any significant injuries. He typically drinks 5-7 drinks per day and his last drink was yesterday afternoon. His current medications include insulin, metformin, atorvastatin, lisinopril, albuterol, and fluoxetine. His temperature is 99.5°F (37.5°C), blood pressure is 177/118 mmHg, pulse is 120/min, respirations are 18/min, and oxygen saturation is 93% on room air. On physical exam, you note an elderly man who is mildly confused. Cardiopulmonary exam demonstrates bilateral expiratory wheezes and a systolic murmur along the righ

### Finetuning dataset

## Load Gemma 2B instruct model

In [6]:
# pip install torch transformers
# pip install bitsandbytes accelerate
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")

quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    quantization_config=quantization_config
)
print(model.device)

`low_cpu_mem_usage` was None, now set to True since model is quantized.
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

cuda:0


In [7]:
chat = [
    {"role": "user", "content": example["input"]},
]
input_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt").to(model.device)

output_ids = model.generate(
    input_ids,
    do_sample=True,
    max_new_tokens=512,
    temperature=1e-3,
)

print(tokenizer.decode(output_ids[0]))



<bos><start_of_turn>user
Q: A 65-year-old man presents to the emergency department for sudden weakness. The patient states that he was at home enjoying his morning coffee when his symptoms began. He says that his left arm suddenly felt very odd and weak thus prompting him to come to the ED. The patient has a past medical history of diabetes, COPD, hypertension, anxiety, alcohol abuse, and PTSD. He recently fell off a horse while horseback riding but claims to not have experienced any significant injuries. He typically drinks 5-7 drinks per day and his last drink was yesterday afternoon. His current medications include insulin, metformin, atorvastatin, lisinopril, albuterol, and fluoxetine. His temperature is 99.5°F (37.5°C), blood pressure is 177/118 mmHg, pulse is 120/min, respirations are 18/min, and oxygen saturation is 93% on room air. On physical exam, you note an elderly man who is mildly confused. Cardiopulmonary exam demonstrates bilateral expiratory wheezes and a systolic murm

In [8]:
def get_generated_texts(input_ids: torch.Tensor, output_ids: torch.Tensor, remove_eos: bool = True) -> list[str]:
    """Retreive only the generated texts based on input and output ids

    Args:
        input_ids (torch.Tensor): batch of input ids.
        output_ids (torch.Tensor): corresponding output ids.
        remove_eos (bool, optional): whether to remove the final <eos> token. Defaults to True.

    Returns:
        list[str]: the batch of text generated by the model based on the input texts.
    """
    texts = [
        tokenizer.decode(out_seq[len(in_seq):])
        for in_seq, out_seq in zip(input_ids, output_ids)
    ]
    if remove_eos:
        texts = [text.removesuffix("<eos>") for text in texts]
    return texts

print(get_generated_texts(input_ids, output_ids)[0])

{"option": "A-Berry aneurysm rupture", "text": "A berry aneurysm rupture is the most likely diagnosis in this patient due to the patient's history of trauma, recent fall, and symptoms of a ruptured artery."}


## Evaluate Gemma on test set

### Evaluation utils

These utils are used to extract the final label from the generated text.

In [9]:
import re

# pip install rapidfuzz
from rapidfuzz import fuzz
from rapidfuzz.utils import default_process

example_passage = """
Q:A 67-year-old man with a past medical history of poorly-controlled type 2 diabetes mellitus (T2DM) is brought to the emergency department for acute onset nausea and vomiting. According to the patient, he suddenly experienced vertigo and began vomiting 3 hours ago while watching TV. He reports hiking in New Hampshire with his wife 2 days ago. Past medical history is significant for a myocardial infarction (MI) that was treated with cardiac stenting, T2DM, and hypertension. Medications include lisinopril, aspirin, atorvastatin, warfarin, and insulin. Physical examination demonstrates left-sided facial droop and decreased pinprick sensation at the right arm and leg. What is the most likely etiology of this patient’s symptoms?? {'A': 'Early disseminated Lyme disease', 'B': 'Embolic stroke at the posterior inferior cerebellar artery (PICA)', 'C': 'Hypoperfusion of the anterior spinal artery (ASA)', 'D': 'Labryrinthitis', 'E': 'Thrombotic stroke at the anterior inferior cerebellar artery (AICA)'},
""".strip()

example_substr = "stroke at the anterior inferior cerebellar artery"

def get_available_qa_choices(passage: str) -> dict[str, str]:
    return {
        match.group(1): match.group(2)
        for match in re.finditer("'([^']+)': '([^']+)'", passage)
    }

def get_best_match(answer_text: str, passage: str) -> str:
    choices = get_available_qa_choices(passage)
    if not choices:
        return ""  # no choices found
    return sorted(
        choices,
        key=lambda c: fuzz.token_set_ratio(choices[c], answer_text, processor=default_process)
    )[-1]

display(get_available_qa_choices(example_passage))
display(get_best_match(example_substr, example_passage))

{'A': 'Early disseminated Lyme disease',
 'B': 'Embolic stroke at the posterior inferior cerebellar artery (PICA)',
 'C': 'Hypoperfusion of the anterior spinal artery (ASA)',
 'D': 'Labryrinthitis',
 'E': 'Thrombotic stroke at the anterior inferior cerebellar artery (AICA)'}

'E'

In [10]:
example_pred_text = """
{
"option": "A",
"answer": "Early disseminated Lyme disease"
}<eos>
""".strip()

def parse_prediction(pred_text: str, passage: str = "") -> str:
    """Parse the predicted answer based on the output text, with text matching as backup.

    Args:
        pred_text (str): text outputted from the language model.
        passage (str, optional): The input passage for text matchingin case the LLM does
            not output parseable JSON. Useful for evaluating the model before finetuning.
            Defaults to "" (no passage); in this case the backup prediction will be "".

    Returns:
        str: option (a letter or "") predicted by the model.
    """
    json_text = pred_text
    if match := re.search(r"\{", json_text):  # remove anything before first {
        json_text = json_text[match.start() :]
    if match := re.search(r"\}", json_text[::-1]):  # remove anything after last }
        json_text = json_text[: len(json_text) - match.start()]

    try:
        if match := re.match(r"^[a-eA-E]$", json.loads(json_text)["option"].strip()):
            return match.group(0)
    except (json.JSONDecodeError, KeyError):
        pass

    # backup: if a passage is supplied, get the best match
    if passage:
        return get_best_match(pred_text, passage)
    else:
        return ""

display(parse_prediction(example_pred_text))
display(parse_prediction(example_substr, example_passage))
display(parse_prediction('{"option": "A - Early disseminated Lyme disease"}', example_passage))
display(parse_prediction(example_substr))

'A'

'E'

'A'

''

### Run eval

In [11]:
from typing import Callable
from transformers.modeling_utils import PreTrainedModel

def batch_get_preds(model: PreTrainedModel, model_name: str) -> Callable[[dict], dict]:
    def _get_preds(samples: dict) -> dict:
        input_ids = tokenizer.apply_chat_template(
            [[{"role": "user", "content": text}] for text in samples["input"]],
            add_generation_prompt=True,
            padding=True,
            return_tensors="pt",
        ).to(model.device)

        output_ids = model.generate(
            input_ids,
            do_sample=True,
            max_new_tokens=512,
            temperature=1e-3,
        )

        pred_texts = get_generated_texts(input_ids, output_ids)
        samples[f"{model_name}_pred"] = pred_texts
        samples[f"{model_name}_label"] = pred_labels = [
            parse_prediction(pred, text)
            for text, pred in zip(samples["input"], pred_texts)
        ]
        samples[f"{model_name}_correct"] = [
            pred_label == true_label
            for true_label, pred_label in zip(samples["true_label"], pred_labels)
        ]

        return samples
    return _get_preds

test_set = dataset["test"]
test_set = test_set.map(batch_get_preds(model, "gemma"), batched=True, batch_size=EVAL_BATCH_SIZE)

gemma_accuracy = sum(test_set["gemma_correct"]) / len(test_set)
print(f"Gemma accuracy: {round(gemma_accuracy * 100, 1)}%")

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Gemma accuracy: 20.0%


## Finetune

In [12]:
train_set = dataset["train"]

def create_chat_for_finetuning(samples: dict) -> list[str]:
    chat_texts = tokenizer.apply_chat_template(
        [
            [
                {"role": "user", "content": input},
                {"role": "assistant", "content": output},
            ]
            for input, output in zip(samples["input"], samples["output"])
        ],
        tokenize=False,
    )
    return [text.removeprefix("<bos>") for text in chat_texts]

display(create_chat_for_finetuning(dataset["train"][:2]))

['<start_of_turn>user\nQ: A 65-year-old man presents to the emergency department for sudden weakness. The patient states that he was at home enjoying his morning coffee when his symptoms began. He says that his left arm suddenly felt very odd and weak thus prompting him to come to the ED. The patient has a past medical history of diabetes, COPD, hypertension, anxiety, alcohol abuse, and PTSD. He recently fell off a horse while horseback riding but claims to not have experienced any significant injuries. He typically drinks 5-7 drinks per day and his last drink was yesterday afternoon. His current medications include insulin, metformin, atorvastatin, lisinopril, albuterol, and fluoxetine. His temperature is 99.5°F (37.5°C), blood pressure is 177/118 mmHg, pulse is 120/min, respirations are 18/min, and oxygen saturation is 93% on room air. On physical exam, you note an elderly man who is mildly confused. Cardiopulmonary exam demonstrates bilateral expiratory wheezes and a systolic murmur

In [13]:
# pip install peft
from bitsandbytes.nn import Linear4bit, Linear8bitLt
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, Linear4bit) or isinstance(module, Linear8bitLt):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
        if 'lm_head' in lora_module_names: # needed for 16-bit
            lora_module_names.remove('lm_head')
    return list(lora_module_names)

modules = find_all_linear_names(model)
print(modules)

lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable:,} | total: {total:,} | Percentage: {trainable/total*100:.4f}%")

['v_proj', 'gate_proj', 'k_proj', 'q_proj', 'down_proj', 'up_proj', 'o_proj']
Trainable: 78,446,592 | total: 2,584,619,008 | Percentage: 3.0351%


In [25]:
# pip install trl
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer

tokenizer.padding_side = "right"
model.config.use_cache=False
torch.cuda.empty_cache()

collator = DataCollatorForCompletionOnlyLM(
    instruction_template="<start_of_turn>user\n",
    response_template="<start_of_turn>model\n",
    tokenizer=tokenizer,
    mlm=False,
)

trainer = SFTTrainer(
    model,
    args=SFTConfig(
        output_dir="/tmp/finetuned_gemma_2b",
        per_device_train_batch_size=TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=TRAIN_GRADIENT_ACCUMULATION_STEPS,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs=dict(use_reentrant=False),
        max_seq_length=2048,
        num_train_epochs=1,
        save_strategy="epoch",
        logging_steps=TRAIN_LOGGING_STEPS,
        eval_steps=TRAIN_LOGGING_STEPS,
        eval_strategy="steps",
    ),
    data_collator=collator,
    eval_dataset=test_set.select(range(min(len(test_set), EVAL_SIZE))),
    formatting_func=create_chat_for_finetuning,
    peft_config=lora_config,
    train_dataset=dataset["train"],
    tokenizer=tokenizer,
)

train_result = trainer.train()
display(train_result._asdict())

Step,Training Loss,Validation Loss
10,0.0559,0.214258


KeyboardInterrupt: 

In [None]:
tokenizer.padding_side = "left"
model.config.use_cache=True
model.gradient_checkpointing_disable()
model.eval();

In [None]:
test_set = test_set.map(batch_get_preds(model, "finetuned"), batched=True, batch_size=EVAL_BATCH_SIZE)
finetuned_accuracy = sum(test_set["finetuned_correct"]) / len(test_set)

print(f"Gemma accuracy: {round(gemma_accuracy * 100, 1)}%")
print(f"Finetuned accuracy: {round(finetuned_accuracy * 100, 1)}%")

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Gemma accuracy: 20.0%
Finetuned accuracy: 20.0%
