In [None]:
import gradio as gr
import chromadb
import requests

# -----------------------------
# Setup ChromaDB (persistent)
# -----------------------------
client = chromadb.PersistentClient(path="/path/chroma_storage_nomic")
collections = {
    "plasmid": client.get_collection("Plasmidfinder"),
    "resfinder": client.get_collection("resistancefinder"),
    "vfdb": client.get_collection("virulencedb"),
    "mge": client.get_collection("mge")
}

# -----------------------------
# Nomic Embed Function (via Ollama API)
# -----------------------------
def nomic_embed(texts):
    """Get embeddings from Nomic via Ollama API (single or batch)."""
    if isinstance(texts, str):
        texts = [texts]

    url = "http://127.0.0.1:11434/api/embeddings"
    headers = {"Content-Type": "application/json"}
    embeddings = []

    for text in texts:
        data = {
            "model": "nomic-embed-text:latest",
            "prompt": text
        }
        response = requests.post(url, headers=headers, json=data)
        if response.status_code != 200:
            raise ConnectionError(f"Failed to get embeddings: {response.status_code} {response.text}")
        res_json = response.json()
        if "embedding" in res_json:
            embeddings.append(res_json["embedding"])
        else:
            raise ValueError(f"No embedding returned for text: {text}")

    return embeddings

# -----------------------------
# Retrieve from ChromaDB
# -----------------------------
def query_chromadb(question):
    """Retrieve top docs from each collection using Nomic embeddings"""
    responses = []
    query_vec = nomic_embed(question)[0]  # get single embedding vector
    
    for name, col in collections.items():
        results = col.query(
            query_embeddings=[query_vec],  # needs list of list
            n_results=3
        )
        docs = results.get("documents", [[]])[0]
        if docs:
            responses.append(f"📂 {name}:\n" + "\n".join(docs))
    return "\n\n".join(responses)

# -----------------------------
# Query Llama (Ollama API)
# -----------------------------
def ask_llama(question, context):
    """Send question + retrieved docs to Llama3.2 (via Ollama API)"""
    prompt = f"""
You are a helpful assistant with expertise in antimicrobial resistance.
Here is some retrieved database context:

{context}

Question: {question}

Answer concisely, using the context where possible.
"""
    url = "http://127.0.0.1:11434/api/generate"
    headers = {"Content-Type": "application/json"}
    data = {
        "model": "llama3.2:1b",   # using llama3.2
        "prompt": prompt,
        "stream": False
    }

    response = requests.post(url, headers=headers, json=data)

    if response.status_code != 200:
        return f"Request failed: {response.status_code} {response.text}"

    res_json = response.json()
    return res_json.get("response", "").strip()

# -----------------------------
# Chatbot Logic
# -----------------------------
def chatbot(user_message, history):
    # Step 1: Retrieve from ChromaDB
    context = query_chromadb(user_message)
    
    # Step 2: Ask Llama with retrieved docs
    answer = ask_llama(user_message, context)
    
    return answer

# -----------------------------
# Gradio Chat Interface
# -----------------------------
demo = gr.ChatInterface(
    chatbot,
    type="messages",
    title="🧬 Resistance Gene Assistant (Llama3.2 + Nomic)"
)

demo.launch(share=True)