# MediBot!

The purpose of this notebook is to gain a glimpse into the world of fine-tuning. Although LLM Agents with RAG or LLMs in a zero-shot setting are extremely powerful, we cannot ignore fine-tuning completely. In cases where we cannot cover the instructions through a few examples in context, fine-tuning would help.

Fortunately, due to the advances in fine-tuning, it isn't as resource intensive as it used to be. But more on that later.

Let us try to build a small chatbot that would answer medical questions using the latest fine-tuning techniques.

## Load Libraries

In [1]:
%%capture

# Load Libraries

!pip install huggingface
!pip install transformers
!pip install datasets
!pip install torch
# !pip install adapter-transformers
!pip install peft
!pip install -U bitsandbytes
# !pip install trl

import huggingface
import transformers
import datasets
import torch
import peft

## Load Model and Dataset

Lets now load a simple model
We choose tinyllama here but you can use whatever based on your interests and resources

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import login

# Authenticate with Read Token (to download model)
login("HF_TOKEN")

# Load a model for fine-tuning
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto")
model.to(device) # Put the model on GPU if available
tokenizer = AutoTokenizer.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

Now, why do we even need to fine-tune the model? These large language models already get a lot of data to look at during pre-training right? Well yes but they look at the data from Seq2Seq perspective. Meaning that they are always trying to predict the output sequence token by token given an input sequence and aren't really trained to chat with you.

Think about it. A lot of information from the internet isn't chatlogs, so why would the model be able to chat with you unless you make it so?

This is why we start off with a chat model in this case, as you might have noticed earlier. Despite that it can be improved to handle extremely domain-specific questions that it might not have seen before.

Let's take an example.

In [3]:
def invoke(query):
  inputs = tokenizer(query, return_tensors="pt").to(device)
  outputs = model.generate(**inputs, max_new_tokens=100)
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))

invoke("What is Crohn's disease?")

What is Crohn's disease? How does it affect the body?


The response does not make any sense. This is why we need to fine-tune. But on what?

Normally we have Question-Answer datasets or Instruction datasets to fine-tune

Lets load the Question Answer dataset now

In [4]:
from datasets import load_dataset

# Load the MedQuAD dataset
dataset = load_dataset("lavita/MedQuAD")
dataset = dataset["train"].train_test_split(test_size=0.1)
print(dataset)

README.md:   0%|          | 0.00/2.77k [00:00<?, ?B/s]

(…)-00000-of-00001-e36383d177026d53.parquet:   0%|          | 0.00/10.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/47441 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['document_id', 'document_source', 'document_url', 'category', 'umls_cui', 'umls_semantic_types', 'umls_semantic_group', 'synonyms', 'question_id', 'question_focus', 'question_type', 'question', 'answer'],
        num_rows: 42696
    })
    test: Dataset({
        features: ['document_id', 'document_source', 'document_url', 'category', 'umls_cui', 'umls_semantic_types', 'umls_semantic_group', 'synonyms', 'question_id', 'question_focus', 'question_type', 'question', 'answer'],
        num_rows: 4745
    })
})


So the dataset contains 47441 rows/datapoints with each datapoint has a bunch of metadata along with the question and answer that we need for the fine-tuning.

## Fine-Tuning

Now lets do fine-tuning. But what is it exactly?

Fine-Tuning is the process of adjusting the model weights to make sure that it performs well on the downstream task of our choice. We do this because we want to take advantage of the models that have already been trained on large datasets before so that all we have to do is submit a small scale dataset for our downstream task. However, if you look at these large models, updating the model weights for every single task would mean the death of several servers. We need a way to do this effectively. This is where PEFT comes in.


PEFT or Parameter Efficient Fine Tuning is an approach to fine-tuning where the goal is to achieve the desired result, by not updating every single parameter/weight in the model as that is way too much to do. The following are some ways of doing it.

<ul>
  <li>Adapter Training: An extra layer is added in the architecture which is trained alone for the specific task while other parameters are frozen. Since you can repeat this for multiple downstream tasks, you can have adapter training done for mulitple tasks</li>
  <li>LoRA: Low Rank Adaptation. This is an efficient way of doing the typical fine-tuning. Instead of updating every single weight, we freeze the original weights and train smaller matrices which will then be added back to the original weights, effectively updating them. Mathematically it looks like,
  <br/>
  W' = W + BA; where B, A's dimensions are decided by the rank you choose</li>
</ul>


There are other approaches like Prefix Tuning, Prompt Tuning (not Prompt Engineering) and so on. We will stick to LoRA based fine-tuning for our case here.

### Data Preparation

Lets briefly look at the tokenization process too. What we need to do here is basically format the question answer dataset into a format that can be used for training. This means preparing a user prompt structure along with the response.

We have chosen a model that learns via Causal Language Modeling. What this means is that it predicts the next token based on the previous tokens, i.e., it is autoregressive in nature. This implies that the model does not need explicit labels, only the next tokens. But it need not predict all of them! Just the answer part. Thus, as the input, it takes the whole prompt, response pair and as the label it takes only the answer part. The prompt (with the system prompt and the user query) becomes the context to the model's prediction

As such we prepare the dataset to include both of these.

In [5]:
# Tokenization function
def tokenize_function(examples):
  tokenizer.pad_token = tokenizer.eos_token

  # We format the dataset as per the chat template
  user_prompt = [f"""<|system|>
  You are a helpful medical assistant. Answer questions clearly and concisely.
  <|user|>
  {q}
  <|assistant|>
  """
  for q in examples["question"]
  ]

  answer = [f" {str(a or '')}" for a in examples["answer"]]
  texts = [user_prompt[i] + answer[i] for i in range(len(user_prompt))]

  # Tokenize all texts in the batch
  inputs = tokenizer(texts, padding="max_length", truncation=True, max_length=1000).to(device)

  labels = inputs["input_ids"].copy()

  # Prepare the labels
  for i in range(len(inputs["input_ids"])):
    user_prompt_ids = tokenizer(user_prompt[i], truncation=True, add_special_tokens=False)["input_ids"]
    prompt_length = len(user_prompt_ids)  # Length of user prompt tokens
    labels[i][:prompt_length] = [tokenizer.pad_token_id] * prompt_length  # Mask out the prompt tokens

  inputs["labels"] = labels
  return inputs

# Tokenize dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)


Map:   0%|          | 0/42696 [00:00<?, ? examples/s]

Map:   0%|          | 0/4745 [00:00<?, ? examples/s]

### LoRA Training

To do the training, we have to define the LoRA configuration and initiate the corresponding trainer. But to get a deeper insight into freezing the weights, I wrote a simple script in the following cell to see how the weight freezing would look like

In [6]:
from torch import nn

for param in model.parameters():
  param.requires_grad = False  # freeze the model - train adapters later
  if param.ndim == 1:
    # cast the small parameters (e.g. layernorm) to fp32 for stability
    param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()  # reduce number of stored activations
model.enable_input_require_grads()

In [7]:
from transformers import TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model


# PEFT Config
peft_config = LoraConfig(
    r=16, # Rank hyperparameter
    lora_alpha=32, # alpha scaling
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM" # simply states the model is trained to predict next token based on previous tokens
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

model.config.pad_token_id = tokenizer.eos_token_id

# Training Arguments
training_args = TrainingArguments(
    output_dir="./adapter_chat_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"].select(range(1000)), # select only a subset of the dataset for training
    eval_dataset=tokenized_datasets["test"]
)

# Train model
trainer.train()

trainable params: 2,252,800 || all params: 1,102,301,184 || trainable%: 0.2044


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mindra927[0m ([33mindra927-leuphana[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss
1,No log,0.497813
2,No log,0.212333


TrainOutput(global_step=93, training_loss=1.1162769768827705, metrics={'train_runtime': 3815.3126, 'train_samples_per_second': 0.786, 'train_steps_per_second': 0.024, 'total_flos': 1.8213890752512e+16, 'train_loss': 1.1162769768827705, 'epoch': 2.928})

On a T4 GPU in colab, this training took around one hour. And we didnt even use the full dataset. This is the reason why we cant just go around fine-tuning in a blunt manner, and properly utilize techniques like Quantization to use the resources available as well as possible.

## Inference

Having fine-tuned the model on a subset of the MedQuAD dataset, lets see how well it performs in conversations.

In [19]:
def invoke(query):
  inputs = tokenizer(query, return_tensors="pt").to(device)
  outputs = model.generate(**inputs, max_new_tokens=100)
  return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[1]

In [20]:
chat_template = """<|system|>
  You are a helpful medical assistant. Answer questions clearly and concisely.
  <|user|>
  {question}
  <|assistant|>
  """

invoke(chat_template.format(question = "What is crohn's disease"))

"\n   Crohn's disease is a chronic inflammatory bowel disease (IBD) that affects the digestive system. It causes inflammation and damage to the digestive tract, including the small intestine, large intestine, and rectum. The symptoms of Crohn's disease can vary depending on the location and severity of the inflammation. Some common symptoms include abdominal pain, diarrhea, fever"

In [21]:
invoke(chat_template.format(question = "What to do for Alzheimer's?"))

"\n  1. Listen to the patient: Listen to the patient's concerns and provide them with information about Alzheimer's disease.\n  2. Provide comfort: Provide comfort to the patient and their family by offering them support and guidance.\n  3. Provide information: Provide information about Alzheimer's disease and its symptoms to the patient and their family.\n  4. Provide education: Provide education about Alzheimer's disease"

In [22]:
invoke(chat_template.format(question = "Are we closer to finding a cure for cancer?"))

'\n   Yes, we are closer to finding a cure for cancer. The development of new treatments and therapies is ongoing, and researchers are making significant progress in understanding the underlying mechanisms of cancer and developing targeted treatments. However, there is still a long way to go before a cure can be found, and many challenges remain in the field of cancer research.'

Although the training wasnt done on the whole dataset due to resource constraints, we still got pretty decent results in our conversations.


That was it for this notebook. Thanks for reading and hope it helps!