In [None]:
import os
import json
import wandb
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from typing import List, Dict, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
from sqlalchemy import create_engine, text
from dotenv import load_dotenv
from sklearn.metrics import confusion_matrix

def get_time_string(fmt="%m%d_%H%M"):
    return datetime.now().strftime(fmt)

# Load environment variables
load_dotenv()

config = {
    "wandb_project": "clinical-concept-extraction-and-prediction",
    "wandb_entity": "llm-lab",
    "wandb_run_name": f"{get_time_string}_concept-extraction-and-prediction",
    "model_name" : "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "label_mapping": {
      "ED": 0,
      "Med-Surg": 1, 
      "Observation": 2,
      "ICU": 3,
      "Stepdown": 4
    },
    "location_mapping": {
      "Corning Hospital 2e": "Med-Surg",
      "Corning Hospital 2w": "Med-Surg",
      "Corning Hospital Emergency Department": "ED",
      "Corning Hospital Icu": "ICU",
      "Cortland Hospital 2c": "Med-Surg",
      "Cortland Hospital 2s": "Med-Surg",
      "Cortland Hospital Icu": "ICU",
      "Cortland Hospital Maternity": "Med-Surg",
      "Guthrie Rph Towanda Med/Surg Unit": "Med-Surg",
      "Guthrie Towanda Campus, Med/Surg Unit": "Med-Surg",
      "Robert Packer Hospital 2Icu": "ICU",
      "Robert Packer Hospital 4w": "Observation",
      "Robert Packer Hospital 6m": "Med-Surg",
      "Robert Packer Hospital 6w": "Med-Surg",
      "Robert Packer Hospital 7Icu": "Stepdown",
      "Robert Packer Hospital 7m": "Med-Surg",
      "Robert Packer Hospital 7nw": "Med-Surg",
      "Robert Packer Hospital 8nw": "Med-Surg",
      "Robert Packer Hospital 9sw": "Med-Surg",
      "Robert Packer Hospital Emergency Department": "ED",
      "Robert Packer Hospital JC": "Med-Surg",
      "RPH 6 NORTHWEST": "Med-Surg",
      "RPH 6 Southwest": "Med-Surg",
      "Troy Community Hospital 1m": "Med-Surg",
      "Robert Packer Hospital Preprocedure": "Med-Surg",
      "Robert Packer Hospital Recovery": "Med-Surg",
      "Guthrie Rph Towanda Emergency Department": "ED",
      "Guthrie Rph Towanda Skilled Nursing Unit": "Med-Surg",
      "Guthrie Towanda Campus, Emergency Department": "ED",
      "Guthrie Towanda Campus, Skilled Nursing Unit": "Med-Surg",
      "Towanda Memorial Hospital Skilled Nursing": "Med-Surg",
      "Troy Community Hospital Emergency Dept": "ED",
      "Troy Community Hospital Preprocedure": "Med-Surg",
      "Cortland Hospital Emergency Department": "ED",
      "Cortland Hospital Preprocedure": "Med-Surg",
      "Cortland Hospital Recovery": "Med-Surg",
      "Guthrie Rph Towanda Acute Rehab Unit": "Med-Surg",
      "Cortland Hospital Ultrasound": "Med-Surg",
      "CH Recovery": "Med-Surg",
      "Corning Hospital Observation Unit": "Observation",
      "Robert Packer Hospital": "Med-Surg",
      "Guthrie Towanda Campus, Preprocedure": "Med-Surg"
    },
    "batch_size": 10, # Number of encounters to fetch and process in each batch
    "max_length": 2048,
}

# Initialize wandb
wandb.login(key=os.getenv('WANDB_API_KEY'))
wandb.init(project=config["wandb_project"], entity=config["wandb_entity"], name=config["wandb_run_name"])

# Connect to the source PostgreSQL database
source_db_url = f"postgresql+psycopg2://{os.getenv('ARTISIGHT_DB_USER')}:{os.getenv('ARTISIGHT_DB_PASSWORD')}@{os.getenv('ARTISIGHT_DB_HOST')}:{os.getenv('ARTISIGHT_DB_PORT')}/{os.getenv('ARTISIGHT_DB_NAME')}"
source_engine = create_engine(source_db_url)

# Load the local LLM model

tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
model = AutoModelForCausalLM.from_pretrained(config["model_name"])

def fetch_batch_encounter_data() -> pd.DataFrame:
    """
    Fetch encounter data for a batch of encounters from the source database.
    """
    query = text(f"""
    WITH ranked_encounters AS (
        SELECT 
            encounters.id AS encounter_id,
            ROW_NUMBER() OVER (ORDER BY encounters.id) AS row_num
        FROM data_service_fhir_app_encounter encounters
        WHERE encounters.document ->> 'status' = 'finished'
    )
    SELECT 
        encounters.id AS encounter_id,
        encounters.fhir_id AS fhir_encounter_id,
        encounters.patient_id,
        patients.fhir_id AS fhir_patient_id,
        ((encounters.document#>> '{{}}')::jsonb->'period'->>'start')::timestamp AS encounter_start,
        ((encounters.document#>> '{{}}')::jsonb->'period'->>'end')::timestamp AS encounter_end,
        encounters.document ->> 'status' AS encounter_status,
        docs.id AS documentreference_id,
        docs.fhir_id AS fhir_documentreference_id,
        binaries.id AS binary_id,
        docs.document -> 'type' ->> 'text' AS note_type,
        (docs.document ->> 'date')::timestamp AS note_date,
        binaries.last_updated AS note_last_updated,
        encounters.document -> 'location' as locations,
        (encounters.document -> 'location' -> -1 -> 'location' ->> 'display') AS current_location,
        CASE 
            WHEN jsonb_array_length(encounters.document -> 'location') > 1 
            THEN (encounters.document -> 'location' -> -2 -> 'location' ->> 'display')
            ELSE NULL
        END AS previous_location,
        binaries.processed_data
    FROM ranked_encounters
    JOIN data_service_fhir_app_encounter encounters ON ranked_encounters.encounter_id = encounters.id
    JOIN data_service_fhir_app_documentreference_encounter docref ON encounters.id = docref.encounter_id
    JOIN data_service_fhir_app_documentreference docs ON docref.documentreference_id = docs.id
    JOIN data_service_fhir_app_documentreference_binary docbin ON docs.id = docbin.documentreference_id
    JOIN data_service_fhir_app_binary binaries ON docbin.binary_id = binaries.id
    JOIN data_service_fhir_app_patient patients ON encounters.patient_id = patients.id
    WHERE binaries.processed_data IS NOT NULL
        AND docs.document ->> 'docStatus' IN ('final', 'amended')
        AND docs.document -> 'type' ->> 'text' IN ('Progress Notes', 'Consults', 'H&P')
        AND ranked_encounters.row_num <= {config['batch_size']}
    ORDER BY encounters.id, note_date
    """)
    
    with source_engine.connect() as connection:
        result = connection.execute(query)
        df = pd.DataFrame(result.fetchall(), columns=result.keys())
    return df

def preprocess_encounters(df_encounters: pd.DataFrame) -> List[Dict[str, Any]]:
    """
    Preprocess encounter data to generate next location prediction samples.
    """
    all_outputs = []
    for encounter_id, df_encounter in df_encounters.groupby("encounter_id"):
        try:
            df_encounter = df_encounter.sort_values("note_date")
            
            locations = json.loads(df_encounter["locations"].iloc[0])
            if not locations or len(locations) < 2:
                continue

            # Create a list of location changes with timestamps
            location_changes = []
            for loc in locations:
                if "period" in loc and "start" in loc["period"]:
                    location_changes.append({
                        "location": loc["location"]["display"],
                        "timestamp": pd.to_datetime(loc["period"]["start"]).tz_localize(None)
                    })
            
            # Sort location changes by timestamp
            location_changes.sort(key=lambda x: x["timestamp"])

            # Find transitions between different locations
            transitions = []
            for i in range(len(location_changes) - 1):
                if location_changes[i]["location"] != location_changes[i+1]["location"]:
                    transitions.append((i, i+1))

            if not transitions:
                continue  # No transitions between different locations found

            outputs = []
            for start_index, end_index in transitions:
                current_location = location_changes[start_index]["location"]
                next_location = location_changes[end_index]["location"]
                
                if next_location not in config["location_mapping"] and next_location not in config["location_exclusions"]:
                    # print(f"Warning: Location {next_location} not in location mapping, please update the mapping to include transitions to this location")
                    continue
                elif next_location in config["location_exclusions"]:
                    continue
                
                label = config["location_mapping"][next_location]
                
                transition_time = location_changes[end_index]["timestamp"]

                # Filter notes up to the transition time
                notes = df_encounter[
                    (df_encounter["note_type"].isin(["H&P", "Progress Notes", "Consults"])) &
                    (df_encounter["note_date"] <= transition_time)
                ]

                if notes.empty:
                    continue

                input_text = assemble_notes(notes)
                
                outputs.append({
                    "patient_id": df_encounter["patient_id"].iloc[0],
                    "encounter_id": encounter_id,
                    "encounter_start": df_encounter["encounter_start"].iloc[0],
                    "encounter_end": df_encounter["encounter_end"].iloc[0],
                    "input": input_text,
                    "label": label,
                    "label_id": config["label_mapping"][label],
                    "db_record_last_updated": notes["note_last_updated"].iloc[-1],
                    "current_location": current_location,
                    "current_time": transition_time,
                    "next_location": next_location,
                    "transition_time": transition_time
                })

            all_outputs.extend(outputs)
        except Exception as e:
            print(f"Error processing encounter {encounter_id}: {e}")
    
    return all_outputs

def assemble_notes(df_encounter: pd.DataFrame) -> str:
    """Assemble notes for an encounter."""
    inp = ""
    hp_note = df_encounter[df_encounter["note_type"] == "H&P"]["processed_data"]
    progress_consults = df_encounter[df_encounter["note_type"].isin(["Progress Notes", "Consults"])]["processed_data"]

    if len(progress_consults) > 1:
        for note_id in (-1, -2):
            if abs(note_id) <= len(progress_consults):
                note = progress_consults.iloc[note_id]
                inp += note + "\n"
    elif len(progress_consults) > 0:
        inp += progress_consults.iloc[-1] + "\n"

    if len(hp_note) > 0:
        inp += hp_note.iloc[0]

    return inp

def extract_clinical_concepts(patient_data: str) -> Dict[str, Any]:
    """
    Extract structured clinical concepts from unstructured patient data.
    """
    # TODO: - [ ] Enhance this prompt.
    prompt = f"""
    Extract clinical concepts from the following patient data:
    {patient_data}
    
    Output the extracted concepts in JSON format.
    """
    
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=500)
    extracted_concepts = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    try:
        return json.loads(extracted_concepts)
    except json.JSONDecodeError:
        print("Error: Unable to parse extracted concepts as JSON.")
        return {}

def predict_next_location(extracted_concepts: Dict[str, Any]) -> str:
    """
    Predict the next location for the patient encounter based on extracted concepts.
    """
    # TODO: - [ ] Enhance this prompt.
    prompt = f"""
    Given the following extracted clinical concepts:
    {json.dumps(extracted_concepts, indent=2)}
    
    Predict the most likely next location for this patient encounter.
    Choose from: {', '.join(config['location_mapping'].values())}
    """
    
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=config["max_length"])
    predicted_location = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return predicted_location.strip()

def process_batch(preprocessed_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Process a batch of preprocessed encounter data.
    """
    results = []
    for sample in preprocessed_data:
        extracted_concepts = extract_clinical_concepts(sample['input'])
        predicted_location = predict_next_location(extracted_concepts)
        
        result = {
            "encounter_id": sample['encounter_id'],
            "current_location": sample['current_location'],
            "actual_next_location": sample['next_location'],
            "predicted_next_location": predicted_location,
            "num_extracted_concepts": len(extracted_concepts),
            "is_correct": predicted_location == sample['label']
        }
        results.append(result)
        
        # Log to wandb
        wandb.log({
            "encounter_id": sample['encounter_id'],
            "num_extracted_concepts": len(extracted_concepts),
            "actual_next_location": sample['label'],
            "predicted_next_location": predicted_location,
            "is_correct": result['is_correct']
        })
    
    return results

def analyze_results(results: List[Dict[str, Any]]):
    """
    Analyze and visualize the results.
    """
    df_results = pd.DataFrame(results)
    
    # Calculate accuracy
    accuracy = df_results['is_correct'].mean()
    
    # Calculate confusion matrix
    cm = confusion_matrix(df_results['actual_next_location'], df_results['predicted_next_location'], labels=list(config['location_mapping'].values()))
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(config['location_mapping'].values()))
    plt.xticks(tick_marks, config['location_mapping'].values(), rotation=45)
    plt.yticks(tick_marks, config['location_mapping'].values())
    plt.tight_layout()
    plt.ylabel('Actual Location')
    plt.xlabel('Predicted Location')
    
    # Log plot to wandb
    wandb.log({"confusion_matrix": wandb.Image(plt)})
    
    # Log statistics to wandb
    wandb.log({
        "accuracy": accuracy,
        "num_samples": len(df_results)
    })
    
# Main execution
all_results = []
batch_number = 1

while True:
    print(f"Processing batch {batch_number}")
    
    # Fetch a batch of encounter data
    df_batch = fetch_batch_encounter_data()
    
    if df_batch.empty:
        print("No more encounters to process.")
        break
    
    # Preprocess the batch
    preprocessed_data = preprocess_encounters(df_batch)
    
    # Process the batch
    batch_results = process_batch(preprocessed_data)
    all_results.extend(batch_results)
    
    print(f"Processed {len(batch_results)} samples in batch {batch_number}")
    batch_number += 1

# Analyze results
analyze_results(all_results)

# Close wandb run
wandb.finish()