In [None]:
# auto_icd_hcc_retriever.py
# -*- coding: utf-8 -*-
"""Module to extract patient age/treatment from transcriptions, map treatments to ICD-10 and HCC codes,
and produce a structured dataframe of results.
"""
from __future__ import annotations

import json
import logging
import os
from dataclasses import dataclass, asdict
from typing import Dict, Optional, List

import pandas as pd

# langchain and openai imports are kept lazy to allow import on-demand for testing
# Example required packages (install in your environment):
# pip install pandas langchain langchain-openai langchain-text-splitters langchain-huggingface chromadb

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


@dataclass
class PatientExtraction:
    patient_age: Optional[int] = None
    patient_treatment: str = "Unknown"
    medical_specialty: Optional[str] = None
    icd10: Optional[str] = None
    hcc: Optional[str] = None


class OpenAIClientWrapper:
    """
    Minimal wrapper around the OpenAI client used in the original notebook.
    This wrapper expects the environment variable OPENAI_API_KEY to be set.
    """

    def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4o-mini"):
        self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
        if not self.api_key:
            raise RuntimeError("OPENAI_API_KEY not set. Set the environment variable or pass api_key.")
        self.model = model
        # Lazy import to avoid requiring OpenAI at import-time
        try:
            from openai import OpenAI  # type: ignore
            self._client = OpenAI(api_key=self.api_key)
        except Exception as e:
            logger.warning("OpenAI client unavailable at import time: %s", e)
            self._client = None

    def get_client(self):
        if self._client is None:
            from openai import OpenAI  # type: ignore
            self._client = OpenAI(api_key=self.api_key)
        return self._client

    def extract_age_and_treatment(self, transcription: str) -> Dict:
        """
        Use an LLM function call (or a simple chat completion) to extract patient age and treatment.
        Returns a dict containing 'patient_age' and 'patient_treatment'.
        """
        client = self.get_client()
        prompt_system = (
            "You are a healthcare data extraction assistant. Extract the patient's age and the "
            "recommended treatment or primary procedure from the given clinical transcription. "
            "Always return both fields. If missing, set the field value to 'Unknown'."
        )
        # Function call schema to encourage structured output
        function_definition = [
            {
                "type": "function",
                "function": {
                    "name": "get_patient_data",
                    "description": "Extract patient age and treatment.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "patient_age": {"type": "integer", "description": "Age of the patient."},
                            "patient_treatment": {"type": "string", "description": "Treatment of the patient."},
                        },
                    },
                },
            }
        ]

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": prompt_system},
                {"role": "user", "content": f"Transcription: {transcription}"},
            ],
            tools=function_definition,
        )

        # The original notebook used tool_calls. Safely parse that structure.
        try:
            tool_call = response.choices[0].message.tool_calls[0]
            args = json.loads(tool_call.function.arguments)
            return {
                "patient_age": args.get("patient_age", None),
                "patient_treatment": args.get("patient_treatment", "Unknown"),
            }
        except Exception:
            # Fallback parsing: attempt to parse assistant content as JSON
            try:
                raw_text = response.choices[0].message.content
                parsed = json.loads(raw_text)
                return {
                    "patient_age": parsed.get("patient_age"),
                    "patient_treatment": parsed.get("patient_treatment", "Unknown"),
                }
            except Exception:
                # If everything fails, return Unknowns
                logger.exception("Failed to parse response from LLM; returning Unknown.")
                return {"patient_age": None, "patient_treatment": "Unknown"}

    def map_to_icd10(self, treatment: str, temperature: float = 0.2) -> str:
        """
        Simple wrapper to ask the LLM to return ICD-10 codes for a treatment.
        Returns a newline- or comma-separated string containing codes or 'Unknown'.
        """
        if not treatment or treatment.lower() == "unknown":
            return "Unknown"

        client = self.get_client()
        prompt = (
            f"Provide the most relevant ICD-10 codes for the following treatment or procedure: "
            f"'{treatment}'. Return **only** the codes as a comma-separated list. If no code is found, return 'Unknown'."
        )
        response = client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=temperature,
        )
        try:
            return response.choices[0].message.content.strip()
        except Exception:
            logger.exception("Failed to get ICD codes; returning 'Unknown'.")
            return "Unknown"


class HCCCodeRetriever:
    """
    RAG-based retriever for HCC codes using a local CSV crosswalk. This uses LangChain
    components to load documents, build embeddings, and run a prompt + LLM to extract only
    the HCC code corresponding to a keyword.
    """

    def __init__(
        self,
        csv_path: str,
        llm_model: str = "gpt-4o-mini",
        embedding_model: str = "text-embedding-3-large",
        openai_api_key: Optional[str] = None,
    ):
        self.csv_path = csv_path
        self.llm_model = llm_model
        self.embedding_model = embedding_model
        self.openai_api_key = openai_api_key or os.environ.get("OPENAI_API_KEY")
        if not self.openai_api_key:
            raise RuntimeError("OPENAI_API_KEY required for HCC retrieval.")

        # Lazy imports to simplify testability
        from langchain_community.document_loaders.csv_loader import CSVLoader
        from langchain_text_splitters import RecursiveCharacterTextSplitter
        from langchain_openai import ChatOpenAI
        from langchain_core.prompts import PromptTemplate
        from langchain_core.output_parsers import StrOutputParser
        from langchain_core.runnables import RunnablePassthrough
        from langchain_huggingface import HuggingFaceEmbeddings
        from langchain_community.vectorstores import Chroma

        # Load and split documents
        docs = CSVLoader(self.csv_path).load()
        splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
        splits = splitter.split_documents(docs)

        # Build vectorstore
        vectorstore = Chroma.from_documents(documents=splits, embedding=HuggingFaceEmbeddings())
        retriever = vectorstore.as_retriever(search_kwargs={"k": len(docs)})

        # Build prompt + chain
        llm = ChatOpenAI(model=self.llm_model, api_key=self.openai_api_key, temperature=0.0)

        template = """Answer ONLY with the HCC code (no explanatory text) based on the following context:
{context}

Question: {question}
"""
        prompt = PromptTemplate.from_template(template)
        self.rag_chain = ({"context": retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser())

    def get_hcc_code(self, key_word: str) -> str:
        """Return ONLY the HCC code text for the provided keyword."""
        if not key_word:
            return "Unknown"
        question = f"What is the HCC code for: {key_word}? Return only the HCC code."
        try:
            raw_result = self.rag_chain.invoke(question)
            return raw_result.strip()
        except Exception:
            logger.exception("HCC retrieval failed for key_word=%s", key_word)
            return "Unknown"


class TranscriptionProcessor:
    """
    High-level orchestrator to process a CSV of transcriptions and produce
    a structured pandas DataFrame with age, treatment, ICD-10, and HCC codes.
    """

    def __init__(
        self,
        transcription_csv_path: str,
        hcc_csv_path: str,
        openai_api_key: Optional[str] = None,
    ):
        self.transcription_csv_path = transcription_csv_path
        self.openai_client = OpenAIClientWrapper(api_key=openai_api_key)
        self.hcc_retriever = HCCCodeRetriever(csv_path=hcc_csv_path, openai_api_key=openai_api_key)

    def run(self) -> pd.DataFrame:
        df = pd.read_csv(self.transcription_csv_path)
        processed: List[Dict] = []

        for _, row in df.iterrows():
            transcription = row.get("transcription", "")
            medical_specialty = row.get("medical_specialty", None)

            extracted = self.openai_client.extract_age_and_treatment(transcription)
            patient_age = extracted.get("patient_age")
            treatment = extracted.get("patient_treatment", "Unknown")

            icd10 = self.openai_client.map_to_icd10(treatment)
            # We send a short primary diag string to HCC retriever to avoid very long prompts.
            primary_diag = treatment[:200]  # truncated to a reasonable length
            hcc_code = self.hcc_retriever.get_hcc_code(primary_diag)

            p = PatientExtraction(
                patient_age=patient_age,
                patient_treatment=treatment,
                medical_specialty=medical_specialty,
                icd10=icd10,
                hcc=hcc_code,
            )
            processed.append(asdict(p))

        result_df = pd.DataFrame(processed)
        return result_df


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Process clinical transcriptions to extract ICD-10 and HCC codes.")
    parser.add_argument("--transcriptions", required=True, help="Path to CSV with columns: transcription, medical_specialty")
    parser.add_argument("--hcc_csv", required=True, help="Path to HCC crosswalk CSV")
    parser.add_argument("--output", default="structured_output.csv", help="CSV output path")
    args = parser.parse_args()

    processor = TranscriptionProcessor(args.transcriptions, args.hcc_csv)
    logger.info("Starting transcription processing...")
    df = processor.run()
    df.to_csv(args.output, index=False)
    logger.info("Wrote structured output to %s", args.output)


if __name__ == "__main__":
    main()
