# Fine-tune LLM for Medical NLI task

This tutorial demonstrates how to fine-tune open-source Large Language Models, such as Google's 2024 gemma 2B- and 7B-parameter and Meta's 2023 Llama-2 7B- and 13B-parameter models, for medical NLP tasks such as NLI (natural language inference).  To load, train and run inference on consumer-grade hardware*, "quantization" and "low rank adaptation" methods are explained and used. The final results using the MedNLI benchmark dataset suggest that gemma-7b-it performed best, slightly outperforming Llama-2-13b.

*this notebook was run on an over-2-year-old laptop with a RTX-3080 16GB GPU

References:
- https://github.com/jgc128/mednli
- https://pytorch.org/blog/finetune-llms/
- https://github.com/TimDettmers/bitsandbytes
- https://github.com/huggingface/peft
- https://www.kaggle.com/code/lucamassaron/fine-tune-gemma-7b-it-for-sentiment-analysis

In [1]:
import os
import gc
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer,
                          TrainingArguments,
                          BitsAndBytesConfig)
from peft import LoraConfig, AutoPeftModelForCausalLM
from trl import SFTTrainer
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"pytorch version: {torch.__version__}, device: {device}")

pytorch version: 2.2.1+cu121, device: cuda:0


## Load model from HF hub or checkpoint

In [2]:
# pre-trained LLM to use
model_id = "meta-llama/Llama-2-7b-chat-hf"
model_id = "meta-llama/Llama-2-13b-chat-hf"
model_id = "google/gemma-2b-it"
model_id = "google/gemma-7b-it"
output_dir = os.path.join('models', model_id)

In [3]:
# select to load model from HF hub, or previously checkpoint-saved folder
from_checkpoint = model_id
#from_checkpoint = output_dir
#from_checkpoint = os.path.join(output_dir, 'checkpoint-474')

## Set training arguments

https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments

- optim — The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or adafactor.

- gradient_accumulation_steps — Number of updates steps to accumulate the gradients for, before performing a backward/update pass.

- lr_scheduler_type — The scheduler type to use.

- learning_rate — The initial learning rate.

- weight_decay — The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights.

- max_grad_norm — Maximum gradient norm (for gradient clipping).

- fp16 — Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.




In [4]:
# whether to train and/or evaluate
do_train = True
do_eval = True

In [5]:
# define arguments for trainer
training_args = TrainingArguments(
    output_dir=output_dir,
    optim="paged_adamw_32bit",
    lr_scheduler_type="cosine",
    learning_rate=2e-4,
    weight_decay=0.001,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    gradient_accumulation_steps=8,
    fp16=True,   # use mixed precision floats

    num_train_epochs=3.0,  #1.0,
    per_device_train_batch_size=2,
    evaluation_strategy='steps',
    save_steps=0.2,                  # checkpoint interval, -1 for none
    logging_steps=0.2,               # logging interval, -1 for none
    eval_steps=0.2,                  # evaluation interval, -1 for non
    eval_accumulation_steps=1,       # 1 for less memory but slower
    prediction_loss_only=True,       # False for full evaluation
    gradient_checkpointing=True,
    report_to="tensorboard"
)

## Get mednli dataset

MedNLI: A dataset annotated by doctors, performing a natural language inference task (NLI). The source of premise sentences, was the MIMIC-III v1.3 (Johnson et al., 2016) database, which 2,078,705 clinical notes written by healthcare professionals in English.  The hypothesis sentences were generated by clinicians. They were asked to write three sentences (hypotheses): 1) A clearly true statement, 2) A clearly false statement, and 3) A statement that might be true or false. This procedure produces three training pairs of sentences for each initial premise with three different labels: entailment, contradiction, and neutral.

Romanov, Alexey and Shivade, Chaitanya, Lessons from Natural Language Inference in the Clinical Domain, 2018. http://arxiv.org/abs/1808.06752,


In [6]:
# helper to read jsonl
import json
def read_jsonl(filename, max_samples=None):
    """helper to read jsonl files as pandas dataframe"""
    lines = []
    with open(filename) as f:
        lines = f.read().splitlines()
    line_dicts = [json.loads(line) for line in lines]
    max_samples = max_samples or len(line_dicts)
    return pd.DataFrame(line_dicts).iloc[:max_samples]

In [7]:
# read in train, dev and test sets
X_train = read_jsonl('mednli/mli_train_v1.jsonl')
y_train = X_train['gold_label']
X_dev = read_jsonl('mednli/mli_dev_v1.jsonl')
y_dev = X_dev['gold_label']
X_test = read_jsonl('mednli/mli_test_v1.jsonl')
y_test = X_test['gold_label']

In [8]:
# Show premise, hypothesis and label for dev examples
X_dev.iloc[:6][['sentence1', 'sentence2', 'gold_label']]

Unnamed: 0,sentence1,sentence2,gold_label
0,"No history of blood clots or DVTs, has never h...",Patient has angina,entailment
1,"No history of blood clots or DVTs, has never h...",Patient has had multiple PEs,contradiction
2,"No history of blood clots or DVTs, has never h...",Patient has CAD,neutral
3,Over the past week PTA he has been more somnol...,He has been less alert over the past week,entailment
4,Over the past week PTA he has been more somnol...,Over the past week he has been alert and orie...,contradiction
5,Over the past week PTA he has been more somnol...,He is disorientated and complains of weakness,neutral


In [9]:
# format premise and hypothesis as chat prompt
def as_test_prompt(ex):
    """prompt for response to test example"""
    return f"""
            Use the following context to determine if the factuality of the
            statement enclosed in square brackets at the end is entailment,
            neutral, or contradiction, and return the answer in 1 word as
            "entailment" or "neutral" or "negative":

            {ex['sentence1']}

            [{ex['sentence2']}]

            Answer:
            """.strip()

# format premise and hypothesis with gold label as training example
def as_prompt(ex):
    """training example"""
    return f"{as_test_prompt(ex)}  {ex['gold_label']}".strip()

In [10]:
# reformat all examples as prompts
X_train = pd.DataFrame(X_train.apply(as_prompt, axis=1), columns=["prompt"])
train_data = Dataset.from_pandas(X_train)
X_dev = pd.DataFrame(X_dev.apply(as_prompt, axis=1), columns=["prompt"])
dev_data = Dataset.from_pandas(X_dev)
X_test = pd.DataFrame(X_test.apply(as_test_prompt, axis=1), columns=["prompt"])
test_data = Dataset.from_pandas(X_test)

## Get quantized pre-trained model

https://huggingface.co/docs/bitsandbytes/main/en/index

bitsandbytes enables accessible large language models via k-bit quantization for PyTorch. 8-bit quantization enables large language model inference with only half the required memory and without any performance degradation. This method is based on vector-wise quantization to quantize most features to 8-bits and separately treating outliers with 16-bit matrix multiplication. 4-bit quantization enables large language model training with several memory-saving techniques that don’t compromise performance. This method quantizes a model to 4-bits and inserts a small set of trainable low-rank adaptation (LoRA) weights to allow training.

- quantization: a technique to reduce the computational and memory costs of running inference by representing the weights and activations with low-precision data types like 8-bit integer (int8) instead of the usual 32-bit floating point (float32). Reducing the number of bits means the resulting model requires less memory storage, consumes less energy (in theory), and operations like matrix multiplication can be performed much faster with integer arithmetic. It also allows to run models on embedded devices, which sometimes only support integer data types.

- LLM.int8(): a quantization method that doesn’t degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output.

- nf4: a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that is normalized into the range [-1, 1].


In [11]:
# Load model from HF hub or local folder
compute_dtype = getattr(torch, "float16")
if do_train:   # Load and quantize pre-trained model to start fine-tune
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=False,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype
    )
    model = AutoModelForCausalLM.from_pretrained(
        from_checkpoint,
        device_map=device,  # "auto" may be slower because offload to cpu
        quantization_config=bnb_config
    )
else:    # Load previously checkpoint-saved model to continue fine-tune or evaluate
    model = AutoPeftModelForCausalLM.from_pretrained(
        from_checkpoint,
        torch_dtype=compute_dtype,
        return_dict=False,
        low_cpu_mem_usage=True,
        device_map=device,  # "auto" may be slower because offload to cpu
    )

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Fix missing pad_token if error
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
    model.resize_token_embeddings(len(tokenizer))
model.config.use_cache = False
model.config.pretraining_tp = 1

Loading checkpoint shards: 100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


In [12]:
print(f"CUDA: {torch.cuda.memory_allocated()/1e9:.2f} GB")

CUDA: 6.01 GB



## Helpers for evaluating model predictions 


In [13]:
def predict(X_test, model, tokenizer):
    """Generate model predictions on test set"""
    y_pred = []
    for i in tqdm(range(len(X_test))):
        prompt = X_test.iloc[i]["prompt"]
        input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
        outputs = model.generate(**input_ids,
                                 max_new_tokens=4,  # 8
                                 do_sample=False,   # True
                                 temperature=0.00,  # 0.01
        )
        result = tokenizer.decode(outputs[0][len(input_ids[0]):])   
        answer = result.lower() # result.split("=")[-1].lower()
        if "entailment" in answer:
            y_pred.append("entailment")
        elif "contradiction" in answer:
            y_pred.append("contradiction")
        else:
            y_pred.append("neutral")
    return y_pred

In [14]:
def evaluate(y_true, y_pred, verbose=False):
    """Evaluate accuracy and confusion matrix of model predictions"""
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    # Calculate accuracy
    accuracy = accuracy_score(y_true=y_true, y_pred=y_pred)
    if verbose:
        print(f'Accuracy: {accuracy:.3f}')
                
    # Generate classification report
    class_report = classification_report(y_true=y_true, y_pred=y_pred)
    if verbose:
        print('\nClassification Report:')
        print(class_report)
    
    # Generate confusion matrix
    conf_matrix = confusion_matrix(y_true=y_true, y_pred=y_pred)
    if verbose:
        print('\nConfusion Matrix:')
        print(conf_matrix)

    return accuracy

In [15]:
# Evaluate pre-trained that has *NOT* been tuned for downstream NLI task
y_pred = predict(X_test, model, tokenizer)
evaluate(y_test, y_pred, verbose=True)

prompt = X_test.iloc[42]["prompt"]
print('Prompt:', prompt)

input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids,
                         max_new_tokens=4,
                         do_sample=False,
                         temperature=0.00
                         )
result = tokenizer.decode(outputs[0][len(input_ids[0]):])
answer = result.lower()
print('Answer:', answer)


100%|██████████| 1422/1422 [06:05<00:00,  3.89it/s]


Accuracy: 0.333

Classification Report:
               precision    recall  f1-score   support

contradiction       0.00      0.00      0.00       474
   entailment       0.00      0.00      0.00       474
      neutral       0.33      1.00      0.50       474

     accuracy                           0.33      1422
    macro avg       0.11      0.33      0.17      1422
 weighted avg       0.11      0.33      0.17      1422


Confusion Matrix:
[[  0   0 474]
 [  0   0 474]
 [  0   0 474]]
Prompt: Use the following context to determine if the factuality of the
            statement enclosed in square brackets at the end is entailment,
            neutral, or contradiction, and return the answer in 1 word as
            "entailment" or "neutral" or "negative":

            He could think of what he wanted to say but was having trouble getting the words out.

            [ The patient is having trouble speaking. ]

            Answer:
Answer: 

the factu



## Train model

- PEFT (Parameter-Efficient Fine-Tuning) methods enable efficient adaptation of large pretrained models to various downstream applications by only fine-tuning a small number of (extra) model parameters instead of all the model's parameters. PEFT can save storage by avoiding full finetuning of models on each of downstream task or dataset. One of the main benefits of using PEFT is the huge savings in compute and storage.

- LoRA (Low-Rank Adaptation) works by attaching extra trainable parameters into a model and decomposing a large weight matrix into two smaller, low-rank matrices (called update matrices). These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn’t receive any further adjustments. To produce the final results, both the original and the adapted weights are combined.

- QLoRA is a 4-bit quantization method that enables large language model training with several memory-saving techniques that don’t compromise performance. This method quantizes a model to 4-bits and inserts a small set of trainable low-rank adaptation (LoRA) weights to allow training.

https://pytorch.org/blog/finetune-llms/
https://github.com/huggingface/peft
https://huggingface.co/docs/bitsandbytes/main/en/index

Prepare a model for training with a PEFT method such as LoRA by wrapping the base model with PEFT configuration

- r — attention dimension (the “rank”)

- lora_alpha — scaling factor for the weight matrices, a higher alpha assigns more weight to the LoRA activations

- lora_dropout — The dropout probability for Lora layers.

In [16]:
# Set config for PEFT
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

In [17]:
# Set config for SFT Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=dev_data,
    peft_config=peft_config,
    dataset_text_field="prompt",
    tokenizer=tokenizer,
    args=training_args,
    packing=False,
    max_seq_length=1024,
)

Map: 100%|██████████| 11232/11232 [00:00<00:00, 17323.32 examples/s]
Map: 100%|██████████| 1395/1395 [00:00<00:00, 23797.54 examples/s]


In [18]:
if do_train:
    for _ in range(10):    # reclaim memory before training
        with torch.no_grad():
            torch.cuda.empty_cache()
    gc.collect()
    training_stats = trainer.train()

    # Save trained model
    #trainer.model.save_pretrained(output_dir)
    trainer.save_model(output_dir)

 20%|██        | 422/2106 [45:33<3:00:58,  6.45s/it]

{'loss': 1.8024, 'grad_norm': 0.32632938027381897, 'learning_rate': 0.00018521172305285236, 'epoch': 0.6}


                                                    
 20%|██        | 422/2106 [47:59<3:00:58,  6.45s/it]

{'eval_loss': 0.7946441173553467, 'eval_runtime': 145.5527, 'eval_samples_per_second': 9.584, 'eval_steps_per_second': 1.202, 'epoch': 0.6}


 40%|████      | 844/2106 [1:33:14<2:16:52,  6.51s/it]

{'loss': 0.6638, 'grad_norm': 0.3769872784614563, 'learning_rate': 0.00013623384610073693, 'epoch': 1.2}


                                                      
 40%|████      | 844/2106 [1:35:36<2:16:52,  6.51s/it]

{'eval_loss': 0.7909778952598572, 'eval_runtime': 141.786, 'eval_samples_per_second': 9.839, 'eval_steps_per_second': 1.234, 'epoch': 1.2}


 60%|██████    | 1266/2106 [2:20:39<1:30:41,  6.48s/it]

{'loss': 0.5796, 'grad_norm': 0.4613594710826874, 'learning_rate': 7.25118606258684e-05, 'epoch': 1.8}


                                                       
 60%|██████    | 1266/2106 [2:23:01<1:30:41,  6.48s/it]

{'eval_loss': 0.8163220882415771, 'eval_runtime': 141.8245, 'eval_samples_per_second': 9.836, 'eval_steps_per_second': 1.234, 'epoch': 1.8}


 80%|████████  | 1688/2106 [3:08:42<45:12,  6.49s/it]   

{'loss': 0.4816, 'grad_norm': 0.5449468493461609, 'learning_rate': 1.9975221274455323e-05, 'epoch': 2.4}


                                                     
 80%|████████  | 1688/2106 [3:11:07<45:12,  6.49s/it]

{'eval_loss': 0.8947425484657288, 'eval_runtime': 145.2768, 'eval_samples_per_second': 9.602, 'eval_steps_per_second': 1.205, 'epoch': 2.4}


100%|██████████| 2106/2106 [3:55:53<00:00,  6.72s/it]  


{'train_runtime': 14153.7833, 'train_samples_per_second': 2.381, 'train_steps_per_second': 0.149, 'train_loss': 0.7942605955987914, 'epoch': 3.0}


In [None]:
# pd.Series(training_stats.metrics)

In [19]:
print(f"CUDA: {torch.cuda.memory_allocated()/1e9:.2f} GB")
# !nvidia-smi

CUDA: 7.90 GB


## Generates responses and evaluate model

In [20]:
if do_eval:
    for _ in range(10):    # reclaim memory
        with torch.no_grad():
            torch.cuda.empty_cache()
    gc.collect()
    
    y_pred = predict(X_test, trainer.model, tokenizer)
    evaluate(y_test, y_pred, verbose=True)

    evaluation = pd.DataFrame({'prompt': X_test["prompt"],
                               'y_test': y_test,
                               'y_pred': y_pred})
    eval_dir = os.path.join(output_dir, 'logs')
    os.makedirs(eval_dir, exist_ok=True)
    evaluation.to_csv(
        os.path.join(eval_dir, datetime.now().strftime("%Y%m%d-%H%M")),
        index=False
    )

  0%|          | 0/1422 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
100%|██████████| 1422/1422 [18:15<00:00,  1.30it/s]

Accuracy: 0.878

Classification Report:
               precision    recall  f1-score   support

contradiction       0.92      0.94      0.93       474
   entailment       0.86      0.87      0.86       474
      neutral       0.85      0.82      0.84       474

     accuracy                           0.88      1422
    macro avg       0.88      0.88      0.88      1422
 weighted avg       0.88      0.88      0.88      1422


Confusion Matrix:
[[446   8  20]
 [ 15 412  47]
 [ 23  60 391]]



