In [2]:
import json
import logging
import os

from dotenv import load_dotenv
from openai import OpenAI
from pymongo import MongoClient, UpdateOne
from rich.logging import RichHandler

from bson.objectid import ObjectId
from concurrent.futures import ThreadPoolExecutor, as_completed

logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True)],
)

load_dotenv()
MONGODB_URI = os.getenv("MONGODB_URI")
Database_Name = os.getenv("DATABASE")
edps_claims_collection = os.getenv("EDPS_CLAIMS_COLLECTION")
pharmacy_claims_collection = os.getenv("PHARMACY_CLAIMS_COLLECTION")
eligibility_collection = os.getenv("ELIGIBILITY_COLLECTION")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
mbi_crosswalk_collection = os.getenv("MBI_CROSSWALK_COLLECTION")

LLM_MODEL = "gpt-4o-mini"
llm_client = OpenAI(api_key=OPENAI_API_KEY)

MBI_CROSSWALK_MAP = {}


In [3]:
def get_mbi_crosswalk_map():
    """Fetch MemberID -> MBI mapping from Mongodb and return as dict."""
    try:
        with MongoClient(MONGODB_URI) as client:
            db = client[Database_Name]
            projection = {"_id": 0, "created_dt": 0}
            cursor = db[mbi_crosswalk_collection].find({}, projection)

            crosswalk_map = {
                doc.get("MemberID"): doc.get("MBI")
                for doc in cursor
                if doc.get("MemberID") is not None and doc.get("MBI") is not None
            }

            return crosswalk_map

    except Exception as e:
        logging.error(f"Failed to load MBI crosswalk map: {e}")
        return {}


In [4]:
def load_members_with_claims_from_docs(eligibility_docs):
    """Given eligibility documents, fetch medical & pharmacy claims using MBI if available."""
    
    global MBI_CROSSWALK_MAP

    try:
        client = MongoClient(MONGODB_URI)
        db = client[Database_Name]
    except Exception as e:
        logging.error(f"Error connecting to MongoDB: {e}")
        return []

    members = []

    for el in eligibility_docs:
        member_id = el["memberId"]

        # MBI fallback logic
        mbi_lookup_id = MBI_CROSSWALK_MAP.get(member_id, member_id)

        # ---- Medical Claims (EDPS) ----
        medical_claims_cursor = db[edps_claims_collection].find(
            {"Member.Subscriber_ID": mbi_lookup_id},
            {
                "_id": 0,
                "Diagnosis.Diag_Codes": 1,
                "ServiceLine.LXServiceNo": 1,
                "ServiceLine.BilledCPT_Code": 1,
                "ServiceLine.BilledCPTDesc": 1,
                "ServiceLine.Line_SvcDate": 1,
                "Claim.ClaimID": 1,
                "Claim.POS": 1,
                "Type_of_Bill": 1,
                "Provider.BillProv_NPI": 1,
                "Provider.BillProv_LastName": 1,
                "Member.Subscriber_ID": 1,
                "Member.Subscriber_DOB": 1,
                "Member.Subscriber_Gender": 1,
            },
        )
        medical_claims = list(medical_claims_cursor)

        # ---- Pharmacy Claims ----
        pharmacy_claims_cursor = db[pharmacy_claims_collection].find(
            {"Member ID": member_id},
            {
                "_id": 0,
                "Member ID": 1,
                "NDC": 1,
                "Product Label Name": 1,
                "Fill Date": 1,
                "Days Supply": 1,
                "Metric Quantity": 1,
                "Prescriber ID": 1,
                "Prescriber Name": 1,
                "Total Billed": 1,
            },
        )

        pharmacy_claims = [
            {
                "memberId": pc.get("Member ID"),
                "ndc": str(pc.get("NDC")),
                "drugName": pc.get("Product Label Name"),
                "fillDate": (
                    pc.get("Fill Date").strftime("%Y-%m-%d")
                    if pc.get("Fill Date")
                    else None
                ),
                "daysSupply": pc.get("Days Supply"),
                "quantityDispensed": pc.get("Metric Quantity"),
                "prescriberNPI": pc.get("Prescriber ID"),
                "prescriberName": pc.get("Prescriber Name"),
                "totalBilled": pc.get("Total Billed"),
            }
            for pc in pharmacy_claims_cursor
        ]

        members.append(
            {
                "memberId": member_id,
                "eligibility": el,
                "medicalClaims": medical_claims,
                "pharmacyClaims": pharmacy_claims,
            }
        )

    return members


In [None]:
def call_llm_for_suspects(members):
    """Calls the LLM to generate suspects for a batch of members."""
    
    prompt = f"""
You are a clinical AI assistant.

Return your answer as **strict raw JSON only**.
Do NOT include markdown formatting, backticks, comments, or explanations.
The response MUST be valid JSON and MUST NOT contain ```json or ```.

Output format:
[
  {{
    "memberId": "...",
    "suspectType": "...",
    "suspectDiagnosis": {{
      "code": "...",
      "description": "...",
      "hccCategory": "..."
    }},
    "confidenceScore": 0.85,
    "priority": "...",
    "evidence": {{
      "summary": "...",
      "details": ["...", "..."]
    }},
    "suggestedAction": "..."
  }}
]

Members: {members}
"""

    prompt2 = f"""
  You are a clinical AI assistant.

  Return your answer as **strict raw JSON only**.
  Do NOT include markdown formatting, backticks, comments, or explanations.
  The response MUST be valid JSON and MUST NOT contain ```json or ```.

  Only return diagnosis codes (ICD-10) that map to **V28 Medicare Advantage Risk Adjustment Model HCCs**.  
  Do not include any diagnosis codes that are not part of the V28 HCC model.

  Output format:
  [
    {{
      "memberId": "...",
      "suspectType": "...",
      "suspectDiagnosis": {{
        "code": "...",
        "description": "...",
        "hccCategory": "..."
      }},
      "confidenceScore": 0.85,
      "priority": "...",
      "evidence": {{
        "summary": "...",
        "details": ["...", "..."]
      }},
      "suggestedAction": "..."
    }}
  ]

  Members: {members}
  """


    try:
        response = llm_client.chat.completions.create(
            model=LLM_MODEL,
            messages=[
                {"role": "system", "content": "You are a clinical assistant."},
                {"role": "user", "content": prompt2},
            ],
            temperature=1.0,
        )

        output_text = response.choices[0].message.content
        suspects = json.loads(output_text)
        return suspects

    except json.JSONDecodeError:
        logging.error("Failed to decode LLM response")
        logging.error(output_text)
        return []

    except Exception as e:
        logging.error(f"Error calling LLM: {e}")
        return []


In [6]:
# def save_suspects_to_mongo(suspects):
#     if not suspects:
#         logging.info("No suspects to save")
#         return

#     try:
#         client = MongoClient(MONGODB_URI)
#         db = client[Database_Name]
#         db["ui.member.suspects"].insert_many(suspects)
#         logging.info(f"Saved {len(suspects)} suspect records")
#     except Exception as e:
#         logging.error(f"Error saving suspects: {e}")

def save_suspects_to_mongo(suspects):
    if not suspects:
        logging.info("No suspects to save")
        return

    try:
        client = MongoClient(MONGODB_URI)
        db = client[Database_Name]
        collection = db["ui.member.suspects"]

        operations = []

        for s in suspects:
            member_id = s.get("memberId")
            if not member_id:
                logging.warning(f"Skipping suspect without memberId: {s}")
                continue

            operations.append(
                UpdateOne(
                    {"memberId": member_id},  
                    {"$set": s},              
                    upsert=True
                )
            )

        if operations:
            result = collection.bulk_write(operations, ordered=False)
            logging.info(
                f"Bulk write complete: "
                f"matched={result.matched_count}, "
                f"modified={result.modified_count}, "
                f"upserted={len(result.upserted_ids)}"
            )
        else:
            logging.info("No valid suspects to write")

    except Exception as e:
        logging.error(f"Error saving suspects: {e}")


In [None]:

try:
    client = MongoClient(MONGODB_URI)
    db_client = client[Database_Name]
    logging.info("Connected to MongoDB")
except Exception as e:
    logging.error(f"Error connecting to MongoDB: {e}")

suspect_member_ids = set(
    doc["memberId"]
    for doc in db_client["ui.member.suspects"].find({}, {"memberId": 1})
)

print(f"Loaded {len(suspect_member_ids)} existing suspect member IDs.")
print(suspect_member_ids)

suspect_member_ids_new = {
    MBI_CROSSWALK_MAP.get(id, id)
    for id in suspect_member_ids
}

print(suspect_member_ids_new)

In [22]:
# def process_all_members(batch_size: int = 100):
#     try:
#         client = MongoClient(MONGODB_URI)
#         db = client[Database_Name]
#         logging.info("Connected to MongoDB")
#     except Exception as e:
#         logging.error(f"Error connecting to MongoDB: {e}")
#         return

#     cursor = db[eligibility_collection].find(
#         {}, {"_id": 0, "createdAt": 0, "updatedAt": 0},
#         no_cursor_timeout=True
#     ).batch_size(batch_size)

#     batch = []
#     total_processed = 0

#     for el in cursor:
#         batch.append(el)

#         if len(batch) >= batch_size:
#             logging.info(f"Processing batch of {len(batch)}...")
#             members = load_members_with_claims_from_docs(batch)
#             logging.info(f"Loaded claims for {len(members)} members")
#             suspects = call_llm_for_suspects(members)
#             logging.info(f"Generated {len(suspects)} suspects")
#             save_suspects_to_mongo(suspects)
#             logging.info(f"Saved {len(suspects)} suspects to MongoDB")

#             total_processed += len(batch)
#             logging.info(f"Total processed so far: {total_processed}")

#             batch = []

#     # final batch
#     if batch:
#         logging.info(f"Processing final batch of {len(batch)}...")
#         members = load_members_with_claims_from_docs(batch)
#         logging.info(f"Loaded claims for {len(members)} members")
#         suspects = call_llm_for_suspects(members)
#         logging.info(f"Generated {len(suspects)} suspects")
#         save_suspects_to_mongo(suspects)
#         logging.info(f"Saved {len(suspects)} suspects to MongoDB")

#         total_processed += len(batch)

#     logging.info(f"Processing complete. Total: {total_processed}")

def process_batch(batch_docs):
    """
    Processes a single batch of documents.
    Runs inside worker threads.
    """
    members = load_members_with_claims_from_docs(batch_docs)
    suspects = call_llm_for_suspects(members)
    save_suspects_to_mongo(suspects)
    return len(batch_docs)

def process_all_members(batch_size=100, max_workers=8):
    try:
        client = MongoClient(MONGODB_URI)
        db = client[Database_Name]
        logging.info("Connected to MongoDB")
    except Exception as e:
        logging.error(f"Error connecting to MongoDB: {e}")
        return

    last_id = None
    total_processed = 0

    executor = ThreadPoolExecutor(max_workers=max_workers)
    futures = []

    suspect_member_ids = set(
        doc["memberId"]
        for doc in db["ui.member.suspects"].find({}, {"memberId": 1})
    )
    
    suspect_member_ids_new = {
        MBI_CROSSWALK_MAP.get(id, id)
        for id in suspect_member_ids
    }

    while True:
        query = {
            "mbi": {"$nin": list(suspect_member_ids_new)}
        }
        if last_id:
            query["_id"] = {"$gt": last_id}

        batch_docs = list(
            db[eligibility_collection]
            .find(query, {"createdAt": 0, "updatedAt": 0})
            .sort("_id", 1)
            .limit(batch_size)
        )

        if not batch_docs:
            break

        last_id = batch_docs[-1]["_id"]

        # submit batch processing to thread pool
        futures.append(executor.submit(process_batch, batch_docs))

        logging.info(f"Submitted batch of {len(batch_docs)} docs to thread pool...")

        # Optional: throttle if too many futures queued
        if len(futures) > max_workers * 3:
            for f in as_completed(futures):
                total_processed += f.result()
            futures = []

    # Wait for remaining batches
    for f in as_completed(futures):
        total_processed += f.result()

    executor.shutdown(wait=True)
    logging.info(f"Processing complete. Total processed: {total_processed}")

In [None]:
logging.info("Loading MBI crosswalk map into memory...")
MBI_CROSSWALK_MAP = get_mbi_crosswalk_map()
logging.info(f"Loaded {len(MBI_CROSSWALK_MAP)} crosswalk records.")

In [None]:
logging.info(f"Starting batch processing...")
process_all_members(batch_size=4)
logging.info(f"All Done!")