In [1]:
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
import gc
import time
from torch.utils.data import DataLoader, TensorDataset
from huggingface_hub import login
from openai import OpenAI
from tqdm import tqdm
import re


In [None]:
login(token="hf_....")     # HUGGINGFACE TOKEN

In [3]:
# -------------------------------------
# 1. Configuration
# -------------------------------------

# --- Model and Device Setup ---
# MODIFIED: Switched to the open-access Mistral-7B-Instruct model.
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Performance Tuning ---
# You can tune this based on your VRAM. 16 is a good starting point for a 7B model on a 4070.
BATCH_SIZE = 25

# --- File Paths ---
INPUT_CSV = "ai-medical-chatbot.csv"
CHECKPOINT_CSV = "partial_simplified_notes.csv"
FINAL_CSV = "final_simplified_notes.csv"

In [4]:
# -------------------------------------
# 2. Load Model and Tokenizer (Optimized)
# -------------------------------------

# Configure 4-bit quantization (this remains the same and is essential)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

print(f"Loading model: {model_name}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# Forgetting to set a pad_token is a common source of errors.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load the quantized model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
).eval()

# Optional: JIT compile if PyTorch >= 2.0
if hasattr(torch, "compile"):
    print("Compiling model for faster inference...")
    model = torch.compile(model)

print("Model loaded and compiled successfully.")


Loading model: mistralai/Mistral-7B-Instruct-v0.3


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Compiling model for faster inference...
Model loaded and compiled successfully.


In [5]:
# -------------------------------------
# 3. Load Dataset
# -------------------------------------
print(f"Loading dataset from {INPUT_CSV}...")
df = pd.read_csv(INPUT_CSV).dropna(subset=["Description", "Doctor"])
print(f"Loaded {len(df)} records to process.")

# Take a random sample of 60,000 rows (without replacement)
SAMPLE_SIZE = 20000
if len(df) > SAMPLE_SIZE:
    df = df.sample(n=SAMPLE_SIZE, random_state=42).reset_index(drop=True)
    print(f"Randomly sampled {SAMPLE_SIZE} rows.")
else:
    print(f"Dataset has fewer than {SAMPLE_SIZE} rows, using full dataset.")


Loading dataset from ai-medical-chatbot.csv...
Loaded 256916 records to process.
Randomly sampled 20000 rows.


In [6]:
# -------------------------------------
# 4. Build Prompts
# -------------------------------------

def build_prompt(description, doctor_text):
    # This system message helps set the context for the model
    system_message = (
        "You are an expert medical communicator rewriting a doctor's response for a patient. "
        "You MUST follow these rules without exception:\n\n"
        "--- MANDATORY RULES ---\n"
        "1.  **SIMPLIFY & EXPLAIN:** Your only goal is to simplify complex medical terms and explain the core information accurately. NEVER use vague, non-clinical phrases (like 'brain vitamins'); instead, describe the treatment's purpose (e.g., 'nutritional support for brain health').\n\n"
        "2.  **SANITIZE MEDICATIONS (CRITICAL SAFETY RULE):** You MUST NOT use specific medication brand names. ALWAYS replace them with general categories like 'a prescribed pain reliever', 'an antibiotic cream', or 'a lactose-free formula'.\n\n"
        "3.  **IGNORE SENSITIVE DATA (CRITICAL PRIVACY RULE):** The original note may contain personal information like names, emails, or phone numbers. Your job is to IGNORE sentences with this information completely. Rewrite the note focusing only on the medical advice, as if the sensitive data was never there.\n\n"
        "4.  **BE DIRECT & PROFESSIONAL:** NEVER include greetings, introductions, closings, or conversational filler ('Don't worry', 'I hope this helps'). Start the explanation immediately.\n\n"
        "5.  **ENSURE COMPLETENESS:** The final output MUST be a complete thought and MUST NOT end in a partial sentence."
    )
    
    # The user message contains the specific task
    user_message = f"""Rewrite the following doctor's note so a patient can understand it. Use the patient's question as context.

**Patient's Question:**
{description}

**Doctor's Original Note:**
{doctor_text}

**Simplified Version for the Patient:**
"""
    # Using the chat template automatically applies the correct [INST] tags
    return tokenizer.apply_chat_template([
        {"role": "user", "content": f"{system_message}\n\n{user_message}"}
    ], tokenize=False)


print("Building all prompts...")
prompts = [build_prompt(row["Description"], row["Doctor"]) for _, row in df.iterrows()]


def clean_note(text: str) -> str:
    import re

    # --- 1. Sanitize Medications with Smart, Non-Repetitive Replacements ---
    medication_map = {
        r'\b(Tramadol|Acetaminophen|Ibuprofen|Paracetamol|Aspirin|Ultracet)\b': 'an over-the-counter pain reliever',
        r'\b(Pantoprazole|Levosulpiride)\b': 'a medication to manage acid reflux',
        r'\b(Fluoxetine|Sertraline)\b': 'a prescribed antidepressant medication',
        r'\bIsomil\b': 'a special lactose-free formula',
    }
    for pattern, replacement in medication_map.items():
        text = re.sub(f'({pattern})(\s*(or|,|and)\s*{pattern})*', replacement, text, flags=re.IGNORECASE)

    # --- 2. Sanitize Any Remaining Unmapped Medications ---
    generic_meds = [r'\bOralcon\b', r'\bKalarchikai\b', r'\bBenadryl\b', r'\bMoxikind\b', r'\bClobazam\b']
    for med in generic_meds:
        text = re.sub(med, 'a prescribed medication', text, flags=re.IGNORECASE)

    # --- 3. Redact PII ---
    text = re.sub(r'\S+@\S+|\bhttp\S+\b', '[contact information redacted]', text, flags=re.IGNORECASE)
    text = re.sub(r'(\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b|\b\d{10,}\b)', '[phone number redacted]', text)
    
    # --- 4. Remove Conversational Filler (Final Expanded List) ---
    removal_patterns = [

        r"please keep me informed.*", r"if you have any questions related to.*",
        r"please do not hesitate to ask.*", r"please reach out if you have any questions.*",
        r"i wish you.*", r"good luck.*", r"i trust this clarifies.*", r"reach out if you require.*",
        r"take care.*", r"please note.*", r"feel free to.*", r"i'm here to help.*",
        r"rest assured.*", r"if you have further questions.*", r"don't hesitate to.*",
        r"best regards.*", r"sincerely.*", r"i hope this.*helps.*", r"please consult.*", r"\bDr\.$"
    ]
    for pattern in removal_patterns:
        text = re.sub(pattern, "", text, flags=re.IGNORECASE)

    # --- 5. Final Cleanup and Repair ---
    text = re.sub(r'\s{2,}', ' ', text).strip()
    text = re.sub(r'\s+\d+\.?\s*$', '', text) # Remove dangling list numbers
    # Clip any incomplete sentences at the very end
    if not text.endswith((".", "!", "?")):
        last_sentence_end = max(text.rfind('. '), text.rfind('! '), text.rfind('? '))
        if last_sentence_end > 0:
            text = text[:last_sentence_end + 1]

    return text

Building all prompts...


In [7]:
# -------------------------------------
# 5. Inference with ETA + Checkpoints
# -------------------------------------
total_batches = (len(prompts) + BATCH_SIZE - 1) // BATCH_SIZE
outputs = []
start_time = time.time()

print(f"\nStarting inference with batch size: {BATCH_SIZE}")

for i in range(0, len(prompts), BATCH_SIZE):
    batch_prompts = prompts[i:i + BATCH_SIZE]

    inputs = tokenizer(
        batch_prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=768
    ).to(device)

    with torch.no_grad():
        generated_outputs = model.generate(
            **inputs,
            max_new_tokens=220,
            do_sample=True,
            temperature=0.5,
            top_p=0.9,
            repetition_penalty=1.05,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id # Important for open-ended generation
        )

    # Decode only the newly generated part
    input_lengths = [len(x) for x in inputs.input_ids]
    decoded_batch = [
        tokenizer.decode(generated_outputs[j][input_lengths[j]:], skip_special_tokens=True).strip()
        for j in range(len(batch_prompts))
    ]
    cleaned_batch = [clean_note(t) for t in decoded_batch]
    outputs.extend(cleaned_batch)

    # --- Progress + ETA ---
    batch_num = (i // BATCH_SIZE) + 1
    elapsed = time.time() - start_time
    avg_time_per_batch = elapsed / batch_num
    remaining_batches = total_batches - batch_num
    eta = avg_time_per_batch * remaining_batches
    print(f"[Batch {batch_num}/{total_batches}] "
          f"Elapsed: {elapsed/60:.1f} min | "
          f"ETA: {eta/60:.1f} min | "
          f"Avg. time/batch: {avg_time_per_batch:.2f}s")

    # --- Save partial progress ---
    if batch_num % 100 == 0:
        temp_df = df.iloc[:len(outputs)].copy()
        temp_df["Simplified_Note"] = outputs
        temp_df.to_csv(CHECKPOINT_CSV, index=False)
        print(f"ðŸ’¾ Saved checkpoint with {len(outputs)} rows to {CHECKPOINT_CSV}")

    # Clean up memory
    del inputs, generated_outputs, decoded_batch
    gc.collect()
    torch.cuda.empty_cache()



Starting inference with batch size: 25
[Batch 1/800] Elapsed: 0.7 min | ETA: 549.6 min | Avg. time/batch: 41.27s
[Batch 2/800] Elapsed: 1.4 min | ETA: 555.8 min | Avg. time/batch: 41.79s
[Batch 3/800] Elapsed: 2.1 min | ETA: 546.3 min | Avg. time/batch: 41.13s
[Batch 4/800] Elapsed: 2.8 min | ETA: 548.3 min | Avg. time/batch: 41.33s
[Batch 5/800] Elapsed: 3.7 min | ETA: 591.4 min | Avg. time/batch: 44.63s
[Batch 6/800] Elapsed: 4.3 min | ETA: 575.6 min | Avg. time/batch: 43.50s
[Batch 7/800] Elapsed: 5.0 min | ETA: 570.2 min | Avg. time/batch: 43.14s
[Batch 8/800] Elapsed: 5.7 min | ETA: 563.5 min | Avg. time/batch: 42.69s
[Batch 9/800] Elapsed: 6.3 min | ETA: 557.2 min | Avg. time/batch: 42.26s
[Batch 10/800] Elapsed: 7.0 min | ETA: 551.3 min | Avg. time/batch: 41.87s
[Batch 11/800] Elapsed: 7.6 min | ETA: 547.7 min | Avg. time/batch: 41.65s
[Batch 12/800] Elapsed: 8.3 min | ETA: 544.6 min | Avg. time/batch: 41.47s
[Batch 13/800] Elapsed: 9.0 min | ETA: 542.0 min | Avg. time/batch: 4

In [8]:
# -------------------------------------
# 6. Final Save
# -------------------------------------
print("\nâœ… Inference complete. Saving final results...")
df['Simplified_Note'] = pd.Series(outputs)
df.to_csv(FINAL_CSV, index=False)
print(f"âœ… Saved {len(outputs)} final results to {FINAL_CSV}")


âœ… Inference complete. Saving final results...
âœ… Saved 20000 final results to final_simplified_notes.csv
