<a href="https://colab.research.google.com/github/rajesh8483/medimind_chatbot/blob/main/medimind_chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install gradio

import pandas as pd
import re
import json
import os
import datetime
import joblib
import gradio as gr
import tensorflow as tf
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

In [None]:
# Load and preprocess dataset
df = pd.read_csv("/content/medimind_india_raw_data.csv")
df.dropna(subset=["symptoms", "disease", "treatment_plan"], inplace=True)
df.reset_index(drop=True, inplace=True)

In [None]:
def clean_text(text):
    return re.sub(r"[^\w\s]", "", text.lower()).strip()

# Clean symptoms and encode disease labels
df["symptoms"] = df["symptoms"].apply(clean_text)
df["disease"] = df["disease"].astype(str)

label_encoder = LabelEncoder()
df["label"] = label_encoder.fit_transform(df["disease"])
num_labels = len(label_encoder.classes_)

In [None]:
# Load tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
# Paths for model and encoder
model_path = "disease_model"
encoder_path = "label_encoder.joblib"

In [None]:
# Load or train model
if os.path.exists(model_path) and os.path.exists(encoder_path):
    model = TFDistilBertForSequenceClassification.from_pretrained(model_path)
    label_encoder = joblib.load(encoder_path)
else:
    model = TFDistilBertForSequenceClassification.from_pretrained(
        "distilbert-base-uncased", num_labels=num_labels
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(5e-5),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"]
    )
    encodings = tokenizer(df["symptoms"].tolist(), truncation=True, padding=True, max_length=128, return_tensors="tf")
    dataset = tf.data.Dataset.from_tensor_slices((dict(encodings), df["label"].values)).shuffle(len(df)).batch(16)
    train_size = int(0.8 * len(dataset))
    model.fit(dataset.take(train_size), validation_data=dataset.skip(train_size), epochs=3)
    model.save_pretrained(model_path)
    joblib.dump(label_encoder, encoder_path)

In [None]:
def clinical_decision(disease):
    if "infarction" in disease.lower() or "cardiac" in disease.lower():
        return {
            "symptom_analysis": {
                "differential_diagnoses": [
                    {"name": "Acute myocardial infarction (STEMI)", "confidence": "92%"},
                    {"name": "Unstable angina", "confidence": "5%"},
                    {"name": "Aortic dissection", "confidence": "2%"},
                    {"name": "Pulmonary embolism", "confidence": "1%"}
                ],
                "alert": "Immediate cardiology consult recommended."
            },
            "decision_support": {
                "immediate_actions": [
                    "Aspirin 325 mg (chewable) stat.",
                    "Heparin bolus and infusion.",
                    "Prepare for PCI within 90 minutes."
                ],
                "follow_ups": [
                    "Repeat ECG in 30 minutes.",
                    "Monitor troponin every 6 hours."
                ],
                "documentation": "STEMI diagnosed; aspirin and heparin initiated; PCI planned."
            }
        }
    elif "diabetes" in disease.lower():
        return {
            "symptom_analysis": {
                "differential_diagnoses": [
                    {"name": "Type 2 diabetes mellitus", "confidence": "90%"},
                    {"name": "Type 1 diabetes", "confidence": "5%"},
                    {"name": "Hyperthyroidism", "confidence": "3%"},
                    {"name": "Malignancy", "confidence": "2%"}
                ],
                "alert": "Monitor glucose and refer for diet and lifestyle counseling."
            },
            "decision_support": {
                "immediate_actions": [
                    "Metformin 500 mg once daily.",
                    "Refer to dietician for low-carb diet."
                ],
                "follow_ups": [
                    "Repeat HbA1c in 3 months.",
                    "Order lipid panel and kidney tests."
                ],
                "documentation": "Start metformin and follow a healthy diet."
            }
        }
    else:
        return {
            "symptom_analysis": {
                "differential_diagnoses": [{"name": disease, "confidence": "80%"}],
                "alert": "Further evaluation recommended."
            },
            "decision_support": {
                "immediate_actions": ["Supportive care"],
                "follow_ups": ["Follow-up in 1 week"],
                "documentation": f"Diagnosis: {disease}"
            }
        }


In [None]:
def analyze_patient(name, age, gender, symptoms, history, tests, query):
    input_text = f"{symptoms} {history} {tests} {query}".lower()
    inputs = tokenizer(input_text, return_tensors="tf", truncation=True, padding=True, max_length=128)
    logits = model(inputs).logits
    prediction = tf.argmax(logits, axis=1).numpy()[0]
    predicted_disease = label_encoder.inverse_transform([prediction])[0]
    decision = clinical_decision(predicted_disease)

    patient_record = {
        "timestamp": str(datetime.datetime.now()),
        "name": name,
        "age": age,
        "gender": gender,
        "input": {
            "symptoms": symptoms,
            "history": history,
            "tests": tests,
            "query": query
        },
        "predicted_disease": predicted_disease,
        "outputs": decision
    }

    # Save to files
    with open("patient_records.jsonl", "a") as f_json:
        f_json.write(json.dumps(patient_record) + "\n")

    df_path = "patient_records.csv"
    csv_row = {
        "timestamp": patient_record["timestamp"],
        "name": name,
        "age": age,
        "gender": gender,
        "symptoms": symptoms,
        "history": history,
        "tests": tests,
        "query": query,
        "predicted_disease": predicted_disease
    }
    df_new = pd.DataFrame([csv_row])
    if not os.path.exists(df_path):
        df_new.to_csv(df_path, index=False)
    else:
        df_existing = pd.read_csv(df_path)
        df_all = pd.concat([df_existing, df_new], ignore_index=True)
        df_all.to_csv(df_path, index=False)

    # Format output like MediMind clinical examples
    readable_output = f"""## Clinical Summary for {name} ({age}, {gender})\n
**Input**:
- Symptoms: {symptoms}
- Medical History: {history}
- Test Results: {tests}
- Physician Query: {query}

**Output (Symptom Analysis)**:
""" + "\n".join([f"- {d['name']}: {d['confidence']}" for d in decision['symptom_analysis']['differential_diagnoses']]) + \
f"""\n- Alert: {decision['symptom_analysis']['alert']}

**Output (Clinical Decision Support)**:
- Immediate Actions:
""" + "\n".join([f"  - {act}" for act in decision['decision_support']['immediate_actions']]) + \
f"""\n- Follow-Ups:\n""" + "\n".join([f"  - {f}" for f in decision['decision_support']['follow_ups']]) + \
f"""\n- Documentation: {decision['decision_support']['documentation']}"""

    return predicted_disease, readable_output


In [None]:
def run_gradio():
    interface = gr.Interface(
        fn=analyze_patient,
        inputs=[
            gr.Textbox(label="Patient Name"),
            gr.Textbox(label="Age"),
            gr.Radio(["Male", "Female", "Other"], label="Gender"),
            gr.Textbox(label="Symptoms", lines=2, placeholder="E.g., chest pain, sweating..."),
            gr.Textbox(label="Medical History", lines=2),
            gr.Textbox(label="Test Results", lines=2),
            gr.Textbox(label="Physician Query", lines=2, placeholder="What’s the likely diagnosis, and what should we do?")
        ],
        outputs=[
            gr.Textbox(label="Predicted Disease"),
            gr.Textbox(label="Detailed Report", lines=20)
        ],
        title=" MediMind - AI Medical Assistant",
        description="Enter patient details to get disease prediction and clinical decision support.",
        allow_flagging="never"
    )
    interface.launch(share=True)
run_gradio()
