In [None]:
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, util
import torch
import re
import pickle

# Load FLAN-T5 model for TherapyBot++
model_id = "raviix46/flan-t5-therapy-finetuned"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
    device_map="auto" if torch.cuda.is_available() else None
)

# Load FAQ and Symptom embeddings
with open("faq_embeddings.pkl", "rb") as f:
    faq_data = pickle.load(f)
with open("symptom_embeddings.pkl", "rb") as f:
    symptom_data = pickle.load(f)

# Shared Chat History for Therapy
chat_history = []

def clean_reply(text, user_input=None):
    text = re.sub(r"(hi|hello|hey)[, ]+(charlie|[A-Za-z]+)[,\.]?", "", text, flags=re.IGNORECASE).strip()
    text = re.sub(r"\b(charlie|therapist)\b", "you", text, flags=re.IGNORECASE)
    text = re.sub(r"(thank you for reaching out)[\s,]*you[\.]?", r"\1.", text, flags=re.IGNORECASE)
    text = re.sub(r"\byour addiction\b", "your condition", text, flags=re.IGNORECASE)
    if user_input:
        keywords = set(re.findall(r"\b\w{5,}\b", user_input.lower()))
        for word in keywords:
            text = re.sub(rf"\b({word})\b(?=.*\b\1\b)", "", text, flags=re.IGNORECASE)
    repeated_phrases = [
        r"let's explore this further\.?", r"i can understand .*?\.", r"it's completely normal .*?\.",
        r"it's okay to .*?\.", r"remember,.*?\.", r"you are not alone.*?\."
    ]
    for pattern in repeated_phrases:
        text = re.sub(pattern, "", text, flags=re.IGNORECASE).strip()
    sentences = []
    seen = set()
    for sent in re.split(r'(?<=[.!?]) +', text):
        sent_clean = sent.strip().lower()
        if sent_clean not in seen and sent_clean:
            seen.add(sent_clean)
            sentences.append(sent.strip())
    final_text = ' '.join(sentences)
    lines = final_text.strip().split("\n")
    final_text = "\n".join([line for i, line in enumerate(lines) if i == 0 or line != lines[i-1]])
    if final_text:
        final_text = final_text[0].upper() + final_text[1:]
        if not final_text.endswith(('.', '!', '?')):
            final_text += '.'
    return final_text.strip()

def respond(user_message):
    global chat_history
    context = "\n".join([f"User: {u}\nTherapist: {b}" for u, b in chat_history[-3:]])
    full_prompt = f"{context}\nUser: {user_message}\nTherapist:"
    inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)
    output_ids = model.generate(
        **inputs,
        max_new_tokens=150,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )
    full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    reply = clean_reply(full_output.split("Therapist:")[-1].strip(), user_message)
    chat_history.append((user_message, reply))
    formatted_chat = "\n\n".join([
        f"\U0001F464 You: {u}\n\n\U0001F916 Bot: {b}\n\n------" for u, b in chat_history
    ])
    return "", formatted_chat.strip()

def clear_chat():
    global chat_history
    chat_history = []
    return "", ""

def identify_disease(user_symptom):
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
    input_embedding = embedding_model.encode(user_symptom, convert_to_tensor=True)
    similarities = util.pytorch_cos_sim(input_embedding, symptom_data['embeddings'])[0]
    idx = similarities.argmax().item()
    return symptom_data['diseases'][idx], symptom_data['treatments'][idx]

def answer_faq(user_query):
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
    query_embedding = embedding_model.encode(user_query, convert_to_tensor=True)
    similarities = util.pytorch_cos_sim(query_embedding, faq_data['embeddings'])[0]
    idx = similarities.argmax().item()
    return faq_data['answers'][idx]

def clear_symptoms():
    return "", "", ""

def clear_faq():
    return "", ""

custom_css = """
body {
    background-color: #121212;
    color: white;
}
.centered-text {
    text-align: center;
    margin-left: auto;
    margin-right: auto;
}
#submit-btn, #symptom-btn, #faq-btn {
    background-color: #cc5500 !important;
    color: white;
    font-weight: bold;
    border-radius: 8px;
    padding: 10px 20px;
}
.gr-markdown > div {
    background: transparent !important;
}
/* Fix for disclaimer */
#disclaimer-box {
    color: #222 !important;
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown(
        """
        <div style="background-color:#2a2a2a; padding: 15px; border-radius: 10px; border: 1px solid #444; text-align: center; font-size: 14px; color: #ddd;">
        ⚠️ <strong>Disclaimer:</strong> This chatbot is intended for <strong>educational and demonstration purposes only</strong>. It is <strong>not a substitute</strong> for professional medical advice, diagnosis, or treatment.
        </div>
        """,
        elem_classes="centered-text"
    )

    with gr.Tab("💬 Therapy Chatbot"):
        gr.Markdown("## 🧠 TherapyBot++", elem_classes="centered-text")
        gr.Markdown("Thoughtfully designed to support your mental wellness.", elem_classes="centered-text")
        with gr.Row():
            with gr.Column():
                user_input = gr.Textbox(placeholder="How are you feeling?", label="🗣 Your Message", lines=3)
                submit_btn = gr.Button("Submit", elem_id="submit-btn")
                clear_btn = gr.Button("Clear")
            with gr.Column():
                with gr.Accordion("📜 Chat History", open=True):
                    chat_display = gr.Textbox(label="Chat Log", interactive=False, lines=20)
        submit_btn.click(respond, user_input, outputs=[user_input, chat_display])
        clear_btn.click(clear_chat, outputs=[user_input, chat_display])

    with gr.Tab("🧬 Symptom Checker"):
        gr.Markdown("## 🧠 TherapyBot++", elem_classes="centered-text")
        gr.Markdown("Describe your symptoms to get a possible diagnosis and treatment.", elem_classes="centered-text")
        with gr.Row():
            with gr.Column():
                symptom_input = gr.Textbox(placeholder="e.g., I have a fever and rash", label="Enter your symptoms")
                symptom_btn = gr.Button("Check", elem_id="symptom-btn")
                symptom_clear = gr.Button("Clear")
            with gr.Column():
                predicted_disease = gr.Textbox(label="Predicted Disease", interactive=False)
                suggested_treatment = gr.Textbox(label="Suggested Treatment", interactive=False)
        symptom_btn.click(identify_disease, symptom_input, outputs=[predicted_disease, suggested_treatment])
        symptom_clear.click(clear_symptoms, outputs=[symptom_input, predicted_disease, suggested_treatment])

    with gr.Tab("📘 FAQ Support"):
        gr.Markdown("## 🧠 TherapyBot++", elem_classes="centered-text")
        gr.Markdown("Ask your health-related questions to get instant answers.", elem_classes="centered-text")
        with gr.Row():
            with gr.Column():
                faq_input = gr.Textbox(placeholder="e.g., How do I book an appointment?", label="Ask a Question")
                faq_btn = gr.Button("Get Answer", elem_id="faq-btn")
                faq_clear = gr.Button("Clear")
            with gr.Column():
                faq_output = gr.Textbox(label="Answer", interactive=False)
        faq_btn.click(answer_faq, faq_input, outputs=faq_output)
        faq_clear.click(clear_faq, outputs=[faq_input, faq_output])

    gr.Markdown("Made by Ravi⚡️", elem_classes="centered-text")

demo.launch()