In [None]:
import os
import json
import wandb
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from sqlalchemy import create_engine, Column, Integer, String, JSON
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

# Initialize wandb
wandb.init(project="clinical-llm-application", name="concept-extraction-and-prediction")

# Connect to the source PostgreSQL database
source_db_url = "postgresql://username:password@localhost:5432/source_database"
source_engine = create_engine(source_db_url)

# Create and connect to the local PostgreSQL database for model output
local_db_url = "postgresql://username:password@localhost:5432/local_model_output"
local_engine = create_engine(local_db_url)
Base = declarative_base()

class ModelOutput(Base):
    __tablename__ = 'model_outputs'
    id = Column(Integer, primary_key=True)
    encounter_id = Column(String)
    extracted_concepts = Column(JSON)
    predicted_location = Column(String)

Base.metadata.create_all(local_engine)
Session = sessionmaker(bind=local_engine)
local_session = Session()

# Load the local LLM model
model_name = "path/to/your/local/llm/model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

def fetch_patient_data(encounter_id):
    """
    Fetch patient data from the source database.
    """
    with source_engine.connect() as connection:
        result = connection.execute(f"SELECT * FROM patient_encounters WHERE encounter_id = '{encounter_id}'")
        return result.fetchone()

def extract_clinical_concepts(patient_data):
    """
    Extract structured clinical concepts from unstructured patient data.
    """
    prompt = f"""
    Extract clinical concepts from the following patient data:
    {patient_data}
    
    Output the extracted concepts in JSON format.
    """
    
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=500)
    extracted_concepts = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    try:
        return json.loads(extracted_concepts)
    except json.JSONDecodeError:
        print("Error: Unable to parse extracted concepts as JSON.")
        return {}

def predict_next_location(extracted_concepts):
    """
    Predict the next location for the patient encounter based on extracted concepts.
    """
    prompt = f"""
    Given the following extracted clinical concepts:
    {json.dumps(extracted_concepts, indent=2)}
    
    Predict the most likely next location for this patient encounter.
    """
    
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=50)
    predicted_location = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return predicted_location.strip()

def process_encounter(encounter_id):
    """
    Process a single patient encounter.
    """
    patient_data = fetch_patient_data(encounter_id)
    extracted_concepts = extract_clinical_concepts(patient_data)
    predicted_location = predict_next_location(extracted_concepts)
    
    # Log to wandb
    wandb.log({
        "encounter_id": encounter_id,
        "num_extracted_concepts": len(extracted_concepts),
        "predicted_location": predicted_location
    })
    
    # Store results in local database
    model_output = ModelOutput(
        encounter_id=encounter_id,
        extracted_concepts=extracted_concepts,
        predicted_location=predicted_location
    )
    local_session.add(model_output)
    local_session.commit()
    
    return extracted_concepts, predicted_location

def analyze_results():
    """
    Analyze and visualize the results.
    """
    results = local_session.query(ModelOutput).all()
    
    # Calculate some statistics
    num_encounters = len(results)
    avg_concepts_per_encounter = np.mean([len(r.extracted_concepts) for r in results])
    location_distribution = {}
    for r in results:
        location_distribution[r.predicted_location] = location_distribution.get(r.predicted_location, 0) + 1
    
    # Create a bar plot of predicted locations
    plt.figure(figsize=(10, 6))
    plt.bar(location_distribution.keys(), location_distribution.values())
    plt.title("Distribution of Predicted Next Locations")
    plt.xlabel("Location")
    plt.ylabel("Frequency")
    plt.xticks(rotation=45)
    plt.tight_layout()
    
    # Log plot to wandb
    wandb.log({"location_distribution": wandb.Image(plt)})
    
    # Log statistics to wandb
    wandb.log({
        "num_encounters": num_encounters,
        "avg_concepts_per_encounter": avg_concepts_per_encounter
    })
    

# Example usage (you can put this in a separate cell for execution)
encounter_ids = ["1", "2", "3", ...]  
for encounter_id in encounter_ids:
    extracted_concepts, predicted_location = process_encounter(encounter_id)
    print(f"Processed encounter {encounter_id}")
    print(f"Extracted concepts: {json.dumps(extracted_concepts, indent=2)}")
    print(f"Predicted next location: {predicted_location}")
    print("---")

# Analyze results (you can put this in a separate cell for execution)
analyze_results()

# Close wandb run (you can put this in a final cell for execution)
wandb.finish()