## Constants

In [1]:
TRAIN_SIZE = 200
TEST_SIZE = 20
EVAL_SIZE = 50
SEED = 123

USE_4BIT = True
USE_8BIT = False
GENERATE_KWARGS = dict(
    do_sample=True,
    max_new_tokens=512,
    temperature=1e-3,
)
EVAL_BATCH_SIZE = 8
TRAIN_BATCH_SIZE = 1
TRAIN_GRADIENT_ACCUMULATION_STEPS = 4
TRAIN_LOGGING_STEPS = 10

## Load Gemma 2B instruct model

In [6]:
# pip install torch transformers
# pip install bitsandbytes accelerate
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=USE_4BIT,
    load_in_8bit=USE_8BIT,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_8bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    quantization_config=quantization_config
)
print(model.device)

`low_cpu_mem_usage` was None, now set to True since model is quantized.
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

cuda:0


## Medical Q & A

In [None]:
from medqa_data import load_train_test_data

train_data, test_data = load_train_test_data(
    train_size=TRAIN_SIZE,
    test_size=TEST_SIZE,
    seed=SEED,
)

display(train_data)
display(test_data)

def print_sample(sample: dict[str, str]):
    message = "\n".join(
        f"\n# {k}\n{v}" for k, v in sample.items()
    )[1:]
    print(message)

print_sample(example := test_data[0])

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Dataset({
    features: ['input', 'output', 'true_label'],
    num_rows: 200
})

Dataset({
    features: ['input', 'output', 'true_label'],
    num_rows: 20
})

# input
Q: A 5-year-old boy presents to the emergency department with a sore throat and trouble breathing. His mother states that his symptoms started last night and have rapidly been worsening. The patient is typically healthy, has received all his childhood immunizations, and currently takes a daily multivitamin. His temperature is 103°F (39.4°C), blood pressure is 100/64 mmHg, pulse is 155/min, respirations are 29/min, and oxygen saturation is 95% on room air. Physical exam is notable for an ill-appearing child who is drooling and is leaning forward to breathe. He does not answer questions and appears very uncomfortable. He will not comply with physical exam to open his mouth for inspection of the oropharynx. Which of the following is the most likely infectious etiology of this patient's symptoms?? Give your answer as a JSON dictionary in the form of {"option": "A-E", "text": "corresponding text"}. No yapping.
{'A': 'Candidia albicans', 'B': 'Epstein-Barr virus', 'C': 'Haemophilus i

In [7]:
question = """
Q:A child is in the nursery one day after birth. A nurse notices a urine-like discharge being expressed through the umbilical stump. What two structures in the embryo are connected by the structure that failed to obliterate during the embryologic development of this child??
{'A': 'Pulmonary artery - aorta', 'B': 'Bladder - yolk sac', 'C': 'Bladder - small bowel', 'D': 'Liver - umbilical vein', 'E': 'Kidney - large bowel'},
Give your answer as a JSON dictionary in the form of {"option": "A-E", "option_text": "corresponding text"}. No yapping.
""".strip()

chat = [{"role": "user", "content": question}]
input_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt").to(model.device)
output_ids = model.generate(input_ids, **GENERATE_KWARGS)
print(tokenizer.decode(output_ids[0]))



<bos><start_of_turn>user
Q: A 65-year-old man presents to the emergency department for sudden weakness. The patient states that he was at home enjoying his morning coffee when his symptoms began. He says that his left arm suddenly felt very odd and weak thus prompting him to come to the ED. The patient has a past medical history of diabetes, COPD, hypertension, anxiety, alcohol abuse, and PTSD. He recently fell off a horse while horseback riding but claims to not have experienced any significant injuries. He typically drinks 5-7 drinks per day and his last drink was yesterday afternoon. His current medications include insulin, metformin, atorvastatin, lisinopril, albuterol, and fluoxetine. His temperature is 99.5°F (37.5°C), blood pressure is 177/118 mmHg, pulse is 120/min, respirations are 18/min, and oxygen saturation is 93% on room air. On physical exam, you note an elderly man who is mildly confused. Cardiopulmonary exam demonstrates bilateral expiratory wheezes and a systolic murm

In [11]:
import numpy as np
from medqa_data import create_batch_predict

batch_predict = create_batch_predict(tokenizer, model, "gemma", GENERATE_KWARGS)
test_data = test_data.map(batch_predict, batched=True, batch_size=EVAL_BATCH_SIZE)

gemma_accuracy = (np.asarray(test_data["gemma_label"]) == np.asarray(test_data["true_label"])).mean()
print(f"Gemma accuracy: {round(gemma_accuracy * 100, 1)}%")

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Gemma accuracy: 20.0%


## Finetune

In [12]:

def create_chat_for_finetuning(samples: dict) -> list[str]:
    chat_texts = tokenizer.apply_chat_template(
        [
            [
                {"role": "user", "content": input},
                {"role": "assistant", "content": output},
            ]
            for input, output in zip(samples["input"], samples["output"])
        ],
        tokenize=False,
    )
    return [text.removeprefix("<bos>") for text in chat_texts]

display(create_chat_for_finetuning(train_data[:3]))

['<start_of_turn>user\nQ: A 65-year-old man presents to the emergency department for sudden weakness. The patient states that he was at home enjoying his morning coffee when his symptoms began. He says that his left arm suddenly felt very odd and weak thus prompting him to come to the ED. The patient has a past medical history of diabetes, COPD, hypertension, anxiety, alcohol abuse, and PTSD. He recently fell off a horse while horseback riding but claims to not have experienced any significant injuries. He typically drinks 5-7 drinks per day and his last drink was yesterday afternoon. His current medications include insulin, metformin, atorvastatin, lisinopril, albuterol, and fluoxetine. His temperature is 99.5°F (37.5°C), blood pressure is 177/118 mmHg, pulse is 120/min, respirations are 18/min, and oxygen saturation is 93% on room air. On physical exam, you note an elderly man who is mildly confused. Cardiopulmonary exam demonstrates bilateral expiratory wheezes and a systolic murmur

In [13]:
# pip install peft
from bitsandbytes.nn import Linear4bit, Linear8bitLt
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, Linear4bit) or isinstance(module, Linear8bitLt):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
        if 'lm_head' in lora_module_names: # needed for 16-bit
            lora_module_names.remove('lm_head')
    return list(lora_module_names)

modules = find_all_linear_names(model)
print(modules)

lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable:,} | total: {total:,} | Percentage: {trainable/total*100:.4f}%")

['v_proj', 'gate_proj', 'k_proj', 'q_proj', 'down_proj', 'up_proj', 'o_proj']
Trainable: 78,446,592 | total: 2,584,619,008 | Percentage: 3.0351%


In [25]:
# pip install trl
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer

tokenizer.padding_side = "right"
model.config.use_cache=False
torch.cuda.empty_cache()

collator = DataCollatorForCompletionOnlyLM(
    instruction_template="<start_of_turn>user\n",
    response_template="<start_of_turn>model\n",
    tokenizer=tokenizer,
    mlm=False,
)

trainer = SFTTrainer(
    model,
    args=SFTConfig(
        output_dir="/tmp/finetuned_gemma_2b",
        per_device_train_batch_size=TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=TRAIN_GRADIENT_ACCUMULATION_STEPS,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs=dict(use_reentrant=False),
        max_seq_length=2048,
        num_train_epochs=1,
        save_strategy="epoch",
        logging_steps=TRAIN_LOGGING_STEPS,
        eval_steps=TRAIN_LOGGING_STEPS,
        eval_strategy="steps",
    ),
    data_collator=collator,
    eval_dataset=test_set.select(range(min(len(test_set), EVAL_SIZE))),
    formatting_func=create_chat_for_finetuning,
    peft_config=lora_config,
    train_dataset=dataset["train"],
    tokenizer=tokenizer,
)

train_result = trainer.train()
display(train_result._asdict())

Step,Training Loss,Validation Loss
10,0.0559,0.214258


KeyboardInterrupt: 

In [None]:
tokenizer.padding_side = "left"
model.config.use_cache=True
model.gradient_checkpointing_disable()
model.eval();

## Evaluate finetuned model

In [None]:
batch_predict = create_batch_predict(tokenizer, model, "finetuned", GENERATE_KWARGS)
test_data = test_data.map(batch_predict, batched=True, batch_size=EVAL_BATCH_SIZE)

finetuned_accuracy = (np.asarray(test_data["finetuned_label"]) == np.asarray(test_data["true_label"])).mean()
print(f"Gemma accuracy: {round(gemma_accuracy * 100, 1)}%")
print(f"Finetuned accuracy: {round(finetuned_accuracy * 100, 1)}%")

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Gemma accuracy: 20.0%
Finetuned accuracy: 20.0%
