In [None]:
"""
=============================================================================
BIOMEDICAL RAG SYSTEM - FIXED VERSION
Neo4j Knowledge Graph + Multi-LLM Backend
Fixed: Updated to currently supported Groq models (October 2024)
=============================================================================
"""

# =============================================================================
# CELL 1: INSTALL DEPENDENCIES
# =============================================================================

!pip install neo4j requests -q

# =============================================================================
# CELL 2: IMPORT LIBRARIES
# =============================================================================

import requests
import json
from typing import Optional, Dict, Any, List
from neo4j import GraphDatabase
import warnings
warnings.filterwarnings('ignore')

print("‚úì All libraries imported successfully")

# =============================================================================
# CELL 3: NEO4J CONNECTOR CLASS
# =============================================================================

class Neo4jConnector:
    """Connect to Neo4j Aura and retrieve biomedical context."""

    def __init__(self, uri: str, username: str, password: str, database: str = "neo4j"):
        self.driver = GraphDatabase.driver(uri, auth=(username, password))
        self.database = database
        print(f"‚úì Connected to Neo4j at {uri}")

    def close(self):
        self.driver.close()
        print("‚úì Neo4j connection closed")

    def query(self, cypher: str, parameters: Optional[Dict] = None) -> List[Dict]:
        """Execute Cypher query and return results."""
        with self.driver.session(database=self.database) as session:
            result = session.run(cypher, parameters or {})
            return [record.data() for record in result]

    def get_schema(self) -> str:
        """Retrieve database schema."""
        try:
            labels_query = "CALL db.labels() YIELD label RETURN collect(label) as labels"
            rels_query = "CALL db.relationshipTypes() YIELD relationshipType RETURN collect(relationshipType) as types"

            labels = self.query(labels_query)
            rels = self.query(rels_query)

            schema = {
                "node_labels": labels[0]['labels'] if labels else [],
                "relationship_types": rels[0]['types'] if rels else []
            }
            return json.dumps(schema, indent=2)
        except Exception as e:
            return f"Schema retrieval error: {e}"

    def get_sample_data(self, limit: int = 5) -> str:
        """Get sample nodes to understand data structure."""
        query = f"""
        MATCH (n)
        RETURN labels(n) AS labels, properties(n) AS props
        LIMIT {limit}
        """
        results = self.query(query)
        return json.dumps(results, indent=2)

    def search_entity(self, entity_name: str, limit: int = 5) -> str:
        """Search for entity and its relationships - optimized for context size."""
        query = """
        MATCH (n)
        WHERE (n.text IS NOT NULL AND toLower(n.text) CONTAINS toLower($entity))
           OR (n.name IS NOT NULL AND toLower(n.name) CONTAINS toLower($entity))
           OR (n.canonical_name IS NOT NULL AND toLower(n.canonical_name) CONTAINS toLower($entity))
        OPTIONAL MATCH (n)-[r]->(m)
        RETURN n, labels(n) AS node_labels, r, type(r) AS rel_type, m, labels(m) AS target_labels
        LIMIT $limit
        """

        results = self.query(query, {"entity": entity_name, "limit": limit})

        if not results:
            return f"No results found for '{entity_name}'."

        # Build concise, clean context
        context_parts = []

        for idx, record in enumerate(results, 1):
            node = record.get('n', {})
            node_labels = record.get('node_labels', [])
            rel_type = record.get('rel_type')
            target = record.get('m')
            target_labels = record.get('target_labels', [])

            # Get node info (exclude embeddings)
            node_props = {k: v for k, v in dict(node).items()
                         if k != 'embedding' and isinstance(v, (str, int, float, bool))}

            # Build node description
            label = node_labels[0] if node_labels else "Node"

            # Get most relevant property
            text_prop = node_props.get('text') or node_props.get('name') or node_props.get('canonical_name') or ''
            if text_prop and len(text_prop) > 200:
                text_prop = text_prop[:200] + "..."

            if text_prop:
                context_parts.append(f"{idx}. {label}: {text_prop}")

            # Add relationship if exists
            if rel_type and target:
                target_props = {k: v for k, v in dict(target).items()
                               if k != 'embedding' and isinstance(v, (str, int, float, bool))}
                target_text = target_props.get('text') or target_props.get('name') or target_props.get('canonical_name') or ''
                if target_text and len(target_text) > 150:
                    target_text = target_text[:150] + "..."

                if target_text:
                    target_label = target_labels[0] if target_labels else "Node"
                    context_parts.append(f"   ‚Üí {rel_type}: {target_label} - {target_text}")

        return "\n".join(context_parts)

    def custom_query_to_context(self, cypher: str, parameters: Optional[Dict] = None) -> str:
        """Execute custom Cypher and format as context."""
        results = self.query(cypher, parameters)

        if not results:
            return "No results returned from query."

        # Simplified formatting
        context_parts = []
        for i, record in enumerate(results[:10], 1):  # Limit to 10 results
            context_parts.append(f"{i}. {json.dumps(record)}")

        return "\n".join(context_parts)

print("‚úì Neo4jConnector class defined")

# =============================================================================
# CELL 4: RAG LLM BACKEND CLASS - FIXED WITH CURRENT MODELS
# =============================================================================

class BiomedicalRAG:
    """RAG system with CURRENT Groq models and multiple backend support."""

    def __init__(self, backend: str = "groq", api_key: Optional[str] = None, model: Optional[str] = None):
        self.backend = backend
        self.api_key = api_key

        # Updated model configurations with CURRENTLY SUPPORTED models
        self.configs = {
            "groq": {
                "url": "https://api.groq.com/openai/v1/chat/completions",
                "headers": {
                    "Authorization": f"Bearer {api_key}",
                    "Content-Type": "application/json"
                } if api_key else {},
                # FIXED: Updated to currently available models (Oct 2024)
                "model": model or "llama-3.1-70b-versatile",  # This should still work
                "available_models": [
                    "llama-3.1-70b-versatile",
                    "llama-3.1-8b-instant",
                    "llama-3.2-1b-preview",
                    "llama-3.2-3b-preview",
                    "llama-3.2-11b-vision-preview",
                    "llama-3.2-90b-vision-preview",
                    "gemma2-9b-it",
                    "mixtral-8x7b-32768"
                ]
            },
            "huggingface": {
                "url": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
                "headers": {"Authorization": f"Bearer {api_key}"} if api_key else {}
            },
            "ollama": {
                "url": "http://localhost:11434/api/generate",
                "model": model or "llama3.1"
            }
        }

        current_model = self.configs[backend].get("model", "default")
        print(f"‚úì RAG system initialized with {backend} backend")
        print(f"  Model: {current_model}")

    def _build_prompt(self, question: str, context: str, max_context_length: int = 3000) -> str:
        """Build prompt with context length limiting."""
        # Truncate context if too long
        if len(context) > max_context_length:
            context = context[:max_context_length] + "\n... (truncated for length)"

        prompt = f"""You are a biomedical AI assistant. Use the provided knowledge graph context to answer the question.

Context from Knowledge Graph:
{context}

Question: {question}

Provide a clear, concise answer based ONLY on the context above. If the context doesn't contain enough information, say so.

Answer:"""

        return prompt

    def _query_groq(self, prompt: str) -> str:
        """Query Groq API with error handling and fallback."""
        config = self.configs["groq"]

        payload = {
            "model": config["model"],
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "temperature": 0.1,
            "max_tokens": 500,
            "top_p": 1,
            "stream": False
        }

        try:
            response = requests.post(
                config["url"],
                headers=config["headers"],
                json=payload,
                timeout=30
            )

            # Enhanced error handling
            if response.status_code != 200:
                print(f"‚ö†Ô∏è Error Status: {response.status_code}")
                try:
                    error_json = response.json()
                    error_msg = error_json.get('error', {}).get('message', '')
                    print(f"‚ö†Ô∏è Error Details: {error_msg}")

                    # Check if model is decommissioned and try fallback
                    if 'decommissioned' in error_msg.lower():
                        print(f"üîÑ Model '{config['model']}' is decommissioned. Trying fallback model...")
                        # Try llama-3.1-8b-instant as fallback
                        config["model"] = "llama-3.1-8b-instant"
                        payload["model"] = "llama-3.1-8b-instant"

                        response = requests.post(
                            config["url"],
                            headers=config["headers"],
                            json=payload,
                            timeout=30
                        )

                        if response.status_code == 200:
                            result = response.json()
                            print(f"‚úÖ Fallback successful! Using: {config['model']}")
                            return result["choices"][0]["message"]["content"].strip()
                        else:
                            return f"Groq API Error: Both primary and fallback models failed. Please check https://console.groq.com/docs/models for current models."

                    return f"Groq API Error: {error_msg}"
                except:
                    print(f"‚ö†Ô∏è Raw Response: {response.text[:500]}")
                    return f"Groq API Error: Status {response.status_code}"

            response.raise_for_status()
            result = response.json()
            return result["choices"][0]["message"]["content"].strip()

        except requests.exceptions.HTTPError as e:
            error_detail = ""
            try:
                error_detail = response.json()
                if 'error' in error_detail:
                    error_msg = error_detail['error'].get('message', str(e))
                    return f"Groq API Error: {error_msg}\n\nüí° Please visit https://console.groq.com/docs/models to check available models."
            except:
                pass
            return f"HTTP Error {response.status_code}: {str(e)}"
        except Exception as e:
            return f"Error querying Groq: {str(e)}"

    def _query_huggingface(self, prompt: str) -> str:
        """Query HuggingFace API."""
        config = self.configs["huggingface"]

        payload = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": 400,
                "temperature": 0.1,
                "top_p": 0.9,
                "return_full_text": False
            }
        }

        try:
            response = requests.post(
                config["url"],
                headers=config["headers"],
                json=payload,
                timeout=30
            )
            response.raise_for_status()
            result = response.json()

            if isinstance(result, list) and len(result) > 0:
                return result[0].get("generated_text", "").strip()
            return result.get("generated_text", "").strip()
        except Exception as e:
            return f"Error querying HuggingFace: {str(e)}"

    def _query_ollama(self, prompt: str) -> str:
        """Query local Ollama instance."""
        config = self.configs["ollama"]

        payload = {
            "model": config["model"],
            "prompt": prompt,
            "stream": False,
            "options": {"temperature": 0.1, "num_predict": 512}
        }

        try:
            response = requests.post(config["url"], json=payload, timeout=60)
            response.raise_for_status()
            return response.json().get("response", "").strip()
        except Exception as e:
            return f"Error querying Ollama: {str(e)}\nMake sure Ollama is running: ollama serve"

    def answer(self, question: str, neo4j_context: str) -> Dict[str, Any]:
        """Generate answer using RAG pipeline."""
        prompt = self._build_prompt(question, neo4j_context)

        if self.backend == "groq":
            answer = self._query_groq(prompt)
        elif self.backend == "huggingface":
            answer = self._query_huggingface(prompt)
        elif self.backend == "ollama":
            answer = self._query_ollama(prompt)
        else:
            answer = f"Unknown backend: {self.backend}"

        return {
            "question": question,
            "answer": answer,
            "context": neo4j_context,
            "backend": self.backend,
            "model": self.configs[self.backend].get("model", "N/A"),
            "prompt_length": len(prompt),
            "context_length": len(neo4j_context)
        }

print("‚úì BiomedicalRAG class defined with fallback support")

# =============================================================================
# CELL 5: COMPLETE RAG PIPELINE CLASS
# =============================================================================

class CompleteRAGPipeline:
    """End-to-end RAG: Neo4j ‚Üí Context ‚Üí LLM ‚Üí Answer"""

    def __init__(self, neo4j_connector: Neo4jConnector, rag_system: BiomedicalRAG):
        self.neo4j = neo4j_connector
        self.rag = rag_system
        print("‚úì Complete RAG pipeline initialized")

    def ask(self, question: str, entity: Optional[str] = None,
            custom_cypher: Optional[str] = None, cypher_params: Optional[Dict] = None) -> Dict[str, Any]:
        """Answer question using full RAG pipeline."""
        print(f"\nüîç Processing: {question}")

        # Retrieve context from Neo4j
        if custom_cypher:
            print("üìä Executing custom Cypher query...")
            context = self.neo4j.custom_query_to_context(custom_cypher, cypher_params)
        elif entity:
            print(f"üìä Searching knowledge graph for: {entity}")
            context = self.neo4j.search_entity(entity, limit=5)
        else:
            # Auto-extract entity from question
            words = question.split()
            potential_entities = [w.strip('?,.:;') for w in words if len(w) > 4]
            if potential_entities:
                entity = potential_entities[0]
                print(f"üìä Auto-extracted entity: {entity}")
                context = self.neo4j.search_entity(entity, limit=5)
            else:
                context = "No specific entity found. Please provide more context."

        print(f"‚úì Retrieved context ({len(context)} chars)")

        # Generate answer with LLM
        print("ü§ñ Generating answer...")
        result = self.rag.answer(question, context)
        print("‚úì Answer generated\n")

        return result

    def close(self):
        self.neo4j.close()

print("‚úì CompleteRAGPipeline class defined")

# =============================================================================
# CELL 6: CONNECT TO NEO4J
# =============================================================================

print("\n" + "="*70)
print("CONNECTING TO NEO4J AURA")
print("="*70)

NEO4J_URI = "HIDDEN"
NEO4J_USERNAME = "HIDDEN"
NEO4J_PASSWORD = "HIDDEN"
NEO4J_DATABASE = "HIDDEN"

neo4j = Neo4jConnector(
    uri=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD,
    database=NEO4J_DATABASE
)

# =============================================================================
# CELL 7: EXPLORE DATABASE
# =============================================================================

print("\n" + "="*70)
print("DATABASE EXPLORATION")
print("="*70)

print("\nüìã Database Schema:")
schema = neo4j.get_schema()
print(schema)

print("\nüìä Sample Data:")
sample = neo4j.get_sample_data(limit=2)
print(sample)

# =============================================================================
# CELL 8: INITIALIZE RAG SYSTEM WITH WORKING MODELS
# =============================================================================

print("\n" + "="*70)
print("INITIALIZING RAG SYSTEM")
print("="*70)

# üî¥ REPLACE WITH YOUR API KEY
GROQ_KEY = "HIDDEN"

# FIXED: Using models that should work as of October 2024
# Try these in order of preference:
# 1. "llama-3.1-8b-instant" (Fast, reliable)
# 2. "gemma2-9b-it" (Alternative)
# 3. "mixtral-8x7b-32768" (If available)

rag = BiomedicalRAG(
    backend="groq",
    api_key=GROQ_KEY,
    model="llama-3.1-8b-instant"  # CHANGED: Using a model that should work
)

# Test API connection
print("\nüß™ Testing Groq API connection...")
test_result = rag._query_groq("Say 'Hello' in one word.")
if "Error" not in test_result and "API" not in test_result:
    print(f"‚úÖ API working! Response: {test_result}")
else:
    print(f"‚ùå API test failed: {test_result}")
    print("\nüí° Troubleshooting:")
    print("   1. Verify your API key at https://console.groq.com/keys")
    print("   2. Check current available models at https://console.groq.com/docs/models")
    print("   3. The code will automatically try fallback models")

# Create pipeline
pipeline = CompleteRAGPipeline(neo4j, rag)

# =============================================================================
# CELL 9: RUN EXAMPLE QUERIES
# =============================================================================

print("\n" + "="*70)
print("EXAMPLE QUERIES")
print("="*70)

# Example 1
print("\n--- Example 1: Disease Mechanism ---")
result1 = pipeline.ask(
    question="What causes jaundice in cirrhosis?",
    entity="cirrhosis"
)
print(f"üìù Question: {result1['question']}")
print(f"ü§ñ Model: {result1['model']}")
print(f"üìä Context length: {result1['context_length']} chars")
print(f"üí° Answer:\n{result1['answer']}")

# Example 2
print("\n--- Example 2: Drug Mechanism ---")
result2 = pipeline.ask(
    question="How does metformin work?",
    entity="metformin"
)
print(f"üìù Question: {result2['question']}")
print(f"üí° Answer:\n{result2['answer']}")

# Example 3
print("\n--- Example 3: Simple Query ---")
result3 = pipeline.ask(
    question="What are common diabetes symptoms?",
    entity="diabetes"
)
print(f"üìù Question: {result3['question']}")
print(f"üí° Answer:\n{result3['answer']}")

# =============================================================================
# CELL 10: INTERACTIVE MODE (OPTIONAL)
# =============================================================================

print("\n" + "="*70)
print("INTERACTIVE MODE")
print("="*70)
print("\n‚ú® Uncomment the code below to enable interactive Q&A:")
print("""
# Interactive loop
while True:
    q = input("\\n\\nYour question (or 'quit'): ")
    if q.lower() in ['quit', 'exit', 'q']:
        break
    entity = input("Entity to search (or Enter to auto-detect): ").strip() or None
    result = pipeline.ask(q, entity=entity)
    print(f"\\nüí° Answer:\\n{result['answer']}")
    print(f"\\nüìä Used {result['context_length']} chars of context")
""")

# =============================================================================
# CELL 11: ADVANCED: SWITCH MODELS ON THE FLY
# =============================================================================

print("\n" + "="*70)
print("ADVANCED: MODEL SWITCHING")
print("="*70)
print("\nüîß You can switch models dynamically:")
print("""
# Example: Try different Groq models
rag_fast = BiomedicalRAG(backend="groq", api_key=GROQ_KEY, model="llama-3.1-8b-instant")
pipeline_fast = CompleteRAGPipeline(neo4j, rag_fast)

# Example: Try Gemma
rag_gemma = BiomedicalRAG(backend="groq", api_key=GROQ_KEY, model="gemma2-9b-it")
pipeline_gemma = CompleteRAGPipeline(neo4j, rag_gemma)

# Example: Use HuggingFace instead
# HF_TOKEN = "your_token_here"
# rag_hf = BiomedicalRAG(backend="huggingface", api_key=HF_TOKEN)
# pipeline_hf = CompleteRAGPipeline(neo4j, rag_hf)

# Example: Use local Ollama
# rag_local = BiomedicalRAG(backend="ollama", model="llama3.1")
# pipeline_local = CompleteRAGPipeline(neo4j, rag_local)
""")

# =============================================================================
# CELL 12: CLEANUP
# =============================================================================

print("\n" + "="*70)
print("CLEANUP")
print("="*70)
print("‚úÖ All systems operational!")
print("üìå Run: pipeline.close() when done")
print("\nüöÄ Quick commands:")
print("   result = pipeline.ask('Your question here', entity='entity_name')")
print("   result = pipeline.ask('Custom query', custom_cypher='MATCH (n) RETURN n LIMIT 5')")
print("\nüìö Currently using Groq models (as of Oct 2024):")
print("   - llama-3.1-8b-instant (fast, recommended)")
print("   - gemma2-9b-it (alternative)")
print("   - mixtral-8x7b-32768 (if available)")
print("\nüí° Note: The code includes automatic fallback if a model is decommissioned!")

‚úì All libraries imported successfully
‚úì Neo4jConnector class defined
‚úì BiomedicalRAG class defined with fallback support
‚úì CompleteRAGPipeline class defined

CONNECTING TO NEO4J AURA
‚úì Connected to Neo4j at neo4j+s://62418b31.databases.neo4j.io

DATABASE EXPLORATION

üìã Database Schema:
{
  "node_labels": [
    "[Info Tree Node]",
    "[Info Tree Root]",
    "[Disease]",
    "[Embedding Root]",
    "[Embedding Node, pseudo medical history Embedding Node]",
    "[Embedding Node, diagnosis info Embedding Node]",
    "Disease",
    "Info Tree Node",
    "EntityCanonical"
  ],
  "relationship_types": [
    "tree_node-tree_node",
    "tree_root-tree_node",
    "main_node-tree_root",
    "main_node-embedding_root",
    "embedding_root-embedding_node",
    "CANONICAL_FORM",
    "SIMILAR_TO"
  ]
}

üìä Sample Data:
[
  {
    "labels": [
      "[Info Tree Node]"
    ],
    "props": {
      "text": "Diaphragmatic hernia",
      "original_id": "4:77316bc7-08a7-42a4-afdd-da77a7c4cd66: