In [33]:
import glob
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from sklearn.manifold import TSNE
from chromadb import PersistentClient
from typing import List
import os
import shutil
import groq

import os
import re
import json
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Any

from langchain_core.documents import Document
groq_client = groq.Groq(api_key=os.getenv("GROQ_API_KEY"))
OPENSOURCE_OSS_MODEL = "openai/gpt-oss-120b"


In [34]:

# -----------------------------
# 1) Parsing & chunking
# -----------------------------

EVENT_HEADING_RE = re.compile(r"^###\s+(.*)\s*$", re.MULTILINE)
FIELD_RE = {
    "event_type": re.compile(r"^\s*-\s*Event Type:\s*(.+?)\s*$", re.MULTILINE | re.IGNORECASE),
    "category": re.compile(r"^\s*-\s*Category:\s*(.+?)\s*$", re.MULTILINE | re.IGNORECASE),
    "age_tags": re.compile(r"^\s*-\s*Age Tags:\s*(.+?)\s*$", re.MULTILINE | re.IGNORECASE),
    "instructor": re.compile(r"^\s*-\s*(Instructor|Facilitator):\s*(.+?)\s*$", re.MULTILINE | re.IGNORECASE),
    "date_range": re.compile(r"^\s*-\s*Date Range:\s*(.+?)\s*$", re.MULTILINE | re.IGNORECASE),
    "time_slots": re.compile(r"^\s*-\s*Time Slots:\s*(.+?)\s*$", re.MULTILINE | re.IGNORECASE),
    "duration": re.compile(r"^\s*-\s*Duration:\s*(.+?)\s*$", re.MULTILINE | re.IGNORECASE),
    "spots": re.compile(r"^\s*-\s*Spots:\s*(.+?)\s*$", re.MULTILINE | re.IGNORECASE),
}

CENTER_RE = re.compile(r"^##\s+(.+?)\s*$", re.MULTILINE)
LOCATION_RE = re.compile(r"^\*\*Location:\*\*\s*(.+?)\s*$", re.MULTILINE)
TYPE_RE = re.compile(r"^\*\*Type:\*\*\s*(.+?)\s*$", re.MULTILINE)

PAGE_RE = re.compile(r"^#\s+PAGE\s+\d+\s+—\s+(.+?)\s*$", re.MULTILINE)


def _safe_find(regex: re.Pattern, text: str) -> Optional[str]:
    m = regex.search(text)
    return m.group(1).strip() if m else None


def _safe_find2(regex: re.Pattern, text: str) -> Optional[str]:
    m = regex.search(text)
    return m.group(2).strip() if m else None


def parse_center_metadata(md_text: str, source: str) -> Dict[str, Optional[str]]:
    center_name = _safe_find(CENTER_RE, md_text)
    location = _safe_find(LOCATION_RE, md_text)
    center_type = _safe_find(TYPE_RE, md_text)

    city, state = None, None
    if location:
        # "Salem, Massachusetts" or "Plymouth, Massachusetts"
        parts = [p.strip() for p in location.split(",")]
        if len(parts) >= 2:
            city, state = parts[0], parts[1]

    return {
        "source": source,
        "center_name": center_name,
        "center_type": center_type,
        "city": city,
        "state": state,
    }


def split_event_blocks(md_text: str) -> List[Tuple[str, str]]:
    """
    Returns list of (event_title, event_block_text).
    Event blocks start with '### ' and continue until next '### ' or end.
    """
    matches = list(EVENT_HEADING_RE.finditer(md_text))
    blocks: List[Tuple[str, str]] = []
    for i, m in enumerate(matches):
        title = m.group(1).strip()
        start = m.start()
        end = matches[i + 1].start() if i + 1 < len(matches) else len(md_text)
        block = md_text[start:end].strip()
        blocks.append((title, block))
    return blocks


def parse_event_metadata(event_title: str, block: str) -> Dict[str, Optional[str]]:
    # Event Type may appear as "Event Type:" or "Category:" depending on file style
    event_type = _safe_find(FIELD_RE["event_type"], block) or _safe_find(FIELD_RE["category"], block)

    age_tags = _safe_find(FIELD_RE["age_tags"], block)
    instructor = _safe_find2(FIELD_RE["instructor"], block)
    date_range = _safe_find(FIELD_RE["date_range"], block)
    time_slots = _safe_find(FIELD_RE["time_slots"], block)
    duration = _safe_find(FIELD_RE["duration"], block)
    spots = _safe_find(FIELD_RE["spots"], block)

    return {
        "event_title": event_title,
        "event_type": event_type,
        "age_tags": age_tags,
        "instructor": instructor,
        "date_range": date_range,
        "time_slots": time_slots,
        "duration": duration,
        "spots": spots,
    }


def build_event_documents(md_text: str, source: str) -> List[Document]:
    center_md = parse_center_metadata(md_text, source)
    docs: List[Document] = []

    for title, block in split_event_blocks(md_text):
        meta = parse_event_metadata(title, block)
        combined_meta = {**center_md, **meta, "doc_type": "event"}

        # Use compact but rich page_content for embeddings
        content = (
            f"Center: {center_md.get('center_name')} ({center_md.get('center_type')})\n"
            f"Location: {center_md.get('city')}, {center_md.get('state')}\n"
            f"Event: {title}\n"
            f"Event Type: {meta.get('event_type')}\n"
            f"Age Tags: {meta.get('age_tags')}\n"
            f"Instructor: {meta.get('instructor')}\n"
            f"Date Range: {meta.get('date_range')}\n"
            f"Time Slots: {meta.get('time_slots')}\n"
            f"Duration: {meta.get('duration')}\n"
            f"Spots: {meta.get('spots')}\n\n"
            f"Raw Block:\n{block}\n"
        ).strip()

        docs.append(Document(page_content=content, metadata=combined_meta))

    return docs

def _normalize_activity_heading(s: str) -> str:
    """
    Normalizes headings for robust matching.
    - Uppercases
    - Collapses whitespace
    - Removes most punctuation (keeps alphanumerics and spaces)
    """
    if not s:
        return ""
    s = s.strip().upper()
    s = re.sub(r"\s+", " ", s)
    # remove punctuation/symbols except spaces and alphanumerics
    s = re.sub(r"[^A-Z0-9 ]+", "", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s





In [35]:
# Load in everything in the knowledgebase using LangChain's loaders

# Get the project root directory
# In notebooks, we need to find the project root relative to current working directory
current_dir = os.getcwd()
# If we're in helper/, go up one level; otherwise assume we're at project root
if os.path.basename(current_dir) == "helper":
    project_root = os.path.dirname(current_dir)
else:
    project_root = current_dir

documents_path = os.path.join(project_root, "documents")

# point this to the documents folder
folders = glob.glob(os.path.join(documents_path, "*"))

# create list of events md file list
events_md_files: List[str] = []
# create list of activity type md file list
activitytype_md_files: List[str] = []


for folder in folders:
    if os.path.isdir(folder):
        doc_type = os.path.basename(folder)
        loader = DirectoryLoader(folder, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={'encoding': 'utf-8'})
        folder_docs = loader.load()
        print(f"Loaded {len(folder_docs)} documents from {doc_type}")
        for doc in folder_docs:
            # load events in events_md_files list (append file path, not Document object)
            if doc_type == "Events":
                events_md_files.append(doc.metadata.get("source", ""))
            # load activity types in activitytype_md_files list (append file path, not Document object)
            elif doc_type == "activityType":
                activitytype_md_files.append(doc.metadata.get("source", ""))
            doc.metadata["doc_type"] = doc_type



Loaded 6 documents from activityType
Loaded 0 documents from prompts
Loaded 11 documents from Events
Loaded 0 documents from Reviews


In [36]:
import re
from typing import List, Optional, Tuple
from langchain_core.documents import Document

_EVENT_BLOCK_RE = re.compile(r"(?m)^###\s+(.+?)\s*$")


# ---------------------------
# Normalization helpers
# ---------------------------

def normalize_event_type(s: Optional[str]) -> Optional[str]:
    """Upper-case and strip symbols like ® / ™ and extra punctuation."""
    if not s:
        return None
    s = s.strip().upper()
    s = s.replace("®", "").replace("™", "")
    s = re.sub(r"\s+", " ", s)
    # Keep alphanumerics, spaces, '/', '&', '-', '+'
    s = re.sub(r"[^A-Z0-9 /&\-\+]", "", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s or None


# ---------------------------
# Age extraction + bucketing
# ---------------------------

AGE_KEYWORDS = {
    "kids": ["kids", "children", "youth"],
    "teens": ["teen", "teens"],
    "young_adults": ["young adult", "college"],
    "adults": ["adult", "adults"],
    "seniors": ["senior", "older adult", "55+", "60+", "65+"],
    "all": ["all ages", "family"],
}

def _bucket_from_age_range(min_age: Optional[int], max_age: Optional[int]) -> List[str]:
    """
    Convert numeric min/max into broad groups.
    Rules are intentionally simple and deterministic.
    """
    if min_age is None and max_age is None:
        return []

    # If only one side exists, treat it as both ends
    if min_age is None:
        min_age = max_age
    if max_age is None:
        max_age = min_age

    groups = set()

    # kids: <= 12
    if min_age <= 12:
        groups.add("kids")

    # teens: 13-17
    if max_age >= 13 and min_age <= 17:
        groups.add("teens")

    # young adults: 18-25
    if max_age >= 18 and min_age <= 25:
        groups.add("young_adults")

    # adults: 26-59 (or any adult mention)
    if max_age >= 26 and min_age <= 59:
        groups.add("adults")

    # seniors: 60+
    if max_age >= 60:
        groups.add("seniors")

    return sorted(groups)

def extract_age_range(text: str) -> Tuple[Optional[int], Optional[int]]:
    """
    Extracts numeric age range from patterns like:
      - "Ages: 6–10"
      - "Age 8-12"
      - "Ages 16 - 18"
      - "Ages: 18+"
      - "55+"
    Returns (min_age, max_age) where max_age can be None for open-ended.
    """
    t = text

    # Ages: 6–10 / 6-10 / 6 — 10
    m = re.search(r"(?i)\bages?\s*[:\-]?\s*(\d{1,2})\s*[–—-]\s*(\d{1,2})\b", t)
    if m:
        return int(m.group(1)), int(m.group(2))

    # Age: 8-12
    m = re.search(r"(?i)\bage\s*[:\-]?\s*(\d{1,2})\s*[–—-]\s*(\d{1,2})\b", t)
    if m:
        return int(m.group(1)), int(m.group(2))

    # Ages: 18+
    m = re.search(r"(?i)\bages?\s*[:\-]?\s*(\d{1,2})\s*\+\b", t)
    if m:
        return int(m.group(1)), None

    # Standalone 55+ / 60+
    m = re.search(r"(?i)\b(\d{2})\s*\+\b", t)
    if m:
        return int(m.group(1)), None

    return None, None


def extract_age_groups(text: str) -> List[str]:
    """
    Prefer numeric extraction if present; fall back to keyword detection.
    Default to adults if nothing found.
    """
    text_l = text.lower()

    min_age, max_age = extract_age_range(text)
    groups = set(_bucket_from_age_range(min_age, max_age))

    # keyword-based additions
    for g, kws in AGE_KEYWORDS.items():
        if any(kw in text_l for kw in kws):
            groups.add(g)

    # If "all" appears, keep it and optionally drop others (your choice)
    if "all" in groups:
        return ["all"]

    if not groups:
        groups.add("adults")

    return sorted(groups)




def infer_intensity_from_text(text: str) -> Optional[str]:
    """
    Heuristic fallback if event block itself contains intensity cues.
    Prefer using activity-type definitions instead (recommended).
    """
    t = text.lower()
    # Strong cues first
    if any(x in t for x in ["low impact", "gentle", "restorative", "beginner", "chair", "arthritis"]):
        return "low"
    if any(x in t for x in ["high intensity", "interval", "boot camp", "fast-paced", "challenging"]):
        return "high"
    if any(x in t for x in ["moderate", "all levels", "level 2", "level 2/3"]):
        return "moderate"
    return None




# Reuse the same normalizer you use for event_type so keys match
def normalize_activity_heading(s: Optional[str]) -> Optional[str]:
    if not s:
        return None
    s = s.strip().upper()
    s = s.replace("®", "").replace("™", "")
    s = re.sub(r"\s+", " ", s)
    # keep alphanumerics, spaces, '/', '&', '-', '+'
    s = re.sub(r"[^A-Z0-9 /&\-\+]", "", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s or None

def normalize_intensity(raw: Optional[str]) -> Optional[str]:
    """
    Map text like 'Low–Moderate', 'Level 2/3', etc. to low/moderate/high.
    Keep it intentionally simple.
    """
    if not raw:
        return None
    t = raw.strip().lower()
    if "low" in t or "gentle" in t:
        return "low"
    if "high" in t or "challenging" in t:
        return "high"
    if "moderate" in t or "medium" in t or "level 2" in t or "level 2/3" in t:
        return "moderate"
    # If it says "level 1" treat as low
    if "level 1" in t:
        return "low"
    # If it says "level 3" treat as high
    if "level 3" in t:
        return "high"
    return None

_INTENSITY_RE = re.compile(r"(?mi)^\s*\*\*Intensity:\*\*\s*(.+?)\s*$")

def normalize_age_focus(age_focus: Optional[str]) -> Optional[str]:
    if not age_focus:
        return None
    return ",".join(extract_age_groups(age_focus))

def normalize_city(city: Optional[str]) -> Optional[str]:
    if not city:
        return None
    return city.strip().lower()

def normalize_state(state: Optional[str]) -> Optional[str]:
    if not state:
        return None
    return state.strip().lower()


In [37]:
from langchain_text_splitters import MarkdownHeaderTextSplitter

def build_activitytype_documents(
    md_text: str,
    source: str,
) -> Tuple[List[Document], Dict[str, str]]:
    """
    1) Chunk aquatics.md (or other activitytype md) using LangChain header splitter (##, ###)
    2) For each chunk, extract activity_heading and parse '**Intensity:** ...'
    3) Return:
       - activity_docs: Documents with metadata {source, activity_heading, activity_heading_norm, intensity}
       - intensity_map: dict[activity_heading_norm] -> 'low'|'moderate'|'high'
    """
    splitter = MarkdownHeaderTextSplitter(
        headers_to_split_on=[("##", "h2"), ("###", "h3")],
        strip_headers=False,
    )
    chunks = splitter.split_text(md_text)

    docs: List[Document] = []
    intensity_map: Dict[str, str] = {}

    for c in chunks:
        meta = c.metadata or {}
        # split heading from source file name, strip .md
        heading = source.split(".")[0].strip()
        heading_norm = normalize_activity_heading((meta.get("h3") or meta.get("h2") or heading).strip())

        # Skip chunks that aren't actual activity headings
        if not heading_norm:
            continue
        if heading_norm.startswith("PAGE "):  # safety if a brochure sneaks in
            continue

        # Parse intensity line if present
        m = _INTENSITY_RE.search(c.page_content)
        intensity = normalize_intensity(m.group(1)) if m else None

        if intensity and heading_norm not in intensity_map:
            intensity_map[heading_norm] = intensity

        print("*** Printing metadata: ")
        print("source: ", source)
        print("activity_heading: ", heading)
        print("activity_heading_norm: ", heading_norm)
        print("intensity: ", intensity)

        docs.append(
            Document(
                page_content=c.page_content.strip(),
                metadata={
                    "source": source,
                    "activity_heading": heading,
                    "activity_heading_norm": heading_norm,
                    "intensity": intensity,  # may be None if line not present
                },
            )
        )

    return docs, intensity_map


In [None]:
# ---------------------------
# Updated build_event_documents
# ---------------------------

def build_event_documents(
    md_text: str,
    source: str,
    activity_intensity_map: Optional[dict] = None,
    city: Optional[str] = None,
    state: Optional[str] = None,
) -> List[Document]:
    """
    Split a brochure into one Document per event (each '### ...' block).

    Metadata:
    - event_name
    - event_type (normalized)
    - age_min, age_max
    - age_contains (bucketed groups)
    - city, state if present
    - intensity (low/moderate/high)  [prefer from activity_intensity_map]
    """
    matches = list(_EVENT_BLOCK_RE.finditer(md_text))
    if not matches:
        return []

    docs: List[Document] = []

    for i, m in enumerate(matches):
        event_name = m.group(1).strip()
        start = m.start()
        end = matches[i + 1].start() if i + 1 < len(matches) else len(md_text)
        block = md_text[start:end].strip()

        # Extract Event Type
        event_type_raw = None
        mt = re.search(r"(?mi)^\s*-\s*Event Type:\s*(.+?)\s*$", block)
        if mt:
            event_type_raw = mt.group(1).strip()

        event_type = normalize_event_type(event_type_raw)

        # Ages numeric + buckets
        age_min, age_max = extract_age_range(block)
        age_contains_list = extract_age_groups(block)
        # Convert list to string for ChromaDB (which doesn't support list metadata)
        age_contains = ", ".join(age_contains_list) if age_contains_list else None

        # Intensity: prefer map from activityType docs (keyed by normalized event_type)
        intensity = None
        if activity_intensity_map and event_type:
            intensity = activity_intensity_map.get(event_type)
        if not intensity:
            intensity = infer_intensity_from_text(block)

        docs.append(
            Document(
                page_content=block,
                metadata={
                    "source": source,
                    "event_name": event_name,
                    "event_type": event_type,           # normalized (e.g., "AQUA ZUMBA")
                    "event_type_raw": event_type_raw,   # optional debugging
                    "age_min": age_min,
                    "age_max": age_max,
                    "city": city,
                    "state": state,
                    "age_contains": age_contains,       # comma-separated string (e.g., "kids" or "teens, adults")
                    "intensity": intensity,             # low/moderate/high or None
                },
            )
        )

    return docs


In [39]:
import shutil

# -----------------------------
# 2) Vectorstores (Chroma)
# -----------------------------

# Initialize embeddings
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small"
)

# --- Optional reranker (open-source) ---
USE_RERANKER = False
try:
    if USE_RERANKER:
        from sentence_transformers import CrossEncoder
        _reranker = CrossEncoder("BAAI/bge-reranker-base")
    else:
        _reranker = None
except Exception:
    _reranker = None

@dataclass
class RagStores:
    events: Chroma
    activity_types: Chroma

# cleanup vector stores object
def cleanup_vectorstores(stores: RagStores):
    stores.events.delete_collection()
    stores.activity_types.delete_collection()

def build_vectorstores(
    event_md_files: List[str],
    activitytype_md_files: List[str],
    persist_dir: str = "./rag_chroma",
) -> RagStores:

    # cleanup vector stores
    if os.path.exists(persist_dir):
        shutil.rmtree(persist_dir)

    # create persist_dir if it doesn't exist give read and write permissions
    os.makedirs(persist_dir, exist_ok=True, mode=0o777)



    activity_docs: List[Document] = []
    activity_intensity_map: Dict[str, str] = {}

    print("Number of activity type files: ", len(activitytype_md_files))
    print("Number of event files: ", len(event_md_files))

    for path in activitytype_md_files:
        with open(path, "r", encoding="utf-8") as f:
            md = f.read()
        docs, intensity_map = build_activitytype_documents(md, source=os.path.basename(path))
        activity_docs.extend(docs)
        activity_intensity_map.update(intensity_map)

    # Build docs
    event_docs: List[Document] = []
    for (path) in event_md_files:
        with open(path, "r", encoding="utf-8") as f:
            md = f.read()
        # for each file, retrieve the city and state from the Location line
        center_md = parse_center_metadata(md, path)
        city = center_md.get('city')
        state = center_md.get('state')
   
        event_docs.extend(build_event_documents(md, source=os.path.basename(path), activity_intensity_map=activity_intensity_map, city=city, state=state))

    # Separate collections to keep retrieval clean
    
    events_store = Chroma.from_documents(
        documents=event_docs,
        embedding=embeddings,
        collection_name="events"
    )
    activity_store = Chroma.from_documents(
        documents=activity_docs,
        embedding=embeddings,
        collection_name="activity_types"
    )

    print("Number of events: ", events_store._collection.count())
    print("Number of activity types: ", activity_store._collection.count())
    
    return RagStores(events=events_store, activity_types=activity_store)


def load_vectorstores(persist_dir: str = "./rag_chroma") -> RagStores:
    events_store = Chroma(
        embedding_function=embeddings,
        collection_name="events"
    )
    activity_store = Chroma(
        embedding_function=embeddings,
        collection_name="activity_types"
    )
    return RagStores(events=events_store, activity_types=activity_store)

In [54]:
def build_chroma_where(input_filter: dict) -> dict | None:
    """
    Build a Chroma-compatible where clause.

    Rules:
    - Single-value fields → equality
    - List-valued fields → $or
    - Multiple fields → wrap in $and
    - Exactly ONE top-level operator
    """

    and_clauses = []

    for key, value in input_filter.items():
        if value is None:
            continue

        # List → OR
        if isinstance(value, list):
            if len(value) == 1:
                and_clauses.append({key: value[0]})
            elif len(value) > 1:
                and_clauses.append({
                    "$or": [{key: v} for v in value]
                })

        # Scalar → equality
        else:
            and_clauses.append({key: value})

    if not and_clauses:
        return None

    # Single clause → return directly
    if len(and_clauses) == 1:
        return and_clauses[0]

    # Multiple → AND
    return {"$and": and_clauses}


In [55]:
from typing import List
from langchain_core.documents import Document

def retrieve_activity_types(stores: RagStores, user_question: str, input_filter: Dict[str, Any], k: int = 5, oversample: int = 8) -> List[Document]:
    raw_k = max(k * oversample, 20)
    print("In retrieve_activity_types **** input_filter: ", input_filter)

   # build filter for chroma dict with and clause 
    filter = build_chroma_where(input_filter)

    raw = stores.activity_types.similarity_search(user_question, k=raw_k, filter=filter)
    print("In retrieve_activity_types **** raw: ", raw)

    seen = set()
    out: List[Document] = []
    for d in raw:
        heading = (d.metadata.get("activity_heading_norm") or "").strip()
        key = heading.lower()
        if not key or key in seen:
            continue
        seen.add(key)
        out.append(d)
        if len(out) >= k:
            break

    print("In retrieve_activity_types **** out: ", out)
    return out


In [59]:
from typing import List, Optional
from langchain_core.documents import Document
from langchain_community.vectorstores import Chroma
from sympy import li

def retrieve_events_for_activity_type(
    stores: RagStores,
    user_question: str,
    input_filter: Dict[str, Any],
    k: int = 10
) -> List[Document]:

    k = 10
    print("In retrieve_events_for_activity_type **** input_filter: ", input_filter)
    
   # build filter for chroma dict with and clause 
    filter = build_chroma_where(input_filter)

    print("In retrieve_events_for_activity_type **** output filter: ", filter)

    # similarity search with priority to city, interests
    # how to prioritize city, age_focus metadata filters?
    
    events = stores.events.similarity_search(
        query=user_question,
        k=k,
        filter=filter
    )
    print("In retrieve_events_for_activity_type **** events: ", events)
    return events


In [42]:
# construct vector stores
stores = build_vectorstores(events_md_files, activitytype_md_files, persist_dir="./rag_chroma")


Number of activity type files:  6
Number of event files:  11
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  ATHLETICS
intensity:  None
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  BOOT CAMP / BOOT CAMP BURN
intensity:  high
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  CARDIO STRENGTH / CIRCUITS
intensity:  moderate
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  DUMBBELL / SURGE / UPBEAT STRENGTH
intensity:  high
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  CYCLING
intensity:  high
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  SENIOR CIRCUITS / SILVERSNEAKERS
intensity:  low
*** Printing metadata: 
source:  dancing.md
activity_heading:  dancing
activity_heading_norm:  DANCI

In [43]:
print("activity types indexed:", stores.activity_types._collection.count())

sample = stores.activity_types.similarity_search("aquatics", k=50)
print([d.metadata.get("activity_heading") for d in sample])

activity_query = "I'm interested in cardio swim classes around Framingham"

user_profile = {
    "city": "Framingham",
    "state": "Massachusetts",
    "age_focus": "teens",
    "interests": "cardio, swim",
}

user_profile_str = json.dumps(user_profile)
# 1) Retrieve top activity types




activity types indexed: 378
['aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics', 'aquatics']


In [44]:
#events_query_parts = [activity_query]
events_query_parts = []

events_query_parts.append(f"intensity: Moderate")
events_query_parts.append(f"city: Framingham")
#events_query_parts.append(f"state: Massachusetts")
events_query_parts.append(f"interests: Cardio, Swim")
events_query_parts.append(f"age_contains: teens")
#events_query_parts.append(f"Activity type: AQUA FIT")
#events_query_parts.append(f"Activity type: AQUA CARDIO")
#events_query_parts.append(f"Activity type: AQUA CARDIO DANCE")
events_query = "\n".join(events_query_parts)

print(stores.events._collection.count())

#events = retrieve_events_for_activity_type(stores, events_query, k=3)
#print("**** events: ", events)


1015


In [45]:
# retrieve events for activity type
#activity_type = "AQUA CARDIO"
#events = retrieve_events_for_activity_type(stores, activity_type, k=3)

#print("**** events: ", events)
#for event in events:
#    print("**** event: ", event)

In [46]:



# -----------------------------
# 3) Retrieval + (optional) rerank
# -----------------------------

def rerank(query: str, docs: List[Document], top_n: int = 5) -> List[Document]:
    if not docs:
        return []
    if _reranker is None:
        return docs[:top_n]

    pairs = [(query, d.page_content) for d in docs]
    scores = _reranker.predict(pairs)
    ranked = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
    return [d for _s, d in ranked[:top_n]]






# -----------------------------
# 4) Answer synthesis
# -----------------------------

def format_event_card(d: Document) -> str:
    m = d.metadata
    return (
        f"- **{m.get('event_title')}** ({m.get('event_type')}) — {m.get('city')}, {m.get('state')} @ {m.get('center_name')}\n"
        f"  - Age: {m.get('age_tags')}\n"
        f"  - When: {m.get('date_range')} | {m.get('time_slots')} | {m.get('duration')}\n"
        f"  - Instructor: {m.get('instructor')} | Spots: {m.get('spots')}\n"
        f"  - Source: {m.get('source')}\n"
    )


def build_context_block(events: List[Document], activity_defs: List[Document]) -> str:
    parts = ["## Retrieved Events\n"]
    parts.extend([format_event_card(d) for d in events])

    if activity_defs:
        parts.append("\n## Activity Definitions\n")
        for d in activity_defs:
            parts.append(f"- Source: {d.metadata.get('source')} | Section: {d.metadata.get('activity_heading')}\n{d.page_content}\n")

    return "\n".join(parts).strip()


# -----------------------------
# 5) Chat model call (plug in your model)
# -----------------------------
# You can wire this into OpenAI / Groq / Ollama.
# Below is a generic interface: pass `generate_fn(prompt:str)->str`.

from typing import List, Dict, Optional
from langchain_core.documents import Document


In [52]:
from typing import List, Dict, Optional
from langchain_core.documents import Document

def answer_user(
    stores: RagStores,
    user_question: str,
    user_profile: Dict[str, str]
) -> str:
    """
    Two-stage retrieval:
      1) user intent -> ACTIVITY TYPE definitions (activityType RAG)
      2) activity types + filters -> EVENTS (brochure RAG)
    Uses: city/state, age, intensity, interests, and user_question.
    """
    profile = user_profile or {}
    RETRIEVAL_K = 5

    def to_str_safe(v):
        if isinstance(v, list):
            return ", ".join(str(x) for x in v if x)
        return str(v).strip() if v else ""

    # -------------------------
    # 0) Extract user constraints
    # -------------------------
    print("user_profile: ", profile)
 

    interests = to_str_safe(",".join(profile.get("interests", "")))       # free text
    # TODO Add age group later
    intensity = normalize_intensity(profile.get("intensity", ""))    
    activity_query_parts =  dict({})

    if interests:
        activity_query_parts['activity_heading'] = interests
    if intensity:
        activity_query_parts['intensity'] = intensity

    print("activity_query_parts: ", activity_query_parts)

    # Retrieve more than 5 so dedupe returns multiple headings
    activity_type_docs = retrieve_activity_types(stores, user_question, activity_query_parts, k=5)

    print("**** activity_type_docs: ", activity_type_docs)
    # similarity search with priority to city, interests
    # how to prioritize city, age_focus metadata filters?

    events_query_parts = dict({})

    #location = to_str_safe(profile.get("location", ""))          # e.g., "Lexington, MA"
    age_focus = normalize_age_focus(profile.get("age_focus"))       
    city = normalize_city(profile.get("city"))
    state = normalize_state(profile.get("state"))

    if age_focus:
        events_query_parts['age_contains'] = age_focus
    if city:
        events_query_parts['city'] = city
    if state:
        events_query_parts['state'] = state

    events = retrieve_events_for_activity_type(stores, user_question, 
                                                events_query_parts, k=RETRIEVAL_K)

    # Choose top N headings (increase to 3 for better recall)
    chosen_headings: List[str] = []
    for d in activity_type_docs:
        h = (d.metadata.get("activity_heading_norm") or "").strip()
        if h and h not in chosen_headings:
            chosen_headings.append(h)
        if len(chosen_headings) >= 3:
            break

    chosen_headings_str = ",".join([h for h in chosen_headings if h])
    print("****chosen_headings_str: ", chosen_headings_str)
    if chosen_headings_str:
        events_query_parts['event_type'] = chosen_headings_str


    events = retrieve_events_for_activity_type(
        stores=stores,
        user_question=user_question,
        input_filter=events_query_parts,    
        k=RETRIEVAL_K
    )
        
    print("Retrieved events: ", events)
    # -------------------------
    # 3) De-dupe events
    # -------------------------
    seen = set()
    deduped_events: List[Document] = []
    for e in events:
        key = (
            (e.metadata.get("source") or "").strip().lower(),
            (e.metadata.get("event_name") or e.metadata.get("event_title") or "").strip().lower(),
            (e.metadata.get("event_type") or "").strip().lower(),
            (e.metadata.get("city") or "").strip().lower(),
            (e.metadata.get("state") or "").strip().lower(),
        )
        if key in seen:
            continue
        seen.add(key)
        deduped_events.append(e)

    # Optional: post-filter if filters weren’t supported in vectorstore
    if intensity:
        deduped_events = [e for e in deduped_events if (e.metadata.get("intensity") == intensity)]

    top_events = deduped_events[:20]

    # -------------------------
    # 4) Include activity definitions for reasoning (intensity/benefits)
    # -------------------------
    activity_defs = [
        d for d in activity_type_docs
        if (d.metadata.get("activity_heading") or "").strip() in chosen_headings
    ]

    return build_context_block(top_events, activity_defs)


In [48]:
USER_PROFILE_SCHEMA = {
    "location": "string | null",
    "age_focus": "kids | teens | young_adults | adults | seniors | null",
    "interests": "list[str]",
    "time_prefs": "list[str]",
    "city": "string | null",
    "state": "string | null",
    "budget_sensitivity": "low | medium | high | null",
}

# create pydantic model class for User profile schema:


from pydantic import BaseModel
from typing import Dict, Any, List, Tuple, Optional

class UserProfile(BaseModel):
    location: str | None = None
    age_focus: str | None = None
    interests: list[str] | None = None
    time_prefs: list[str] | None = None
    city: str | None = None
    state: str | None = None
    budget_sensitivity: str | None = None

PROFILE_SYSTEM_PROMPT = """
You are a profile extraction assistant.

Your task:
- Extract structured user preferences from casual chat text.
- Output ONLY valid JSON.
- Do NOT guess.
- If unsure, return null or empty lists.

Allowed fields ONLY:
- location: string | null  (US city/state if present)
- age_focus: kids | teens | young_adults | adults | seniors | null
- interests: list of strings from: aquatics, athletics, dancing, cooking, drawing
- time_prefs: list of strings from: mornings, afternoons, evenings, weekends
- city: string | null (US city if present)
- state: string | null (US state if present)
- budget_sensitivity: low | medium | high | null

Rules:
- Never invent a location.
- Do not add fields not listed.
- If the user mentions multiple age groups, pick the dominant one; otherwise null.
"""

PROFILE_USER_PROMPT_TEMPLATE = """
Existing profile:
{existing_profile_json}

Recent user messages (for context):
{recent_user_messages}

New user message:
"{user_message}"

Return ONLY a JSON object with any fields you can confidently update.
If nothing can be updated, return:
{{"location": null, "age_focus": null, "interests": [], "time_prefs": [], "city": null, "state": null, "budget_sensitivity": null}}
""".strip()

def llm_call_profile(system_prompt: str, user_prompt: str) -> str:
    """
    CALL YOUR LLM HERE (Groq/OpenAI/Ollama) for profile extraction.
    Must return raw text containing ONLY JSON.
    """
    resp = groq_client.chat.completions.create(
       model=OPENSOURCE_OSS_MODEL,
       messages=[{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}],
       temperature=0
    )
    return resp.choices[0].message.content
    raise NotImplementedError("Wire llm_call_profile() to your LLM provider.")



def merge_profiles(existing: Dict[str, Any], new: Dict[str, Any]) -> Dict[str, Any]:
    """Non-destructive merge: keep existing unless new provides a confident update."""
    merged = dict(existing or {})

    for k, v in (new or {}).items():
        if v is None:
            continue
        if isinstance(v, list):
            cur = set(merged.get(k, []) or [])
            merged[k] = sorted(cur.union([x for x in v if isinstance(x, str) and x.strip()]))
        elif isinstance(v, str):
            if v.strip():
                merged[k] = v.strip()
        else:
            merged[k] = v

    # Normalize allowed enums (extra safety)
    allowed_age = {"kids", "teens", "young_adults", "adults", "seniors"}
    if merged.get("age_focus") not in allowed_age:
        merged["age_focus"] = None

    allowed_interests = {"aquatics", "athletics", "dancing", "cooking", "drawing"}
    merged["interests"] = [x for x in merged.get("interests", []) if x in allowed_interests]

    allowed_time = {"mornings", "afternoons", "evenings", "weekends"}
    merged["time_prefs"] = [x for x in merged.get("time_prefs", []) if x in allowed_time]

    if merged.get("city"):
        merged["city"] = merged["city"].strip()
    if merged.get("state"):
        merged["state"] = merged["state"].strip()

    allowed_budget = {"low", "medium", "high"}
    if merged.get("budget_sensitivity") not in allowed_budget:
        merged["budget_sensitivity"] = None

    return merged

def get_recent_user_messages(history: List[Tuple[str, str]], n: int = 4) -> List[str]:
    """History is List[(user, assistant)]. Returns last n user messages."""
    if not history:
        return []
    users = []
    for u, _a in history:
        if isinstance(u, str) and u.strip():
            users.append(u.strip())
    return users[-n:]

# ---------------------------
# 3) Build compact retrieval query (don’t embed full transcript)
# ---------------------------

def build_retrieval_query(message: str, profile: Dict[str, Any], history: List[Tuple[str, str]]) -> str:
    recent = get_recent_user_messages(history, n=3)
    parts = [message.strip()]

    if profile.get("location"):
        parts.append(f"Location: {profile['location']}")
    if profile.get("age_focus"):
        parts.append(f"Age: {profile['age_focus']}")
    if profile.get("interests"):
        parts.append("Interests: " + ", ".join(profile["interests"]))
    if profile.get("time_prefs"):
        parts.append("Time prefs: " + ", ".join(profile["time_prefs"]))
    if profile.get("budget_sensitivity"):
        parts.append(f"Budget: {profile['budget_sensitivity']}")

    if recent:
        parts.append("Recent user context: " + " | ".join([u[:120] for u in recent]))

    return "\n".join(parts)



In [None]:
# import gradio chat interface
import gradio as gr

ANSWER_SYSTEM_PROMPT = "You are an Activity Recommendation Assistant. \
\
Your job is to help users discover suitable activities, classes, or event types \
based on their age, interests, physical comfort level, time availability, and goals. \
You must follow these rules: \
1. Use ONLY the information provided in the context and user messages. \
2. Do NOT invent activities, classes, or benefits that are not explicitly stated. \
3. If information is missing, ask a clarifying question instead of guessing. \
4. Be respectful of physical limitations and accessibility needs. \
5. Do NOT provide medical advice. Phrase benefits in general wellness terms. \
6. When recommending activities, include: \
   - Activity name \
   - Typical intensity level \
   - Typical session length \
   - Recommended weekly frequency \
   - Why it fits the user’s preferences \
7. If multiple activities fit, rank them from best to least suitable. \
8. If nothing fits well, explain why and suggest alternatives. \
\
Your tone should be friendly, practical, and encouraging. \
\
Return: \
- 3–5 recommended events, ranked \
- short rationale per event (age, interests, schedule, location) \
- keep it concise and practical \
""".strip()


#ANSWER_SYSTEM_PROMPT = """
#You are an assistant that recommends recreation events.

#Use ONLY the provided retrieved context for facts (events, schedules, locations).
#If the user asks for something not found in context, say you don’t have it and suggest what to ask next.

#Return:
#- 3–5 recommended events, ranked
# - short rationale per event (age, interests, schedule, location)
#- keep it concise and practical
#""".strip()


def llm_call_answer(system_prompt: str, user_prompt: str) -> str:
    resp = groq_client.chat.completions.create(
       model=OPENSOURCE_OSS_MODEL,
       messages=[{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}],
       temperature=0
    )
    return resp.choices[0].message.content
    raise NotImplementedError("Wire llm_call_profile() to your LLM provider.")


stores = build_vectorstores(events_md_files, activitytype_md_files, persist_dir="./rag_chroma")

# groq client to openai opensource openai oss model
# use the system prompt and user prompt to generate the response

# Store profile state (simple approach - in production use Gradio's state management)
_user_profile_state = {}

def convert_gradio_history(history):
    """Convert Gradio 6.x history format (list of dicts) to tuple format."""
    if not history:
        return []
    converted = []
    i = 0
    while i < len(history):
        msg = history[i]
        if isinstance(msg, dict):
            # Gradio 6.x format: {"role": "user"/"assistant", "content": "..."}
            role = msg.get("role", "")
            content = msg.get("content", "")
            if role == "user":
                # Look for corresponding assistant message (next item)
                assistant_content = ""
                if i + 1 < len(history) and isinstance(history[i + 1], dict):
                    next_msg = history[i + 1]
                    if next_msg.get("role") == "assistant":
                        assistant_content = next_msg.get("content", "")
                        i += 1  # Skip the assistant message
                converted.append((content, assistant_content))
            elif role == "assistant":
                # Standalone assistant message (shouldn't happen, but handle it)
                converted.append(("", content))
        elif isinstance(msg, tuple):
            # Already in tuple format
            converted.append(msg)
        i += 1
    return converted

def chat(message: str, history):
    """Chat function compatible with Gradio 6.x ChatInterface."""
    global _user_profile_state
    
    # Convert Gradio history format to tuple format for internal use
    history_tuples = convert_gradio_history(history)
    
    profile = dict(_user_profile_state or {})

    # A) Ask LLM to update profile using recent user turns + current message
    recent_msgs = get_recent_user_messages(history_tuples, n=4)
    profile_prompt = PROFILE_USER_PROMPT_TEMPLATE.format(
        existing_profile_json=json.dumps(profile, indent=2),
        recent_user_messages="\n".join([f"- {m}" for m in recent_msgs]) if recent_msgs else "(none)",
        user_message=message,
    )

    try:
        extracted_raw = llm_call_profile(PROFILE_SYSTEM_PROMPT, profile_prompt)
        extracted = json.loads(extracted_raw)
        print("Extracted profile: ", extracted)
        profile = merge_profiles(profile, extracted)
        _user_profile_state.update(profile)  # Update global state
        print("Updated profile: ", profile)
    except Exception:
        # Fail-safe: keep existing profile if LLM extraction fails
        pass

    # Build retrieval query
    retrieval_query = build_retrieval_query(message, profile, history_tuples)

    context = answer_user(stores, retrieval_query, profile)
    print("context: ", context)

    # D) Build answer prompt (history can be used here, not in embeddings)
    # Keep history minimal to avoid token bloat: last 6 turns
    last_turns = history_tuples[-6:] if history_tuples else []
    chat_history_text = "\n".join([f"User: {u}\nAssistant: {a}" for (u, a) in last_turns if u and a])

    answer_user_prompt = f"""
    You are continuing a conversation with a user about activities and events. 

    User question:
    {message}

    User profile (structured):
    {json.dumps(profile, indent=2)}

    Recent chat history:
    {chat_history_text if chat_history_text else "(none)"}

    Retrieved context:
    {context}
 
    Instructions: 
    - Recommend the top 2 suitable activity types from the knowledge provided. 
    - Explain clearly why each recommendation fits the user. 
    - Get the exact event name from the context provided and which location the event is located at. 
    - Include intensity, session length, and typical weekly frequency. 
    - Cite the activity source using [filename | activity name]. 
    - If the knowledge is insufficient, say what is missing. 
       """.strip()


    #answer_user_prompt = f"""
    #User question:
    #{message}

    #User profile (structured):
    #  {json.dumps(profile, indent=2)}

    #Recent chat history:
    #  {chat_history_text if chat_history_text else "(none)"}

    #Retrieved context:
    #{context}

    #Task:
   # Recommend 3–5 best matching events from context, ranked.
    #For each: include event name, center, when, age fit, and why it matches.
    #""".strip()

    assistant_text = llm_call_answer(ANSWER_SYSTEM_PROMPT, answer_user_prompt)

    # Update global profile state
    _user_profile_state.update(profile)
    
    # Return only the message string (Gradio 6.x expects just the message)
    return assistant_text


# gradio chat interface
demo = gr.ChatInterface(chat)
demo.launch()

Number of activity type files:  6
Number of event files:  11
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  ATHLETICS
intensity:  None
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  BOOT CAMP / BOOT CAMP BURN
intensity:  high
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  CARDIO STRENGTH / CIRCUITS
intensity:  moderate
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  DUMBBELL / SURGE / UPBEAT STRENGTH
intensity:  high
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  CYCLING
intensity:  high
*** Printing metadata: 
source:  athletics.md
activity_heading:  athletics
activity_heading_norm:  SENIOR CIRCUITS / SILVERSNEAKERS
intensity:  low
*** Printing metadata: 
source:  dancing.md
activity_heading:  dancing
activity_heading_norm:  DANCI



Extracted profile:  {'location': None, 'age_focus': 'adults', 'interests': ['aquatics'], 'time_prefs': [], 'city': None, 'state': None, 'budget_sensitivity': None}
Updated profile:  {'age_focus': 'adults', 'interests': ['aquatics'], 'time_prefs': [], 'budget_sensitivity': None}
user_profile:  {'age_focus': 'adults', 'interests': ['aquatics'], 'time_prefs': [], 'budget_sensitivity': None}
activity_query_parts:  {'activity_heading': 'aquatics'}
In retrieve_activity_types **** input_filter:  {'activity_heading': 'aquatics'}
In retrieve_activity_types **** raw:  [Document(metadata={'activity_heading': 'aquatics', 'activity_heading_norm': 'AQUATICS', 'source': 'aquatics.md'}, page_content='# Aquatics Activities  \nAquatics activities are water-based exercise programs that use buoyancy and water resistance to reduce joint stress while improving cardiovascular health, strength, flexibility, and balance. These programs are widely used for rehabilitation, senior fitness, and low-impact conditio