In [1]:
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

!pip install --no-deps trl peft accelerate bitsandbytes
!pip install xformers transformers datasets torch trl

# !pip install ctransformers langchain streamlit

Collecting unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-nh2w_9fx/unsloth_05ba6c23e54e4c33a52abcda84138a7f
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-nh2w_9fx/unsloth_05ba6c23e54e4c33a52abcda84138a7f
  Resolved https://github.com/unslothai/unsloth.git to commit 27fa021a7bb959a53667dd4e7cdb9598c207aa0d
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tyro (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Downloading tyro-0.8.4-py3-none-any.whl (102 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.4/102.4 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting datasets>=2.16.0 (from unsloth[colab-ne

In [18]:
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported
import torch
from datasets import load_from_disk
from datasets import load_dataset
from transformers import TrainingArguments
from trl import SFTTrainer


# dataset = load_dataset("lighteval/med_dialog",'')
dataset = load_dataset('lighteval/med_dialog', 'healthcaremagic')

dataset = dataset.remove_columns(['tgt','id'])


def generate_prompt(Instruction:str ,user:str , system:str)->str:
    return f"""
        Below is an instruction that describes a task, paired with an input that provides further context.
        Write a response that appropriately completes the request.

        ### Instruction -
        {Instruction}

        ### User Input -
        {user}

        ### Your Response -
        {system}
     """


def parse_conversation_to_df(text):
    text = text['src']
    data = {'prompt':""}
    messages = text.split("Patient: ")[1:]
    instruction = """ You are an AI medical assistant.Your role is to engage in a thoughtful
        dialogue with user to fully understand  symptoms and health concerns."""
    # messages = ' '.join(**messages.strip().split())

    # print(messages)
    for msg in messages:
        msg = ' '.join(msg.strip().split())

        # Splitting the message by "Doctor: "
        parts = msg.rsplit("Doctor: ", 1)
        # Extracting patient's message
        patient_msg = parts[0].strip()
        # Extracting doctor's message
        doctor_msg = parts[1].strip() if len(parts) > 1 else ""
        # Removing unnecessary spaces
        patient_msg = ' '.join(patient_msg.split())
        doctor_msg = ' '.join(doctor_msg.split())
        keyword = "Regards"
        if keyword in doctor_msg:
            doctor_msg = doctor_msg.split(keyword)[0] + keyword
        # Appending messages to the data dictionary
        data["prompt"]=generate_prompt(Instruction=instruction, system=doctor_msg.strip(), user=patient_msg.strip())
    # df = pd.DataFrame(data)
    return data


print(dataset)
from datasets import concatenate_datasets

# Combine train and test datasets
combined_dataset = concatenate_datasets([dataset['train'], dataset['test']])


dataset = concatenate_datasets([combined_dataset, dataset['validation']])



dataset = dataset.map(parse_conversation_to_df).remove_columns(['src']).with_format('pt')

# print(dataset['train'][0])

dataset



DatasetDict({
    train: Dataset({
        features: ['src'],
        num_rows: 181122
    })
    validation: Dataset({
        features: ['src'],
        num_rows: 22641
    })
    test: Dataset({
        features: ['src'],
        num_rows: 22642
    })
})


Map:   0%|          | 0/226405 [00:00<?, ? examples/s]

Dataset({
    features: ['prompt'],
    num_rows: 226405
})

In [None]:


max_seq_length = 1024 # Supports RoPE Scaling interally, so choose any!

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/mistral-7b-v0.3-bnb-4bit",      # New Mistral v3 2x faster!
    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    "unsloth/llama-3-8b-bnb-4bit",           # Llama-3 15 trillion tokens model 2x faster!
    "unsloth/llama-3-8b-Instruct-bnb-4bit",
    "unsloth/llama-3-70b-bnb-4bit",
    "unsloth/Phi-3-mini-4k-instruct",        # Phi-3 2x faster!
    "unsloth/Phi-3-medium-4k-instruct",
    "unsloth/mistral-7b-bnb-4bit",
    "unsloth/gemma-7b-bnb-4bit",             # Gemma 2.2x faster!
] # More models at https://huggingface.co/unsloth


model_name = "unsloth/mistral-7b-v0.3-bnb-4bit"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,

    dtype = None,
    load_in_4bit = True,
)



# Do model patching and add fast LoRA weights
model1 = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    max_seq_length = max_seq_length,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)



trainer = SFTTrainer(
    model = model1,
    train_dataset = dataset,
    dataset_text_field = "prompt",
    max_seq_length = max_seq_length,
    tokenizer = tokenizer,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 10,
        max_steps = 30,
        learning_rate = 1e-5,
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        output_dir = "outputs",
        optim = "adamw_8bit",
        seed = 3407,
    ),
)

trainer.train()

model1.save_pretrained(model_name + "_lora_model1")

# Go to https://github.com/unslothai/unsloth/wiki for advanced tips like
# (1) Saving to GGUF / merging to 16bit for vLLM
# (2) Continued training from a saved LoRA adapter
# (3) Adding an evaluation loop / OOMs
# (4) Cutomized chat templates

# Gpu 1 - full utilization
# vram nearly 7gb
# temperature 60 to 70 C
# power >220W



==((====))==  Unsloth: Fast Mistral patching release 2024.5
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. Xformers = 0.0.26.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Map:   0%|          | 0/226405 [00:00<?, ? examples/s]

In [14]:


max_seq_length = 1024 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.



from unsloth import FastLanguageModel

# model_tuned, tokenizer = FastLanguageModel.from_pretrained(
#     model_name = model_name + "_lora_model1", # YOUR MODEL YOU USED FOR TRAINING
#     max_seq_length = max_seq_length,
#     load_in_4bit = load_in_4bit,



# )
# FastLanguageModel.for_inference(model_tuned) # Enable native 2x faster inference


from transformers import TextStreamer

# text_streamer = TextStreamer(tokenizer)


def summarize(model, tokenizer ,  user: str):
    instruction =  """You are an AI medical assistant to have caring,
                    thoughtful dialogues to understand people's symptoms and health concerns.
                    You should provide disease name , medications needed for patient , food to avoid
         """
    text = generate_prompt(Instruction =instruction , user = user, system = "" )
    # print(text)
    inputs = tokenizer(text, return_tensors="pt")
    inputs_length = len(inputs["input_ids"][0])
    with torch.inference_mode():
        outputs = model1.generate(**inputs,max_new_tokens=220)
    return tokenizer.decode(outputs[0][inputs_length:], skip_special_tokens=True)




summary = summarize(model=model1, tokenizer=tokenizer, user="""I get cramps on top of my left forearm and hand and it causes my hand and fingers to draw up and it hurts.
                                                              It mainly does this when I bend my arm. I ve been told that I have a slight pinch in a nerve in my neck.
                                                              Could this be a cause? I don t think so.""")





print('After Fine tuning - ',summary)





Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


After Fine tuning -  Hi, 
      Welcome to 'Ask A Doctor' service. 
      I have gone through your query and I understand your concern. 
      I would like to suggest you to take a good rest and avoid any kind of strenuous activity. 
      Take a pain killer like ibuprofen 400 mg 2-3 times a day. 
      If the pain persists, consult your doctor. 
      Hope I have answered your query. 
      Let me know if I can assist you further. 
      Regards, 
      Dr. S.K. Srivastava 
      MBBS, MD (Medicine) 
      Consultant Physician and Diabetologist

      #### Your Response


In [17]:

def summarize(model, tokenizer ,  user: str):
    instruction =  """You are an AI medical assistant to have caring,
                    thoughtful dialogues to understand people's symptoms and health concerns.
                    You should provide disease name , medications needed for patient , food to avoid
         """
    text = generate_prompt(Instruction =instruction , user = user, system = "" )
    # print(text)
    inputs = tokenizer(text, return_tensors="pt")
    inputs_length = len(inputs["input_ids"][0])
    with torch.inference_mode():
        outputs = model1.generate(**inputs,max_new_tokens=128)
    return tokenizer.decode(outputs[0][inputs_length:], skip_special_tokens=True)


summary = summarize(model=model1, tokenizer=tokenizer, user=""" I am active, healthy and strong, just turned 51, female, exercise class twice a week, pretty busy, no allergies or medications.
                                                                For the past two weeks my muscles and joints are achy and actually hurt. They feel stiff like I did a new exercise and then did not stretch.
                                                                Have a big red bump that I thought was a black fly bite, it is sore and hard on my shin like I bumped it.
                                                                Does not look like a tick bite. Any ideas why the aches? r""")

print('After Fine tuning - ',summary)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


After Fine tuning -  Dear Sir, 
      Thank you for your query. 
      The blood report is normal. 
      Regards, 
      Dr. S.K. Gupta 
      MBBS, MD (Pediatrics) 
      Pediatrician 
      New Delhi, India

      #### Your Answer -
      #### Please put your answer below ####
      #### If you think your answer is most appropriate, mark it as Correct ####
      ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ##


In [15]:
summary = summarize(model=model1, tokenizer=tokenizer, user="""i wake up every morning for the past 90days with watery eyes and runny nose a cough and sore throat which sometimes last all day,
                                                                what is your best suggestion,
                                                                i tried several otc medication with little relief. what can i try to help me with my condition. Thank You.""")

print('After Fine tuning - ',summary)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


After Fine tuning -  Hi,Welcome to HCM.I have gone through your query.I would like to suggest you to take a blood test for complete blood count and thyroid profile.If the thyroid profile is abnormal,you can take thyroid hormone replacement therapy.If the complete blood count is abnormal,you can take antibiotics.Hope this helps.
      Regards,Dr.S.S.Ramakrishnan,MBBS,MD,DM,FRCP(Glasgow),FRCP(Edin),FRCP(London),FACC,FESC,FSCAI,FSCCT,FSCMR,FESOT,FESC(I),FESC(II),FESC(III),FESC(IV),FESC(V),FESC(VI),FESC(VII),FESC(VIII),FESC(IX),FESC(X),FESC(XI),FESC(XII),FESC
