In [15]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device}")

Using cuda


In [3]:
# open questions from patients
# dataset contains inputs and outputs from a medical chatbot 
# might be useful for finetuning
dataset = load_dataset("Malikeh1375/medical-question-answering-datasets", "all-processed")

In [4]:
for i in range(3):
    print(f"PROMPT: \n {dataset['train'][0]['instruction']} {dataset['train'][i]['input']}")
    print(f"ANSWER: \n {dataset['train'][0]['output']}\n\n")

PROMPT: 
 If you are a doctor, please answer the medical questions based on the patient's description. Hey Just wondering.  I am a 39 year old female, pretty smallMy heart rate is around 97 to 106 at rest, and my BP is 140/90 and twice I get 175/118I did visit a doctor because I  didnt feel well past month or twoThen the doctor gave me a heart medicine to take the pulse down and BP  (its still in further examination.)But I wondering what it can be? Do I need the medicine really?  Is that bad ?
ANSWER: 
 hello and thank you for using chatbot. i carefully read your question and i understand your concern. i will try to explain you something and give you my opinion. we talk about hypertension if we have mean value that exceeds 140 / 90 mmhg. a person might have high value during emotional and physicals trees so it's mandatory to judge on mean values. usaly hypertension does not give any symptoms but left untreated he slowly modifies the heart. according to heart rhythm, the normal rate is 

In [9]:
# load any causal language model 
 
def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
    return tokenizer, model

model_name = "microsoft/Phi-3-mini-128k-instruct"

tokenizer, model = load_model(model_name)
model.to(device)

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
          (rotary_emb): Phi3LongRoPEScaledRotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm()
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm()
      )
    )
    (norm): Phi3RMSNorm()
  )
  (lm_head): Linear(in_features=3072, out

In [12]:
def generate_answer(prompt, temperature, top_p):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
            output_ids = model.generate(
                inputs['input_ids'],
                max_new_tokens=500,  
                pad_token_id=0,
                num_return_sequences=1,
                do_sample=True,
                temperature=temperature,
                top_k=50,
                top_p=top_p,
                no_repeat_ngram_size=3,
            )

    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return generated_text

In [13]:
# see some responses with different temperature
for i in range(3):
    prompt = f"{dataset['train'][0]['instruction']}: {dataset['train'][i]['input']}"
    print(f"PROMPT: \n {prompt}")
    print(f"ANSWER: \n {generate_answer(prompt, .1, .95)}\n\n")   # repeats prompt 

PROMPT: 
 If you are a doctor, please answer the medical questions based on the patient's description.: Hey Just wondering.  I am a 39 year old female, pretty smallMy heart rate is around 97 to 106 at rest, and my BP is 140/90 and twice I get 175/118I did visit a doctor because I  didnt feel well past month or twoThen the doctor gave me a heart medicine to take the pulse down and BP  (its still in further examination.)But I wondering what it can be? Do I need the medicine really?  Is that bad ?
ANSWER: 
 If you are a doctor, please answer the medical questions based on the patient's description.: Hey Just wondering.  I am a 39 year old female, pretty smallMy heart rate is around 97 to 106 at rest, and my BP is 140/90 and twice I get 175/118I did visit a doctor because I  didnt feel well past month or twoThen the doctor gave me a heart medicine to take the pulse down and BP  (its still in further examination.)But I wondering what it can be? Do I need the medicine really?  Is that bad ? 

In [14]:
for i in range(3):
    prompt = f"{dataset['train'][0]['instruction']}: {dataset['train'][i]['input']}"
    print(f"PROMPT: \n {prompt}")
    print(f"ANSWER: \n {generate_answer(prompt, .7, .98)}\n\n")   # repeats prompt 

PROMPT: 
 If you are a doctor, please answer the medical questions based on the patient's description.: Hey Just wondering.  I am a 39 year old female, pretty smallMy heart rate is around 97 to 106 at rest, and my BP is 140/90 and twice I get 175/118I did visit a doctor because I  didnt feel well past month or twoThen the doctor gave me a heart medicine to take the pulse down and BP  (its still in further examination.)But I wondering what it can be? Do I need the medicine really?  Is that bad ?
ANSWER: 
 If you are a doctor, please answer the medical questions based on the patient's description.: Hey Just wondering.  I am a 39 year old female, pretty smallMy heart rate is around 97 to 106 at rest, and my BP is 140/90 and twice I get 175/118I did visit a doctor because I  didnt feel well past month or twoThen the doctor gave me a heart medicine to take the pulse down and BP  (its still in further examination.)But I wondering what it can be? Do I need the medicine really?  Is that bad ? 