<a href="https://colab.research.google.com/github/vishvedula/AI_ML/blob/main/Agentic_AI_Clinical_Trial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# === Required Imports ===
import os
import streamlit as st
from typing import TypedDict, List
from pydantic import BaseModel, Field
from langchain_groq import ChatGroq
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.document_loaders import CSVLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langgraph.graph import StateGraph, START, END
from langchain.schema import Document
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
from streamlit_option_menu import option_menu
import tempfile
import warnings
warnings.filterwarnings("ignore")

st.set_page_config(
    page_title="InsightRx",
    page_icon="🧪",
    layout="wide",
    initial_sidebar_state="expanded"
)

# === Init ===
load_dotenv()

# === Styling ===
st.markdown("""
    <style>
        .main > div {
            padding-top: 1rem;
        }
        .chat-message {
            padding: 1rem;
            border-radius: 10px;
            margin: 1rem 0;
            border-left: 4px solid;
        }
        .user-message { background-color: #f0f2f6; border-left-color: #4CAF50; }
        .assistant-message { background-color: #e8f4f8; border-left-color: #2196F3; }
        .clarify-message { background-color: #e3f2fd; border-left-color: #29b6f6; }
        .clarify-badge { background-color: #e1f5fe; color: #0277bd; border: 1px solid #4fc3f7; }
        .analytical-message { background-color: #f3e5f5; border-left-color: #9C27B0; }
        .analytical-badge { background-color: #f3e5f5; color: #4A148C; border: 1px solid #CE93D8; }

        .logo-container {
            text-align: center;
            padding: 2rem 1rem;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            border-radius: 15px;
            margin-bottom: 2rem;
            color: white;
        }

        .upload-container {
            background-color: #f8f9fa;
            padding: 2rem;
            border-radius: 10px;
            border: 2px dashed #dee2e6;
            margin: 1rem 0;
        }

        .upload-success {
            background-color: #d4edda;
            color: #155724;
            padding: 1rem;
            border-radius: 5px;
            border: 1px solid #c3e6cb;
            margin: 1rem 0;
        }
    </style>
    """, unsafe_allow_html=True)

# === Vector DB Loader ===
embedding_model = HuggingFaceEmbeddings(model_name=os.path.join(os.getcwd()))

def load_vectorstore_from_csv(csv_path: str):
    loader = CSVLoader(file_path=csv_path)
    documents = loader.load()
    chunks = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(documents)
    return FAISS.from_documents(chunks, embedding_model)

@st.cache_resource
def initialize_chatbot(_vectorstore):
    llm = ChatGroq(temperature=0.3, model_name="llama3-70b-8192", api_key=os.getenv("GROQ_API_KEY"))

    class State(TypedDict):
        question: str
        docs: List[Document]
        answer: str
        message_type: str
        response_content: str

    def clarify_user_query(state: State):
        prompt = PromptTemplate.from_template("""
        You are an AI assistant helping a user clarify a question about clinical studies.
        Original Question: {question}
        """)
        follow_up = LLMChain(llm=llm, prompt=prompt).run(question=state["question"])
        return { **state, "message_type": "clarify", "response_content": follow_up }

    def query_trials(state: State):
        try:
            results = _vectorstore.similarity_search(state["question"], k=5)
        except:
            results = []
        return { **state, "docs": results }

    def generate_answer(state: State):
        docs = state.get("docs", [])
        if not docs:
            final_answer = "No relevant results found."
        else:
            context = "\n".join([doc.page_content for doc in docs])
            prompt = PromptTemplate.from_template("""
            You are a helpful assistant that answers user questions based on clinical trials data.
            Context:
            {context}
            Question:
            {question}
            Provide a detailed response based only on the context provided. If the context doesn't contain relevant information, say so clearly.
            """)
            final_answer = LLMChain(llm=llm, prompt=prompt).run(context=context, question=state["question"])
        return { **state, "answer": final_answer, "message_type": "analytical", "response_content": final_answer }

    graph = StateGraph(State)
    graph.add_node("clarify", clarify_user_query)
    graph.add_node("search", query_trials)
    graph.add_node("analytical", generate_answer)
    graph.add_edge(START, "clarify").add_edge("clarify", "search").add_edge("search", "analytical").add_edge("analytical", END)
    return graph.compile()

# === Load chatbot if file uploaded ===
if "chat_messages" not in st.session_state:
    st.session_state.chat_messages = []

if "user_data" in st.session_state and "file_path" in st.session_state["user_data"]:
    file_path = st.session_state["user_data"]["file_path"]
    if "graph" not in st.session_state:
        vectordb = load_vectorstore_from_csv(file_path)
        st.session_state.graph = initialize_chatbot(vectordb)
else:
    if "graph" not in st.session_state:
        vectordb = load_vectorstore_from_csv("clinical_trials.csv")
        st.session_state.graph = initialize_chatbot(vectordb)

# === Sidebar with Logo/Branding ===
with st.sidebar:
    st.markdown("""
    <div class="logo-container">
        <h1>🧪 InsightRx</h1>
        <p><em>AI-Powered Clinical Research Intelligence</em></p>
        <hr style="border-color: rgba(255,255,255,0.3);">
        <h3> 🤖 Smart Agents</h3>
        <p><strong>🧬 Insightify</strong><br>Deep analytical insights</p>
    </div>
    """, unsafe_allow_html=True)

    # Add some stats or info
    st.markdown("### 📊 Quick Stats")
    if st.session_state.chat_messages:
        total_messages = len(st.session_state.chat_messages)
        user_messages = len([msg for msg in st.session_state.chat_messages if msg['role'] == 'user'])
        st.metric("Total Messages", total_messages)
        st.metric("Your Queries", user_messages)
    else:
        st.info("No conversations yet. Upload data and start chatting!")

    # Optional: Add a small visualization placeholder
    st.markdown("### 🎯 Active Status")
    if "graph" in st.session_state:
        st.success("✅ AI Agents Ready")
    else:
        st.warning("⏳ Loading AI Agents...")

# === Horizontal NavBar ===
selected = option_menu(
    menu_title=None,
    options=["Dashboard", "Rules", "Sources", "Explorer", "Graph", "AI", "Authors"],
    icons=["bar-chart", "clipboard-check", "folder", "search", "graph-up", "robot", "people"],
    orientation="horizontal",
    default_index=5
)

# === Main Router ===
if selected == "AI":
    st.header("🤖 AI Chat Interface")
    st.markdown("*Interact with our intelligent agents for clinical research insights*")

    for msg in st.session_state.chat_messages:
        if msg['role'] == 'user':
            st.markdown(f"<div class='chat-message user-message'><b>You:</b><br>{msg['content']}</div>", unsafe_allow_html=True)
        else:
            msg_type = msg.get('type', 'clarify')
            styles = {
                "clarify": ("clarify-message", "clarify-badge", "🤔 IntentAnalyzer Assistant"),
                "analytical": ("analytical-message", "analytical-badge", "🧬 Insightify Assistant")
            }
            css_class, badge_class, name = styles.get(msg_type, styles["clarify"])
            st.markdown(f"""
                <div class='chat-message {css_class}'>
                    <div class='message-type-badge {badge_class}'>{name}</div><br>
                    {msg['content']}
                </div>
            """, unsafe_allow_html=True)

    with st.form("chat_form", clear_on_submit=True):
        user_input = st.text_area("💬 Ask me anything about clinical trials:", height=100, placeholder="e.g., 'Which trials are investigating COVID-19 vaccines?'")
        col1, col2, col3 = st.columns([2, 1, 1])
        with col2:
            submit = st.form_submit_button("Send Message", use_container_width=True)
        with col3:
            clear_chat = st.form_submit_button("Clear Chat", use_container_width=True)

    if clear_chat:
        st.session_state.chat_messages = []
        st.rerun()

    if submit and user_input.strip():
        with st.spinner('🤖 Processing your message...'):
            st.session_state.chat_messages.append({"role": "user", "content": user_input})
            result = st.session_state.graph.invoke({"question": user_input, "docs": [], "answer": "", "message_type": "", "response_content": ""})
            st.session_state.chat_messages.append({"role": "assistant", "content": result['response_content'], "type": result['message_type']})
            st.rerun()

elif selected == "Dashboard":
    st.title("📊 Dashboard")
    st.markdown("*Overview of your clinical research analytics*")

    col1, col2, col3, col4 = st.columns(4)
    with col1:
        st.metric("Total Queries", len(st.session_state.chat_messages))
    with col2:
        st.metric("Data Sources", "1" if "user_data" in st.session_state else "0")
    with col3:
        st.metric("Active Agents", "1")
    with col4:
        st.metric("Success Rate", "98%")

    st.info("📈 Dashboard analytics and KPIs will be displayed here.")

elif selected == "Rules":
    st.title("📜 Rules & Configuration")
    st.markdown("*Configure AI behavior and business rules*")
    st.info("⚙️ Rule-based engine explanations, conditions, and parameters will be configured here.")

elif selected == "Sources":
    st.title("📁 Data Sources Management")
    st.markdown("*Upload and manage your clinical trial datasets*")

    # Upload Section
    st.subheader("🔄 Upload New Data")

    st.markdown("""
    <div class="upload-container">
        <h4>📂 Clinical Trials CSV Upload</h4>
        <p>Upload your clinical trials data in CSV format. Supported sources:</p>
        <ul>
            <li>ClinicalTrials.gov exports</li>
            <li>Custom trial databases</li>
            <li>Research institution datasets</li>
        </ul>
    </div>
    """, unsafe_allow_html=True)

    uploaded_file = st.file_uploader(
        "Choose your CSV file",
        type=["csv"],
        help="Upload clinical trials data from ClinicalTrials.gov or other sources"
    )

    if uploaded_file:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as f:
            f.write(uploaded_file.getvalue())
            st.session_state["user_data"] = {"file_path": f.name}

        st.markdown("""
        <div class="upload-success">
            <strong>✅ Success!</strong> Your file has been uploaded successfully.
            The AI agents will now use your custom dataset for analysis.
        </div>
        """, unsafe_allow_html=True)

        # Show file info
        st.subheader("📋 File Information")
        st.write(f"**Filename:** {uploaded_file.name}")
        st.write(f"**Size:** {uploaded_file.size} bytes")

        if st.button("🔄 Reinitialize AI Agents with New Data"):
            with st.spinner("Rebuilding AI agents with your data..."):
                # Clear the cached graph so it rebuilds with new data
                if "graph" in st.session_state:
                    del st.session_state["graph"]
                st.success("✅ AI agents updated! Go to the AI tab to start chatting.")
                st.rerun()

    # Current Sources Section
    st.subheader("📊 Current Data Sources")
    if "user_data" in st.session_state:
        st.success("🎯 **Custom Dataset Active** - Using your uploaded clinical trials data")
    else:
        st.info("📖 **Default Dataset Active** - Using built-in clinical trials sample data")

    # Data Preview (optional)
    if st.checkbox("🔍 Preview Data"):
        st.info("Data preview functionality can be added here to show sample rows from the loaded dataset.")

elif selected == "Explorer":
    st.title("🔍 Data Explorer")
    st.markdown("*Explore and analyze your clinical trial data*")
    st.info("🔬 Document-level insights, vector embeddings visualization, and data exploration tools.")

elif selected == "Graph":
    st.title("📈 Agent Workflow Visualization")
    st.markdown("*Visual representation of AI agent pipeline*")

    st.subheader("🔄 Current Workflow")
    st.markdown("""
    ```
    User Query → IntentAnalyzer → Vector Search → Insightify → Response
         ↓              ↓              ↓           ↓
    Classification → Clarification → Retrieval → Analysis
    ```
    """)

    st.info("📊 Interactive LangGraph state flow and pipeline visualizations will be displayed here.")

elif selected == "Authors":
    st.title("👨‍🔬 Authors & Credits")
    st.markdown("*Meet the team behind InsightRx*")

    st.markdown("""
    ### 🧪 InsightRx Development Team

    **AI Research & Development**
    - Clinical AI Specialists
    - LangGraph Architecture
    - Vector Database Engineering

    **Data Science**
    - Clinical Trials Analytics
    - Semantic Search Implementation
    - Performance Optimization
    """)

    st.info("👥 Detailed contributor information and authorship data.")

# === Footer ===
st.markdown("---")
st.markdown("""
    <div style="text-align: center; color: #666; padding: 1rem;">
        <p>🧪 InsightRx - Powered by LangGraph & Streamlit</p>
        <p><em>Semantic insights for clinical research teams</em></p>
    </div>
    """, unsafe_allow_html=True)