In [None]:
import json
import random

In [None]:
# data
"""
case template = 
    problem : 
        age, cheif complaints, histories,
        examinalfindings,
        reasoning, differentials, investigations, 
        refferals, 
        variables--------------days, temp,years, bp, risks,

ADVERSARIAL (TRICKY) CASES TEMPLATES=
    diagnosis
    mimics
    input 
    reasoning
    plan

inputs : patient age, complaint,histories, examinal findings

outputs : reasoining,diff, investigations, referrals/ plan

dataset- instruction : act a med ai asistant, diagnose and plan
            input: inputs
            output : outputs

"""    
# same for SURGICAL PATCH DATA GENERATOR(Focus: Fixes Dengue/Malaria and Chickenpox/Scabies confusion_) 
# and 
#  TROPICAL / INFECTIOUS - instruction data 

In [None]:
#merge data 

In [None]:
# FAST INSTALL (Pre-compiled)
# Run this ONLY if the previous install is taking forever
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers "trl<0.9.0" accelerate bitsandbytes

In [None]:
# Force install the latest PEFT to fix the "ensure_weight_tying" error
!pip install --upgrade --force-reinstall "peft @ git+https://github.com/huggingface/peft.git"

In [None]:
# 1. Force-align Torch and Torchvision
# We strictly request the version Unsloth is asking for (>=0.24.0)
!pip install --upgrade "torch==2.9.1" "torchvision>=0.24.0" "torchaudio>=2.9.0"

# 2. Re-install Unsloth (Just to be safe after the torch update)
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:

import os
# Force Single GPU Mode
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["WANDB_DISABLED"] = "true"

from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset

In [None]:
# 1. LOAD MODEL
max_seq_length = 2048 
print("‚¨áÔ∏è Loading Llama-3.1-8B (Senior Resident)...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = None,
    load_in_4bit = True,
)


In [None]:
# 2. CONFIGURE ADAPTERS
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, 
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0, 
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)


In [None]:
# 3. DEFINE ROBUST FORMATTING FUNCTION (The Fix)
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

def formatting_prompts_func(examples):
    # Case 1: Batch Mode (List of strings) - The Trainer uses this during training
    if isinstance(examples["instruction"], list):
        instructions = examples["instruction"]
        inputs       = examples["input"]
        outputs      = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + tokenizer.eos_token
            texts.append(text)
        return texts # <--- Returns a LIST (Correct for Trainer)

    # Case 2: Single Example Mode - Unsloth uses this for the safety check
    else:
        instruction = examples["instruction"]
        input       = examples["input"]
        output      = examples["output"]
        text = alpaca_prompt.format(instruction, input, output) + tokenizer.eos_token
        return [text] # <--- Returns a LIST of 1 string

print("üìÇ Loading 'FINAL_MASTER_DATASET.json'...")
dataset = load_dataset("json", data_files="FINAL_MASTER_DATASET.json", split="train")

In [None]:
# Split into Train/Test
dataset = dataset.train_test_split(test_size=0.05) 

In [None]:
print(f"üìä Final Training on {len(dataset['train'])} cases")
print(f"üìä Evaluating on {len(dataset['test'])} cases")

In [None]:
# 4. TRAINING ARGUMENTS
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset["train"],
    eval_dataset = dataset["test"],
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    
    # --- Pass the Function Here (No manual mapping needed) ---
    formatting_func = formatting_prompts_func, 
    
    args = TrainingArguments(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        warmup_steps = 5,
        max_steps = 120,        
        learning_rate = 2e-5,   
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        output_dir = "medical_llama_final_v1",
        
        # Validation Settings
        eval_strategy = "steps",
        eval_steps = 10,
        save_strategy = "steps",
        save_steps = 20,
    ),
)

print("üöÄ Starting Final Training...")


In [None]:
trainer.train()
print("‚úÖ FINAL MODEL READY.")

In [None]:
# ============================================================================
# üöë RESCUE MISSION: RESUME TRAINING FROM CHECKPOINT-100
# ============================================================================

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["WANDB_DISABLED"] = "true"

from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset

# 1. LOAD EVERYTHING AGAIN
max_seq_length = 2048 
print("‚¨áÔ∏è Reloading Model for Rescue...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = None,
    load_in_4bit = True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, 
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0, 
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

# 2. LOAD DATA
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

def formatting_prompts_func(examples):
    if isinstance(examples["instruction"], list):
        instructions = examples["instruction"]
        inputs       = examples["input"]
        outputs      = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + tokenizer.eos_token
            texts.append(text)
        return texts
    else:
        text = alpaca_prompt.format(examples["instruction"], examples["input"], examples["output"]) + tokenizer.eos_token
        return [text]

print("üìÇ Reloading Dataset...")
dataset = load_dataset("json", data_files="FINAL_MASTER_DATASET.json", split="train")
dataset = dataset.train_test_split(test_size=0.05) 

# 3. CONFIGURE TRAINER
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset["train"],
    eval_dataset = dataset["test"],
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    formatting_func = formatting_prompts_func,
    args = TrainingArguments(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        warmup_steps = 5,
        max_steps = 120,        
        learning_rate = 2e-5,   
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        output_dir = "medical_llama_final_v1", # <--- Points to where your checkpoint is
        
        eval_strategy = "steps",
        eval_steps = 10,
        save_strategy = "steps",
        save_steps = 20,
    ),
)

# 4. THE MAGIC COMMAND
# This checks your folder, finds 'checkpoint-100', and resumes instantly.
print("üöÄ Resuming Training from Step 100...")
trainer.train(resume_from_checkpoint = True) 

print("‚úÖ RESCUE COMPLETE! Model finished.")

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# 1. Extract Data Automagically
# When resuming, this loads the old history from the checkpoint + the new steps
history = pd.DataFrame(trainer.state.log_history)

# 2. Clean & Merge Data
train_loss = history[history['loss'].notna()][['step', 'loss']].rename(columns={'loss': 'Training Loss'})
val_loss = history[history['eval_loss'].notna()][['step', 'eval_loss']].rename(columns={'eval_loss': 'Validation Loss'})
df = pd.merge(train_loss, val_loss, on='step', how='outer')

# 2.5 SAFETY TWEAK: Remove duplicates if any exist
df = df.drop_duplicates(subset=['step']).sort_values('step')

# 3. Plot
plt.figure(figsize=(12, 6))
sns.set_style("white")
sns.despine()

sns.lineplot(x="step", y="Training Loss", data=df, color='#FF9F1C', linewidth=3, label='Training Loss')
sns.lineplot(x="step", y="Validation Loss", data=df, color='#5C9EAD', linewidth=3, label='Validation Loss')

plt.title("Final Model Training Convergence (Resumed)", fontsize=16, weight='bold', pad=20, color='#333333')
plt.xlabel("Steps", fontsize=14, labelpad=10)
plt.ylabel("Loss", fontsize=14, labelpad=10)
plt.legend(frameon=False, fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.3)
plt.tight_layout()
plt.show()

# Print the final score
if not val_loss.empty:
    final_val = df['Validation Loss'].iloc[-1]
    print(f"üèÜ Final Validation Loss: {final_val:.4f} (Lower is better!)")

In [None]:

# 1. Switch to Inference Mode
FastLanguageModel.for_inference(model)

def ask_doctor(input_text):
    prompt = alpaca_prompt.format(
        "Act as a Medical Consultant. Diagnose and Plan.",
        input_text,
        ""
    )
    inputs = tokenizer([prompt], return_tensors = "pt").to("cuda")
    outputs = model.generate(**inputs, max_new_tokens = 512, use_cache = True)
    return tokenizer.batch_decode(outputs)[0].split("### Response:\n")[-1].replace(tokenizer.eos_token, "")



In [None]:
# --- PART A: 5 SEEN CASES (Memory Check) ---
print("üîé PART A: MEMORY CHECK (5 Random Seen Cases)")
print("="*60)
with open('FINAL_MASTER_DATASET.json', 'r') as f:
    training_data = json.load(f)

# Pick 5 random cases
seen_indices = random.sample(range(len(training_data)), 5)

for i, idx in enumerate(seen_indices):
    case = training_data[idx]
    print(f"\nüìù SEEN CASE {i+1}:")
    print(f"INPUT: {case['input'].splitlines()[1]}") # Complaint
    print("-" * 20)
    ai_response = ask_doctor(case['input'])
    print(f"ü§ñ AI SAYS:\n{ai_response.splitlines()[0]}") # Diagnosis line
    print(f"‚úÖ TRUTH:   {case['output'].splitlines()[0]}")


In [None]:

# --- PART B: 10 UNSEEN CASES (Intelligence Check) ---
print("\n\n" + "="*60)
print("üåç PART B: THE GAUNTLET (10 Unseen Conditions)")
print("="*60)

unseen_exam = [
    # 1. Endocrine
    {"name": "Hypothyroidism", "input": "Patient: 45y female\nComplaint: weight gain and fatigue\nHistory: Gained 10kg in 3 months despite poor appetite. Feels cold all the time. Constipation.\nExam: HR 58 (Bradycardia). Dry skin. Delayed relaxation of deep tendon reflexes."},
    # 2. Hematology
    {"name": "Iron Deficiency Anemia", "input": "Patient: 30y female\nComplaint: feeling tired and dizzy\nHistory: Heavy menstrual periods (menorrhagia). Craving ice (pica). Shortness of breath on exertion.\nExam: Conjunctival pallor. Spoon-shaped nails (koilonychia). Tachycardia."},
    # 3. Pediatrics (ENT)
    {"name": "Acute Otitis Media", "input": "Patient: 4y male\nComplaint: crying and pulling at right ear\nHistory: Had a cold 3 days ago. Now high fever and crying. Not eating.\nExam: Temp 39C. Right tympanic membrane is red, bulging, and immobile."},
    # 4. Dermatology
    {"name": "Cellulitis", "input": "Patient: 50y male\nComplaint: red painful leg\nHistory: Scratched leg in garden 2 days ago. Now lower leg is bright red, hot, and painful. Fevers.\nExam: Erythema extending up the shin. Hot to touch. Tender. Inguinal lymph nodes tender."},
    # 5. Cardiology
    {"name": "Atrial Fibrillation", "input": "Patient: 65y male\nComplaint: palpitations\nHistory: 'Heart feels like a fish flopping in chest'. Mild shortness of breath. History of hypertension.\nExam: Pulse is irregularly irregular. BP 140/90. Chest clear."},
    # 6. Respiratory
    {"name": "Acute Bronchiolitis (RSV)", "input": "Patient: 6 months female\nComplaint: difficulty breathing and cough\nHistory: Runny nose for 2 days. Now wheezing and working hard to breathe. Poor feeding.\nExam: RR 50. Subcostal recession. Widespread wheeze and crackles."},
    # 7. Neurology
    {"name": "Concussion (Mild TBI)", "input": "Patient: 20y male\nComplaint: headache and confusion after hit\nHistory: Hit head during rugby match. Brief loss of consciousness (<30s). Vomited once. Amnesia for the event.\nExam: GCS 15. Pupils equal and reactive. No focal deficits."},
    # 8. Urology
    {"name": "Benign Prostatic Hyperplasia (BPH)", "input": "Patient: 70y male\nComplaint: waking up at night to pee\nHistory: Frequency, urgency, and poor stream (dribbling). Nocturia x4. No pain.\nExam: Abdomen soft. DRE: Smooth, enlarged, non-tender prostate."},
    # 9. Allergy
    {"name": "Anaphylaxis", "input": "Patient: 18y female\nComplaint: swollen lips and difficulty breathing\nHistory: Ate peanuts 10 mins ago. Lips swelled up immediately. Wheezing. Feeling faint.\nExam: Stridor audible. BP 80/50 (Hypotensive). Widespread hives (urticaria)."},
    # 10. MSK
    {"name": "Osteoarthritis (Knee)", "input": "Patient: 60y female\nComplaint: right knee pain\nHistory: Pain worse at end of day and after walking. Stiffness in morning <30 mins. No injury.\nExam: Crepitus (crunching) on movement. Bony swelling. No warmth/redness."}
]

for i, case in enumerate(unseen_exam):
    print(f"\nüß™ UNSEEN Q{i+1}: {case['name']}")
    print(f"INPUT:\n{case['input']}")
    print("-" * 40)
    ai_response = ask_doctor(case['input'])
    
    # Print the Diagnosis and the Reasoning to verify logic
    print("ü§ñ AI REPORT:")
    print(ai_response)
    print("="*60)

In [None]:
import shutil
import os

print("üì¶ 1. Zipping your Medical AI Brain...")
# This compresses your fine-tuned weights into a zip file
shutil.make_archive("medical_ai_brain", 'zip', "medical_llama_final_lora")

print("üì¶ 2. Checking for Dataset...")
if os.path.exists("FINAL_MASTER_DATASET.json"):
    print("‚úÖ Dataset found.")
else:
    print("‚ö†Ô∏è Dataset not found! (Make sure you save it if you haven't already)")

print("\n" + "="*60)
print("üö® DOWNLOAD THESE TWO FILES NOW:")
print("1. medical_ai_brain.zip (This is your trained model)")
print("2. FINAL_MASTER_DATASET.json (This is your data)")
print("="*60)
print("üëâ KAGGLE: Check the 'Output' folder on the right sidebar.")
print("üëâ COLAB: Check the folder icon on the left sidebar.")

In [None]:
import os
from IPython.display import FileLink

# 1. Define the files we want
files_to_download = ["medical_ai_brain.zip", "FINAL_MASTER_DATASET.json"]

print("üëá CLICK THESE LINKS TO DOWNLOAD üëá")
print("="*40)

for filename in files_to_download:
    if os.path.exists(filename):
        # This generates a clickable blue link in Kaggle
        display(FileLink(filename))
    else:
        print(f"‚ö†Ô∏è Could not find {filename} (Did the previous zip script finish?)")

print("="*40)

In [None]:
# ==========================================
# üîÑ RELOAD & LAUNCH (Session Timeout Fix)
# ==========================================

import os
from unsloth import FastLanguageModel
import gradio as gr

# 1. Point to your uploaded folder
model_path = "/kaggle/input/med-llama" 

print(f"üîÑ Waking up MedLlama from: {model_path}...")

# 2. RELOAD THE BRAIN (Fixes 'NameError: tokenizer not defined')
try:
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = model_path, 
        max_seq_length = 2048,
        dtype = None,
        load_in_4bit = True,
    )
    FastLanguageModel.for_inference(model)
    print("‚úÖ Model & Tokenizer Reloaded Successfully!")

except Exception as e:
    print(f"‚ùå Error loading model: {e}")
    print("‚ö†Ô∏è If this says 'Unsloth not defined', scroll up and run the pip install cell again!")

# 3. DEFINE THE DOC'S LOGIC (Stable Mode)
def medical_consult(symptoms, history):
    # Strict Prompt Formatting
    user_input = f"Complaint: {symptoms}\nPatient Data: {history}"
    
    prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Act as a Medical Consultant. Diagnose and Plan.

### Input:
{user_input}

### Response:
"""
    
    inputs = tokenizer([prompt], return_tensors = "pt").to("cuda")
    
    # Greedy Decoding (Stability > Creativity)
    outputs = model.generate(
        **inputs, 
        max_new_tokens = 512, 
        use_cache = True,
        do_sample = False,       # Strict mode
        repetition_penalty = 1.2 
    )
    return tokenizer.batch_decode(outputs)[0].split("### Response:\n")[-1].replace(tokenizer.eos_token, "")

# 4. LAUNCH UI
print("üöÄ Launching Interface...")
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate")) as demo:
    gr.Markdown("# üè• MedLlama")
    with gr.Row():
        with gr.Column(scale=1):
            symptoms = gr.Textbox(label="Chief Complaint")
            history = gr.Textbox(label="Patient History + Exam", lines=5)
            submit_btn = gr.Button("Generate Consult", variant="primary")
        with gr.Column(scale=1):
            output = gr.Textbox(label="Report", lines=15)
    submit_btn.click(fn=medical_consult, inputs=[symptoms, history], outputs=output)

demo.launch(share=True)