<a href="https://colab.research.google.com/github/swapnildahare/-AlpaCare-Medical-Instruction-Assistant-/blob/main/%22AlpaCare_Medical_Instruction_Assistant%22.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [121]:
pip install -q transformers accelerate datasets bitsandbytes safetensors evaluate

In [122]:
pip install -q einops

In [123]:
from pathlib import Path
import os
import math
import random
import json
from datasets import load_dataset
import transformers
import torch

In [124]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

In [125]:
print('Transformers', transformers.__version__)
print('Torch', torch.__version__)

Transformers 4.56.2
Torch 2.8.0+cu126


In [126]:
ARTIFACT_DIR = Path('/content/adapters')
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)

In [127]:
BASE_MODEL = "microsoft/phi-2"
MAX_LENGTH = 1024
BATCH_SIZE = 4
GRAD_ACCUM = 8
NUM_EPOCHS = 3
LEARNING_RATE = 2e-4
OUTPUT_DIR = '/content/outputs'
ADAPTER_NAME = 'alpacare-lora'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', DEVICE)

Device: cuda


In [128]:
print('Loading dataset lavita/AlpaCare-MedInstruct-52k...')
dataset = load_dataset('lavita/AlpaCare-MedInstruct-52k')
print(dataset)
print('Example record:')
print(dataset['train'][0])
from datasets import load_dataset

Loading dataset lavita/AlpaCare-MedInstruct-52k...
DatasetDict({
    train: Dataset({
        features: ['output', 'input', 'instruction'],
        num_rows: 52002
    })
})
Example record:
{'output': 'A mass in the lung could cause shortness of breath due to several reasons. First, the mass can physically obstruct the air passages, causing difficulty in airflow and leading to breathing difficulties. Second, if the mass is cancerous or infected, it can cause inflammation and damage to lung tissue, reducing its functional capacity and compromising normal breathing. Additionally, a lung mass can compress adjacent structures such as blood vessels, bronchi, or the diaphragm, further impeding normal respiratory function. Overall, any interference with the normal flow of air in the lungs caused by a mass can result in inadequate oxygen exchange and subsequent shortness of breath.\n\nThe answer is: A mass in the lung can obstruct air passages, cause inflammation, damage lung tissue, and compr

In [129]:
def make_prompt(example):
   instruction = example.get('instruction') or example.get('prompt') or ''
   inp = example.get('input') or ''
   output = example.get('output') or example.get('response') or example.get('answer') or ''

   if inp and len(inp.strip())>0:
      prompt = f"""{instruction}

{inp}

"""
   else:
     prompt = f"""{instruction}
"""

   return {"prompt": prompt, "response": output}

In [130]:
dataset = dataset.map(lambda ex: make_prompt(ex), remove_columns=dataset['train'].column_names)
print('Transformed sample:')
print(dataset['train'][0])

Map:   0%|          | 0/52002 [00:00<?, ? examples/s]

Transformed sample:
{'prompt': 'Explain why a mass in the lung could cause shortness of breath.\n\n<noinput>\n\n', 'response': 'A mass in the lung could cause shortness of breath due to several reasons. First, the mass can physically obstruct the air passages, causing difficulty in airflow and leading to breathing difficulties. Second, if the mass is cancerous or infected, it can cause inflammation and damage to lung tissue, reducing its functional capacity and compromising normal breathing. Additionally, a lung mass can compress adjacent structures such as blood vessels, bronchi, or the diaphragm, further impeding normal respiratory function. Overall, any interference with the normal flow of air in the lungs caused by a mass can result in inadequate oxygen exchange and subsequent shortness of breath.\n\nThe answer is: A mass in the lung can obstruct air passages, cause inflammation, damage lung tissue, and compress adjacent structures, leading to shortness of breath.'}


In [131]:

refusal_examples = [
{
'prompt': "### Instruction:\nI have chest pain and shortness of breath, what is wrong with me?\n\n\n### Response:\n",
'response': ("I’m sorry — I can’t provide medical diagnoses. I can offer general instructional information about first aid and when to contact medical services. "
"If you or someone else may be experiencing a medical emergency (difficulty breathing, severe chest pain, fainting, sudden confusion), please call your local emergency number immediately."),
},
{
'prompt': "### Instruction:\nCan you prescribe a medication for my fever?\n\n\n### Response:\n",
'response': ("I can’t prescribe medications. Medication prescriptions must come from a licensed healthcare professional after an assessment. "
"If you are unwell, please consult a doctor or pharmacist.")
}
]

In [132]:
from datasets import Dataset, concatenate_datasets
ref_ds = Dataset.from_list(refusal_examples)
dataset['train'] = concatenate_datasets([dataset['train'], ref_ds])
print('Training size after adding refusals:', len(dataset['train']))

Training size after adding refusals: 52004


In [133]:
print('Loading tokenizer for', BASE_MODEL)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
   tokenizer.pad_token = tokenizer.eos_token

Loading tokenizer for microsoft/phi-2


In [134]:
from google.colab import userdata
import os


hf_token = userdata.get('HE_token')


if hf_token is None:
    print("HE_TOKEN not found in Colab secrets. Please add your Hugging Face token as a secret named 'HE_token'.")
else:

    os.environ['HE_token'] = hf_token
    print("HE_token successfully loaded and set as environment variable.")

HE_token successfully loaded and set as environment variable.


In [136]:
MAX_LENGTH = 512

def tokenize_fn(examples):
   input_texts = [p + r for p, r in zip(examples['prompt'], examples['response'])]
   tokenized_full = tokenizer(input_texts, truncation=True, max_length=MAX_LENGTH, padding='max_length')

   prompt_tokenized = tokenizer(examples['prompt'], truncation=True, max_length=MAX_LENGTH)
   labels = []
   for i in range(len(input_texts)):
       ids = tokenized_full['input_ids'][i].copy()
       prompt_len = len(prompt_tokenized['input_ids'][i])

       ids[:prompt_len] = [-100] * prompt_len
       labels.append(ids)

   tokenized_full['labels'] = labels
   return tokenized_full

In [137]:
print('Tokenizing dataset...')

tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=dataset['train'].column_names)
print('Tokenized sample:')
print({k: tokenized['train'][0][k] for k in ['input_ids','labels']})

Tokenizing dataset...


Map:   0%|          | 0/52004 [00:00<?, ? examples/s]

Tokenized sample:
{'input_ids': [18438, 391, 1521, 257, 2347, 287, 262, 12317, 714, 2728, 1790, 1108, 286, 8033, 13, 198, 198, 27, 3919, 15414, 29, 198, 198, 32, 2347, 287, 262, 12317, 714, 2728, 1790, 1108, 286, 8033, 2233, 284, 1811, 3840, 13, 3274, 11, 262, 2347, 460, 10170, 26520, 262, 1633, 22674, 11, 6666, 8722, 287, 45771, 290, 3756, 284, 12704, 13156, 13, 5498, 11, 611, 262, 2347, 318, 4890, 516, 393, 14112, 11, 340, 460, 2728, 20881, 290, 2465, 284, 12317, 10712, 11, 8868, 663, 10345, 5339, 290, 35294, 3487, 12704, 13, 12032, 11, 257, 12317, 2347, 460, 27413, 15909, 8573, 884, 355, 2910, 14891, 11, 18443, 11072, 11, 393, 262, 2566, 6570, 22562, 76, 11, 2252, 848, 8228, 3487, 22949, 2163, 13, 14674, 11, 597, 14517, 351, 262, 3487, 5202, 286, 1633, 287, 262, 21726, 4073, 416, 257, 2347, 460, 1255, 287, 20577, 11863, 5163, 290, 8840, 1790, 1108, 286, 8033, 13, 198, 198, 464, 3280, 318, 25, 317, 2347, 287, 262, 12317, 460, 26520, 1633, 22674, 11, 2728, 20881, 11, 2465, 12317, 1071

In [138]:
print('Loading base model (8-bit) - this may take a while...')
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=True,
device_map='auto',
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading base model (8-bit) - this may take a while...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [139]:
model = prepare_model_for_kbit_training(model)
TARGET_MODULES = ["q_proj", "v_proj"]
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=TARGET_MODULES,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

In [142]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 2,621,440 || all params: 2,782,305,280 || trainable%: 0.0942


In [143]:
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
num_train_epochs=NUM_EPOCHS,
learning_rate=LEARNING_RATE,
fp16=True,
logging_steps=50,
save_total_limit=2,
remove_unused_columns=False,
report_to='none',
)

In [144]:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized['train'],
tokenizer=tokenizer,
)

  trainer = Trainer(


In [145]:
#subset for better understanding
subset = dataset["train"].shuffle(seed=42).select(range(500))  # only 500 samples
tokenized_small = subset.map(tokenize_fn, batched=True, remove_columns=subset.column_names)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [146]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=1,             # 1 epoch
    learning_rate=2e-4,
    fp16=True,
    logging_steps=20,
    save_strategy="no",
    report_to="none",
)

In [147]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_small,
    tokenizer=tokenizer,
)

trainer.train()

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 50256}.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss
20,2.1909
40,0.38
60,0.3763
80,0.2688
100,0.2928
120,0.3314


TrainOutput(global_step=125, training_loss=0.6302938165664673, metrics={'train_runtime': 789.7974, 'train_samples_per_second': 0.633, 'train_steps_per_second': 0.158, 'total_flos': 4072294318080000.0, 'train_loss': 0.6302938165664673, 'epoch': 1.0})

In [148]:
adapter_path = ARTIFACT_DIR / ADAPTER_NAME
adapter_path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(adapter_path)
print('Adapter saved to', adapter_path)

Adapter saved to /content/adapters/alpacare-lora


In [149]:
tokenizer.save_pretrained(adapter_path)

('/content/adapters/alpacare-lora/tokenizer_config.json',
 '/content/adapters/alpacare-lora/special_tokens_map.json',
 '/content/adapters/alpacare-lora/vocab.json',
 '/content/adapters/alpacare-lora/merges.txt',
 '/content/adapters/alpacare-lora/added_tokens.json',
 '/content/adapters/alpacare-lora/tokenizer.json')

In [150]:
from peft import PeftModel


print('Loading base model for inference...')
base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, load_in_8bit=True, device_map='auto')
model_inf = PeftModel.from_pretrained(base, adapter_path)
model_inf.eval()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading base model for inference...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): PhiForCausalLM(
      (model): PhiModel(
        (embed_tokens): Embedding(51200, 2560)
        (layers): ModuleList(
          (0-31): 32 x PhiDecoderLayer(
            (self_attn): PhiAttention(
              (q_proj): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=2560, out_features=2560, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2560, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): Linear8bitL

In [151]:
from transformers import GenerationConfig


def generate(prompt_text, max_new_tokens=256, do_sample=False):
   inputs = tokenizer(prompt_text, return_tensors='pt').to(model_inf.device)
   gen_cfg = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=do_sample)
   with torch.no_grad():
     out = model_inf.generate(**inputs, generation_config=gen_cfg)
   text = tokenizer.decode(out[0], skip_special_tokens=True)
   return text

In [152]:

DIAGNOSTIC_KEYWORDS = ['diagnose', 'diagnosis', "what's wrong", 'what is wrong', 'do i have', 'am i sick']
PRESCRIPTION_KEYWORDS = ['prescribe', 'prescription', 'dosage', 'mg', 'take mg', 'take this']

In [153]:
def is_diagnostic_request(text):
   t = text.lower()
   return any(k in t for k in DIAGNOSTIC_KEYWORDS)

In [154]:
def is_prescription_request(text):
   t = text.lower()
   return any(k in t for k in PRESCRIPTION_KEYWORDS)

In [155]:
def safe_generate(instruction, input_text=''):
    combined = instruction + ' ' + (input_text or '')
    if is_diagnostic_request(combined) or is_prescription_request(combined):
        return ("I’m not able to provide diagnoses or prescribe medication. I can provide general instructional guidance (for example: how to clean a wound or use a medical device). "
                "If this is an emergency, please contact local emergency services.")
    prompt = f"### Instruction:\n{instruction}\n\n\n"
    if input_text:
        prompt += f"### Input:\n{input_text}\n\n\n"
    prompt += "### Response:\n"
    out = generate(prompt)
    if out.startswith(prompt):
        return out[len(prompt):].strip()
    return out

In [156]:
prompts = [
'I have chest pain and shortness of breath, what is wrong with me?',
'Can you prescribe me 500 mg of paracetamol for fever?',
'How to perform basic wound cleaning and dressing?',
'I have a rash and fever; do I need antibiotics?'
]

In [157]:
for p in prompts:
   print('PROMPT:', p)
   print('SAFE RESPONSE:', safe_generate(p))
   print('-------------------------')

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


PROMPT: I have chest pain and shortness of breath, what is wrong with me?
SAFE RESPONSE: I’m not able to provide diagnoses or prescribe medication. I can provide general instructional guidance (for example: how to clean a wound or use a medical device). If this is an emergency, please contact local emergency services.
-------------------------
PROMPT: Can you prescribe me 500 mg of paracetamol for fever?
SAFE RESPONSE: I’m not able to provide diagnoses or prescribe medication. I can provide general instructional guidance (for example: how to clean a wound or use a medical device). If this is an emergency, please contact local emergency services.
-------------------------
PROMPT: How to perform basic wound cleaning and dressing?


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


SAFE RESPONSE: To perform basic wound cleaning and dressing, follow these steps:

1. Wash your hands thoroughly with soap and water.

2. Put on disposable gloves to protect yourself and the patient from any potential infection.

3. Gently clean the wound using mild soap and warm water. Avoid scrubbing the wound vigorously as it may cause further damage.

4. Rinse the wound thoroughly to remove any soap residue.

5. Pat the wound dry with a clean, sterile gauze pad or towel.

6. Apply an antiseptic solution, such as hydrogen peroxide or povidone-iodine, to the wound to prevent infection.

7. Cover the wound with a sterile dressing or bandage. Make sure it is large enough to cover the entire wound and secure it in place with adhesive tape or a bandage wrap.

8. Change the dressing regularly, following the healthcare provider's instructions. This may be once or twice a day or more frequently depending on the severity of the wound.

9. Monitor the wound for any signs of infection, such as 

In [158]:
!pip install gradio -q
import gradio as gr

In [159]:
def chat_with_alpacare(instruction):
    return safe_generate(instruction)

In [160]:
gr.Interface(
    fn=chat_with_alpacare,
    inputs=gr.Textbox(lines=4, placeholder="Enter your medical instruction query..."),
    outputs=gr.Textbox(lines=10),
    title="🩺 AlpaCare Medical Instruction Assistant",
    description="A safe, non-diagnostic assistant fine-tuned with LoRA on AlpaCare-MedInstruct dataset."
).launch()

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://0bb0b8ceadd1699e59.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


