In [None]:
"""
Canadian Trails & Parks RAG System - Streamlit Frontend
Optimized for Hugging Face Spaces deployment
"""

import streamlit as st
import os
import time
from typing import Dict, List
import chromadb
from sentence_transformers import SentenceTransformer
import requests

# ============================================================================
# PAGE CONFIG (MUST BE FIRST STREAMLIT COMMAND)
# ============================================================================

st.set_page_config(
    page_title="Canadian Trails & Parks Explorer",
    page_icon="üèîÔ∏è",
    layout="wide",
    initial_sidebar_state="expanded"
)

# ============================================================================
# CUSTOM CSS
# ============================================================================

st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        font-weight: 700;
        color: #1E88E5;
        text-align: center;
        margin-bottom: 1rem;
    }
    .sub-header {
        font-size: 1.2rem;
        color: #666;
        text-align: center;
        margin-bottom: 2rem;
    }
    .result-card {
        background-color: #f8f9fa;
        border-left: 4px solid #1E88E5;
        padding: 1rem;
        margin: 1rem 0;
        border-radius: 4px;
    }
    .source-tag {
        background-color: #E3F2FD;
        color: #1976D2;
        padding: 0.2rem 0.5rem;
        border-radius: 4px;
        font-size: 0.85rem;
        font-weight: 500;
    }
    .metric-card {
        background-color: #fff;
        padding: 1rem;
        border-radius: 8px;
        border: 1px solid #e0e0e0;
        text-align: center;
    }
    .stTextInput > div > div > input {
        font-size: 1.1rem;
    }
</style>
""", unsafe_allow_html=True)

# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Configuration for the RAG system"""
    
    # Vector DB settings
    COLLECTION_NAME = "extra_large_minilm"  # Best performing from evaluation
    DB_PATH = "./data/vector_db"
    
    # Model settings
    EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
    
    # LLM settings - Groq (free, fast)
    LLM_PROVIDER = "groq"
    LLM_MODEL = "llama-3.1-8b-instant"
    
    # Retrieval settings
    TOP_K = 5
    
    # Location mapping for better Toronto queries
    LOCATION_MAPPING = {
        "toronto": ["Ontario South", "Ontario Central"],
        "vancouver": ["British Columbia South"],
        "montreal": ["Quebec South"],
        "calgary": ["Alberta South"],
        "ottawa": ["Ontario South"],
        "banff": ["Alberta South", "AB"],
        "edmonton": ["Alberta South"],
        "winnipeg": ["Manitoba"],
        "halifax": ["Nova Scotia"],
    }

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

@st.cache_resource(show_spinner=False)
def load_embedding_model():
    """Load embedding model (cached)"""
    return SentenceTransformer(Config.EMBEDDING_MODEL)

@st.cache_resource(show_spinner=False)
def load_vector_db():
    """Load ChromaDB collection (cached)"""
    db_path = f"{Config.DB_PATH}/{Config.COLLECTION_NAME}"
    client = chromadb.PersistentClient(path=db_path)
    collection = client.get_collection(Config.COLLECTION_NAME)
    return collection

def get_groq_api_key():
    """Get Groq API key from Hugging Face secrets or env"""
    # Hugging Face Spaces stores secrets as environment variables
    return os.getenv("GROQ_API_KEY", "")

def extract_location_filter(query: str) -> List[str]:
    """Extract location from query for metadata filtering"""
    query_lower = query.lower()
    
    # Check city mapping
    for city, regions in Config.LOCATION_MAPPING.items():
        if city in query_lower:
            return regions
    
    # Check province names
    provinces = {
        "ontario": ["Ontario South", "Ontario Central", "Ontario North"],
        "quebec": ["Quebec South", "Quebec North"],
        "british columbia": ["British Columbia South", "British Columbia North"],
        "alberta": ["Alberta South", "Alberta North"],
        "bc": ["British Columbia South", "British Columbia North"],
        "nova scotia": ["Nova Scotia"],
        "new brunswick": ["New Brunswick"],
    }
    
    for province, regions in provinces.items():
        if province in query_lower:
            return regions
    
    return None

def retrieve_documents(query: str, collection, embedding_model) -> Dict:
    """Retrieve relevant documents"""
    start_time = time.time()
    
    # Encode query
    query_embedding = embedding_model.encode(query).tolist()
    
    # Check for location filter
    where_filter = None
    regions = extract_location_filter(query)
    if regions:
        where_filter = {"region": {"$in": regions}}
    
    # Query collection
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=Config.TOP_K,
        where=where_filter,
        include=['documents', 'metadatas', 'distances']
    )
    
    retrieval_time = time.time() - start_time
    
    # Format results
    formatted_results = []
    for i, (doc, metadata, distance) in enumerate(zip(
        results['documents'][0],
        results['metadatas'][0],
        results['distances'][0]
    )):
        formatted_results.append({
            'rank': i + 1,
            'content': doc,
            'metadata': metadata,
            'similarity': 1 - distance,
            'distance': distance
        })
    
    return {
        'results': formatted_results,
        'retrieval_time': retrieval_time,
        'query': query,
        'filters_applied': where_filter
    }

def generate_answer(query: str, retrieved_docs: List[Dict]) -> Dict:
    """Generate answer using Groq"""
    api_key = get_groq_api_key()
    
    if not api_key:
        return {
            'answer': "‚ö†Ô∏è **API Key Missing**: Please set GROQ_API_KEY in Hugging Face Spaces secrets.",
            'generation_time': 0,
            'sources': []
        }
    
    # Build context
    context_parts = []
    sources = []
    
    for i, doc in enumerate(retrieved_docs, 1):
        metadata = doc['metadata']
        content = doc['content'][:500]  # Truncate for context
        
        source_info = f"[Source {i}]"
        if metadata.get('document_title'):
            source_info += f" {metadata['document_title']}"
            sources.append(metadata['document_title'])
        if metadata.get('region'):
            source_info += f" - {metadata['region']}"
        
        context_parts.append(f"{source_info}\n{content}\n")
    
    context = "\n".join(context_parts)
    
    # Build prompt
    prompt = f"""You are a helpful assistant for Canadian trails and parks information.

Use the following context to answer the user's question. Cite sources using [Source N] format.

Context:
{context}

User Question: {query}

Instructions:
- Provide a detailed answer based on the context
- Always cite sources using [Source N] format
- If context is insufficient, say so
- Be specific about trail names and locations

Answer:"""
    
    start_time = time.time()
    
    try:
        response = requests.post(
            "https://api.groq.com/openai/v1/chat/completions",
            headers={
                "Authorization": f"Bearer {api_key}",
                "Content-Type": "application/json"
            },
            json={
                "model": Config.LLM_MODEL,
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0.3,
                "max_tokens": 800
            },
            timeout=30
        )
        response.raise_for_status()
        result = response.json()
        
        answer = result['choices'][0]['message']['content']
        tokens = result['usage']['total_tokens']
        
    except requests.exceptions.RequestException as e:
        answer = f"‚ö†Ô∏è **Error generating answer**: {str(e)}\n\nPlease check your API key and try again."
        tokens = 0
    
    generation_time = time.time() - start_time
    
    return {
        'answer': answer,
        'generation_time': generation_time,
        'tokens_used': tokens,
        'sources': sources
    }

# ============================================================================
# MAIN APP
# ============================================================================

def main():
    # Header
    st.markdown('<h1 class="main-header">üèîÔ∏è Canadian Trails & Parks Explorer</h1>', unsafe_allow_html=True)
    st.markdown('<p class="sub-header">Powered by AI ‚Ä¢ 277K+ Trails ‚Ä¢ 60+ Parks</p>', unsafe_allow_html=True)
    
    # Sidebar
    with st.sidebar:
        st.markdown("### ‚öôÔ∏è About")
        st.markdown("""
        This RAG system helps you discover Canadian trails and parks using:
        - **277,468** trail records
        - **60+** Parks Canada locations
        - **Free AI** (Groq + Llama 3.1)
        - **Vector search** for intelligent retrieval
        """)
        
        st.markdown("---")
        
        st.markdown("### üí° Example Queries")
        example_queries = [
            "What are hiking trails in British Columbia?",
            "Find wheelchair accessible trails in Ontario",
            "Tell me about Banff National Park",
            "What trails allow bicycles in Quebec?",
            "Find beginner-friendly trails near Toronto"
        ]
        
        for eq in example_queries:
            if st.button(eq, key=eq, use_container_width=True):
                st.session_state.query_input = eq
                st.rerun()
        
        st.markdown("---")
        
        st.markdown("### üìä System Info")
        st.markdown(f"""
        - **Vector DB**: {Config.COLLECTION_NAME}
        - **Embeddings**: MiniLM-L6-v2
        - **LLM**: Llama 3.1 8B (Groq)
        - **Top-K**: {Config.TOP_K}
        """)
        
        st.markdown("---")
        
        st.markdown("### üîó Links")
        st.markdown("[GitHub Repo](https://github.com) ‚Ä¢ [Documentation](https://docs.example.com)")
    
    # Main content
    # Initialize session state
    if 'query_input' not in st.session_state:
        st.session_state.query_input = ""
    
    # Query input
    query = st.text_input(
        "üîç Ask about Canadian trails and parks:",
        value=st.session_state.query_input,
        placeholder="e.g., What are the best hiking trails near Vancouver?",
        key="main_query_input"
    )
    
    col1, col2, col3 = st.columns([2, 1, 1])
    with col1:
        search_button = st.button("üöÄ Search", type="primary", use_container_width=True)
    with col2:
        clear_button = st.button("üóëÔ∏è Clear", use_container_width=True)
    
    if clear_button:
        st.session_state.query_input = ""
        st.rerun()
    
    # Process query
    if search_button and query:
        with st.spinner("üîç Searching knowledge base..."):
            try:
                # Load models
                embedding_model = load_embedding_model()
                collection = load_vector_db()
                
                # Retrieve documents
                retrieval_results = retrieve_documents(query, collection, embedding_model)
                
                # Generate answer
                generation_results = generate_answer(query, retrieval_results['results'])
                
                # Display results
                st.markdown("---")
                
                # Metrics
                col1, col2, col3, col4 = st.columns(4)
                
                with col1:
                    st.markdown('<div class="metric-card">', unsafe_allow_html=True)
                    st.metric("‚ö° Total Time", f"{retrieval_results['retrieval_time'] + generation_results['generation_time']:.2f}s")
                    st.markdown('</div>', unsafe_allow_html=True)
                
                with col2:
                    st.markdown('<div class="metric-card">', unsafe_allow_html=True)
                    st.metric("üìö Sources", len(retrieval_results['results']))
                    st.markdown('</div>', unsafe_allow_html=True)
                
                with col3:
                    st.markdown('<div class="metric-card">', unsafe_allow_html=True)
                    st.metric("üéØ Retrieval", f"{retrieval_results['retrieval_time']:.2f}s")
                    st.markdown('</div>', unsafe_allow_html=True)
                
                with col4:
                    st.markdown('<div class="metric-card">', unsafe_allow_html=True)
                    st.metric("üí¨ Generation", f"{generation_results['generation_time']:.2f}s")
                    st.markdown('</div>', unsafe_allow_html=True)
                
                st.markdown("---")
                
                # Answer
                st.markdown("### üí° Answer")
                st.markdown(generation_results['answer'])
                
                # Sources
                st.markdown("---")
                st.markdown("### üìñ Retrieved Sources")
                
                for doc in retrieval_results['results']:
                    with st.expander(
                        f"**{doc['rank']}. {doc['metadata'].get('document_title', 'Unknown')}** "
                        f"(Similarity: {doc['similarity']:.3f})"
                    ):
                        # Metadata
                        metadata = doc['metadata']
                        
                        cols = st.columns(3)
                        with cols[0]:
                            if metadata.get('region'):
                                st.markdown(f"üìç **Region**: {metadata['region']}")
                        with cols[1]:
                            if metadata.get('document_type'):
                                st.markdown(f"üè∑Ô∏è **Type**: {metadata['document_type']}")
                        with cols[2]:
                            if metadata.get('difficulty'):
                                st.markdown(f"‚ö° **Difficulty**: {metadata['difficulty']}")
                        
                        # Content
                        st.markdown("**Content:**")
                        st.markdown(doc['content'][:500] + "..." if len(doc['content']) > 500 else doc['content'])
                        
                        # Additional metadata
                        if metadata.get('surface'):
                            st.markdown(f"*Surface: {metadata['surface']}*")
                
            except Exception as e:
                st.error(f"‚ùå **Error**: {str(e)}")
                st.markdown("Please check:")
                st.markdown("- Vector database is available at `./data/vector_db/extra_large_minilm`")
                st.markdown("- GROQ_API_KEY is set in Hugging Face Spaces secrets")
                st.markdown("- All dependencies are installed")
    
    elif search_button and not query:
        st.warning("‚ö†Ô∏è Please enter a query")
    
    # Footer
    st.markdown("---")
    st.markdown(
        '<p style="text-align: center; color: #666; font-size: 0.9rem;">'
        'Built with ‚ù§Ô∏è using Streamlit ‚Ä¢ Data from OpenStreetMap & Parks Canada ‚Ä¢ '
        'Free AI by Groq'
        '</p>',
        unsafe_allow_html=True
    )

if __name__ == "__main__":
    main()