In [1]:
import os
import pandas as pd

from pyspark.sql import SparkSession
import pyspark.sql.functions as sql_f
from pyspark.sql.types import *
from pyspark.sql.functions import to_date, datediff, floor, col, avg, substring

from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import unix_timestamp

spark = SparkSession.builder.getOrCreate()


## 2) Creating the spark dataframes

In [None]:
path = '/synthea_output/'

In [13]:
path = '/content/output/csv/'


In [14]:
#patient
observations = spark.read.csv(path+"observations.csv", header=True)
patient = spark.read.csv(path+"patients.csv", header=True)

#medical
careplans = spark.read.csv(path+"careplans.csv", header=True)
conditions = spark.read.csv(path+"conditions.csv", header=True)
procedures=spark.read.csv(path+"procedures.csv", header=True)
encounters = spark.read.csv(path+"encounters.csv", header=True)
medications = spark.read.csv(path+"medications.csv", header=True)

#insurance and hospital
payer_transitions=spark.read.csv(path+"payer_transitions.csv", header=True)
payers=spark.read.csv(path+"payers.csv", header=True)
providers=spark.read.csv(path+"providers.csv", header=True)
organizations=spark.read.csv(path+"organizations.csv", header=True)

## 3) Cleaning dataframes and renaming variables

In [15]:
# renaming columns
patient = (
    patient.withColumnRenamed("Id", "patient_id")
           .withColumnRenamed("MARITAL", "patient_marital")
           .withColumnRenamed("RACE", "patient_race")
           .withColumnRenamed("ETHNICITY", "patient_ethnicity")
           .withColumnRenamed("GENDER", "patient_gender")
           .withColumnRenamed("ZIP", "patient_zip")
)

encounters = (
    encounters.withColumnRenamed("PATIENT", "patient_id")
              .withColumnRenamed("Id", "encounter_id")
              .withColumnRenamed("DESCRIPTION", "encounter_description")
              .withColumnRenamed("CODE", "encounter_code")
              .withColumnRenamed("START", "encounter_start")
              .withColumn("encounter_start", to_date("encounter_start"))
              .withColumnRenamed("STOP", "encounter_stop")
              .withColumn("encounter_stop", to_date("encounter_stop"))
              .withColumn("PATIENT COST", col("TOTAL_CLAIM_COST") - col("PAYER_COVERAGE"))
              .withColumnRenamed("PAYER", "payer_id")
              .withColumnRenamed("ORGANIZATION", "organization_id")
              .withColumnRenamed("PROVIDER", "provider_id")
)

careplans = (
    careplans.withColumnRenamed("PATIENT", "patient_id")
             .withColumnRenamed("Id", "careplan_id")
             .withColumnRenamed("ENCOUNTER", "encounter_id")
             .withColumnRenamed("DESCRIPTION", "careplan_descriptions")
             .withColumnRenamed("CODE", "careplan_code")
)

procedures = (
    procedures.withColumnRenamed("PATIENT", "patient_id")
              .withColumnRenamed("ENCOUNTER", "encounter_id")
              .withColumnRenamed("DESCRIPTION", "procedure_descriptions")
              .withColumnRenamed("CODE", "procedure_code")
              .withColumnRenamed("DATE", "procedure_date")
              .withColumnRenamed("BASE_COST", "procedure_cost")
)

conditions = (
    conditions.withColumnRenamed("PATIENT", "patient_id")
              .withColumnRenamed("ENCOUNTER", "encounter_id")
              .withColumnRenamed("DESCRIPTION", "condition_description")
              .withColumnRenamed("CODE", "condition_code")
              .withColumnRenamed("START", "condition_start")
              .withColumnRenamed("END", "condition_end")
)

observations = (
    observations.withColumnRenamed("PATIENT", "patient_id")
                .withColumnRenamed("ENCOUNTER", "encounter_id")
                .withColumnRenamed("DATE", "observation_date")
                .withColumn("observation_date", to_date("observation_date"))
)

medications = (
    medications.withColumnRenamed("START", "medication_start")
               .withColumn("medication_start", to_date("medication_start"))
               .withColumnRenamed("STOP", "medication_stop")
               .withColumn("medication_stop", to_date("medication_stop"))
               .withColumnRenamed("PATIENT", "patient_id")
               .withColumnRenamed("PAYER", "payer_id")
               .withColumnRenamed("ENCOUNTER", "encounter_id")
               .withColumnRenamed("CODE", "medication_code")
               .withColumnRenamed("DESCRIPTION", "medication_description")
)

payer_transitions = (
    payer_transitions.withColumnRenamed("PATIENT", "patient_id")
                     .withColumnRenamed("PAYER", "payer_id")
)

payers = (
    payers.withColumnRenamed("Id", "payer_id")
          .withColumnRenamed("NAME", "payer_name")
          .withColumnRenamed("OWNERSHIP", "payer_ownership")
)

providers = (
    providers.withColumnRenamed("Id", "provider_id")
             .withColumnRenamed("SPECIALITY", "provider_specialty")
)

organizations = (
    organizations.withColumnRenamed("Id", "organization_id")
                 .withColumnRenamed("NAME", "organization_name")
                 .withColumnRenamed("ZIP", "organization_zip")
                 .withColumn("organization_zip", substring(col("organization_zip").cast("string"), 1, 5))
)

In [16]:
#can do this too
encounters = (
    encounters
    .join(payers.select("payer_id", "payer_name", "payer_ownership"), on="payer_id", how="left")
    .join(organizations.select("organization_id", "organization_name", "organization_zip"), on="organization_id", how="left")
    .join(providers.select("provider_id", "provider_specialty"), on="provider_id", how="left")
    .join(procedures.select("encounter_id", "procedure_descriptions", "procedure_code"), on="encounter_id", how="left")
    .join(patient.select("patient_id", "BIRTHDATE", "patient_marital", "patient_race", "patient_ethnicity", "patient_gender", "patient_zip"), on="patient_id", how="left")
    .withColumn("age_at_encounter", floor(datediff(col("encounter_start"), col("BIRTHDATE")) / 365.25))
)

In [17]:
encounters.select("encounter_description").show(20, truncate=False)

+------------------------------------------+
|encounter_description                     |
+------------------------------------------+
|Encounter for symptom (procedure)         |
|Encounter for problem (procedure)         |
|General examination of patient (procedure)|
|General examination of patient (procedure)|
|Encounter for problem (procedure)         |
|General examination of patient (procedure)|
|Consultation for treatment (procedure)    |
|Consultation for treatment (procedure)    |
|Patient encounter procedure (procedure)   |
|General examination of patient (procedure)|
|General examination of patient (procedure)|
|General examination of patient (procedure)|
|General examination of patient (procedure)|
|General examination of patient (procedure)|
|Encounter for check up (procedure)        |
|Encounter for check up (procedure)        |
|Encounter for check up (procedure)        |
|Encounter for check up (procedure)        |
|Encounter for check up (procedure)        |
|Encounter

In [18]:
encounters.select("procedure_descriptions").show(20, truncate=False)

+-----------------------------------------------------------------------------------------------+
|procedure_descriptions                                                                         |
+-----------------------------------------------------------------------------------------------+
|NULL                                                                                           |
|NULL                                                                                           |
|NULL                                                                                           |
|NULL                                                                                           |
|NULL                                                                                           |
|NULL                                                                                           |
|NULL                                                                                           |
|NULL               

## Trying to split by topic using LDA and train regession model per topic

In [19]:
modeling_df = encounters.select(
    col("PATIENT COST").cast("double").alias("label"),
    col("age_at_encounter").cast("double"),
    col("patient_marital"),
    col("patient_race"),
    col("patient_ethnicity"),
    col("patient_gender"),
    col("ENCOUNTERCLASS"),
    col("payer_ownership"),
    col("payer_name"),
    col("organization_zip"),
    col("organization_name"),
    col("procedure_code"),
    col("procedure_descriptions"),
    col("encounter_description"),
    col("encounter_code"),
).na.drop()


In [20]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover, CountVectorizer
from pyspark.ml.clustering import LDA
from pyspark.sql.functions import col, lower, regexp_replace, concat_ws


In [21]:
#clean the text first
modeling_df = modeling_df.filter(
    col("encounter_description").isNotNull() &
    col("procedure_descriptions").isNotNull()
).withColumn(
    "clean_encounter_text",
    lower(regexp_replace(col("encounter_description"), "[^a-zA-Z\\s]", ""))
).withColumn(
    "clean_procedure_text",
    lower(regexp_replace(col("procedure_descriptions"), "[^a-zA-Z\\s]", ""))
)


modeling_df = modeling_df.withColumn(
    "combined_text",
    concat_ws(" | ",  # Separator to distinguish sources
        col("clean_encounter_text"),
        col("clean_procedure_text")
    )
)

In [22]:
# Additional Stopwords
stopwords=StopWordsRemover.loadDefaultStopWords("english") + [
    "patient", "doctor", "visit", "care", "provider", "encounter", "hospital","room","admission","procedure"
]

# Text Preprocessing Pipeline
tokenizer = Tokenizer(inputCol="combined_text", outputCol="words")
remover = StopWordsRemover(inputCol="words", outputCol="filtered",stopWords=stopwords)
vectorizer = CountVectorizer(inputCol="filtered", outputCol="features")

In [23]:
# Fit and Transform Data
tokenized_df = tokenizer.transform(modeling_df)
filtered_df = remover.transform(tokenized_df)
vectorizer_model = vectorizer.fit(filtered_df)
vectorized_df = vectorizer_model.transform(filtered_df)

In [24]:
# Train LDA Model
lda = LDA(k=11, maxIter=10, featuresCol="features")
lda_model = lda.fit(vectorized_df)

In [25]:
# Show Topics
topics = lda_model.describeTopics(5)
topics.show(truncate=False)

+-----+------------------------+-----------------------------------------------------------------------------------------------------------------+
|topic|termIndices             |termWeights                                                                                                      |
+-----+------------------------+-----------------------------------------------------------------------------------------------------------------+
|0    |[127, 178, 282, 44, 211]|[0.0040251848984337426, 0.003963807050217229, 0.003943115615965452, 0.0038631460855705836, 0.0038372914935389467]|
|1    |[0, 7, 12, 13, 15]      |[0.14506356257823413, 0.08443604568770224, 0.06841066273664953, 0.06825200922854656, 0.04398986820597827]        |
|2    |[0, 1, 2, 4, 3]         |[0.17890218665107918, 0.09413377257078004, 0.08132200902760102, 0.06699595329117693, 0.0668014509224225]         |
|3    |[0, 1, 9, 4, 22]        |[0.08097086983448099, 0.06453014492134326, 0.05891449735564754, 0.05889740978202827, 0

In [26]:
vocab = vectorizer_model.vocabulary
topics_with_words = topics.rdd.map(
    lambda row: (row['topic'], [vocab[i] for i in row['termIndices']])
).toDF(["topic", "top_words"])

topics_with_words.show(truncate=False)

+-----+-----------------------------------------------------+
|topic|top_words                                            |
+-----+-----------------------------------------------------+
|0    |[plan, nine, episiotomy, disorders, product]         |
|1    |[|, problem, dialysis, renal, environment]           |
|2    |[|, check, examination, dental, general]             |
|3    |[|, check, using, dental, plaque]                    |
|4    |[urine, renal, intrauterine, certification, gingivae]|
|5    |[|, , unit, intensive, administration]               |
|6    |[|, consultation, report, test, alcohol]             |
|7    |[health, |, social, needs, assessment]               |
|8    |[symptom, respiratory, testing, signs, symptoms]     |
|9    |[ward, discharge, peripheral, prone, intrauterine]   |
|10   |[manual, bone, pregnancy, titer, treatment]          |
+-----+-----------------------------------------------------+



In [38]:
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
import numpy as np

modeling_df = lda_model.transform(vectorized_df)

get_topic = udf(lambda v: int(np.argmax(v)), IntegerType())
modeling_df = modeling_df.withColumn("topic", get_topic(col("topicDistribution")))

In [39]:
from pyspark.sql.functions import col
from pyspark.sql.types import DoubleType

numeric_columns = [ "age_at_encounter"]

for column in numeric_columns:
    modeling_df = modeling_df.withColumn(column, col(column).cast(DoubleType()))


In [40]:
from pyspark.ml.feature import StringIndexer

categorical_cols = ['patient_marital', 'patient_race', 'patient_ethnicity',
                   'patient_gender', 'ENCOUNTERCLASS',
                   'payer_ownership',"payer_name","organization_name", "organization_zip", 'procedure_code',"encounter_code"]

for cat_col in categorical_cols:
    indexer = StringIndexer(inputCol=cat_col, outputCol=cat_col + "_index", handleInvalid="keep")
    modeling_df = indexer.fit(modeling_df).transform(modeling_df)


In [41]:
from pyspark.ml.feature import VectorAssembler

# Define final feature list
final_features = numeric_columns + [col + "_index" for col in categorical_cols] + ["topic"]

# Drop rows with nulls in important columns
modeling_df = modeling_df.na.drop(subset=final_features + ["label"])

# Assemble into feature vector
assembler = VectorAssembler(inputCols=final_features, outputCol="final_features")
modeling_df = assembler.transform(modeling_df)


In [42]:
from pyspark.ml.regression import LinearRegression

models = {}
topics = modeling_df.select("topic").distinct().rdd.flatMap(lambda x: x).collect()

for t in topics:
    topic_df = modeling_df.filter(col("topic") == t)
    lr = LinearRegression(featuresCol="final_features", labelCol="label",regParam=0.1) #might not need regParam with a bigger dataset
    model = lr.fit(topic_df)
    models[t] = model


In [43]:
from functools import reduce
from pyspark.sql import DataFrame

def unionAll(*dfs):
    return reduce(DataFrame.unionByName, dfs)

predictions_list = []

for t, model in models.items():
    topic_df = modeling_df.filter(col("topic") == t)
    preds = model.transform(topic_df)
    predictions_list.append(preds)

final_predictions = unionAll(*predictions_list)
final_predictions.select("label", "prediction", "topic").show(10, truncate=False)


+------------------+------------------+-----+
|label             |prediction        |topic|
+------------------+------------------+-----+
|142.0             |2528.5285189819697|1    |
|142.0             |2824.8259316420126|1    |
|142.0             |2330.9969105419414|1    |
|142.0             |2182.8482042119203|1    |
|142.0             |2182.8482042119203|1    |
|143.20000000000005|2380.3798126519487|1    |
|1250.85           |1385.6669808209908|1    |
|9694.170000000002 |2386.491407506738 |1    |
|9694.170000000002 |2534.64011383676  |1    |
|0.0               |3706.443357840976 |1    |
+------------------+------------------+-----+
only showing top 10 rows



In [45]:
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator(labelCol="label", predictionCol="prediction", metricName="rmse")

rmse_per_topic = {}
for t in topics:
    topic_predictions = final_predictions.filter(col("topic") == t)
    rmse = evaluator.evaluate(topic_predictions)
    rmse_per_topic[t] = rmse

for topic, rmse in rmse_per_topic.items():
    print(f"RMSE for topic {topic}: {rmse}")



RMSE for topic 1: 1878.223558554715
RMSE for topic 6: 285.8947994782201
RMSE for topic 3: 2284.193572808087
RMSE for topic 5: 2558.592566886567
RMSE for topic 9: 0.02059086411394987
RMSE for topic 4: 0.09092650811243945
RMSE for topic 8: 93.81029666191061
RMSE for topic 7: 719.7302145007224
RMSE for topic 2: 899.7499943821864
RMSE for topic 0: 0.0
