<a href="https://colab.research.google.com/github/pattichis/AI4All-Med/blob/main/Session_4_2_Medichat_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medichat Example

We found an open-source medical LLM through HuggingFace: https://huggingface.co/collections/sethuiyer/medical-llms.

Below is the associated example code on their model website: https://huggingface.co/sethuiyer/Medichat-Llama3-8B.

_For this to run faster, go to Runtime > Change Runtime Type, and select a GPU option (with high-RAM if available)._

## **Experiments / Questions**
We learned that prompts can have the following list of instructions: persona, instruction, context, format, audience, tone, and data (that the model should perform the instruction on).

1. What subset of the above list make sense to belong in the sys_message?
  - Hint: This message gets prepended to every user input.
2. Modify the question. How well does it do on your medical question?
3. Modify the sys_message. How does this change the perdformance on your previously run question?

In [None]:
!pip install -q bitsandbytes

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [None]:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
class MedicalAssistant:
    def __init__(self, model_name="sethuiyer/Medichat-Llama3-8B", device="cuda"):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            # device='auto',
            quantization_config=quantization_config
            )#.to(self.device)
        self.sys_message = '''
        You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
        provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
        '''

    def format_prompt(self, question):
        messages = [
            {"role": "system", "content": self.sys_message},
            {"role": "user", "content": question}
        ]
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        return prompt

    def generate_response(self, question, max_new_tokens=512):
        prompt = self.format_prompt(question)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
        answer = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip()
        return answer

### **NOTE: Only run the cell below once or else your runtime will crash!**

In [None]:
assistant = MedicalAssistant()

Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

In [None]:
question = '''
Symptoms:
Dizziness, headache, and nausea.

What is the differential diagnosis?
'''
response = assistant.generate_response(question)
print(response)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|im_start|>system

        You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
        provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.<|im_end|>
<|im_start|>user

Symptoms:
Dizziness, headache, and nausea.

What is the differential diagnosis?<|im_end|>
<|im_start|>assistant
Differential diagnosis for dizziness, headache, and nausea:

1. Vestibular disorders: Benign paroxysmal positional vertigo (BPPV), labyrinthitis, vestibular neuritis, Meniere's disease, and vestibular migraine.
2. Inner ear disorders: Otosclerosis, otitis media, and ototoxicity.
3. Central nervous system disorders: Labyrinthitis, multiple sclerosis, and brainstem stroke.
4. Cardiovascular disorders: Orthostatic hypotension, vasovagal syncope, and cardiac arrhythmias.
5. Neurovascular disorders: Migraine, tension headache, and cluster headache.
6. Metabolic disorders: Hypoglycemia, hypergl