In [1]:
import os
import pandas as pd
import seaborn as sns

In [2]:
!pip install pyspark

[0m

In [3]:
# standard libraries
import pandas as pd
import matplotlib.pyplot as plt

# PySpark libraries
from pyspark.sql import SparkSession
import pyspark.sql.functions as sql_f
from pyspark.sql.types import *
from pyspark.sql.functions import to_date, datediff, floor, col, avg, substring

# DL/ML libraries
from sklearn.metrics import accuracy_score
spark = SparkSession.builder.getOrCreate()


In [4]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import unix_timestamp

for working in GCP

## 1) Loading the data

In [11]:
# @title Synthea Patient Generator (CSV Version)
import os
from IPython.display import clear_output

# Configuration
num_patients = 10  # @param {type:"integer"}
state = "Massachusetts"  # @param ["Massachusetts", "California", "New York", "Texas", "Florida"]
age_range = "30-85"  # @param {type:"string"}
seed = 12345  # @param {type:"integer"}

# Install Java
!sudo apt-get update
!sudo apt-get install -y openjdk-11-jdk-headless
clear_output()
print("✅ Java installed")

# Download Synthea
!wget -q https://github.com/synthetichealth/synthea/releases/download/master-branch-latest/synthea-with-dependencies.jar
clear_output()
print("✅ Synthea downloaded")

# Generate patients (using proper string substitution)
!java -jar synthea-with-dependencies.jar \
  -p {num_patients} \
  -s {seed} \
  -a "{age_range}" \
  --exporter.baseDirectory "./output" \
  --exporter.fhir.export=False \
  --exporter.csv.export=True \
  {state}

# Verify output
csv_output_path = "./output/csv"
if os.path.exists(csv_output_path):
    csv_files = [f for f in os.listdir(csv_output_path) if f.endswith('.csv')]
    if csv_files:
        print(f"\n🎉 Success! Generated {len(csv_files)} CSV files:")
        for file in csv_files[:5]:  # Show first 5 files
            print(f"- {file}")
        print(f"\nTotal records across all CSV files: {num_patients} patients")
    else:
        print("\n⚠ CSV directory exists but contains no CSV files")
else:
    print("\n❌ Generation failed. Common fixes:")
    print("1. Try reducing patient count (start with 10)")
    print("2. Check Java version:")
    !java -version
    print("3. Disk space:")
    !df -h

✅ Synthea downloaded
SLF4J: No SLF4J providers were found.
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#noProviders for further details.
Scanned 88 modules and 152 submodules.
Loading submodule modules/allergies/allergy_panel.json
Loading submodule modules/allergies/drug_allergy_incidence.json
Loading submodule modules/allergies/environmental_allergy_incidence.json
Loading submodule modules/allergies/food_allergy_incidence.json
Loading submodule modules/allergies/immunotherapy.json
Loading submodule modules/allergies/outgrow_env_allergies.json
Loading submodule modules/allergies/outgrow_food_allergies.json
Loading submodule modules/allergies/severe_allergic_reaction.json
Loading submodule modules/anemia/anemia_sub.json
Loading submodule modules/breast_cancer/chemotherapy_breast.json
Loading submodule modules/breast_cancer/hormone_diagnosis.json
Loading submodule modules/breast_cancer/hormonetherapy_breast.json
Loading submodul

In [12]:
!hdfs dfs -mkdir -p /synthea_output
!hdfs dfs -put ./output/csv/*.csv /synthea_output

put: `/synthea_output/allergies.csv': File exists
put: `/synthea_output/careplans.csv': File exists
put: `/synthea_output/claims.csv': File exists
put: `/synthea_output/claims_transactions.csv': File exists
put: `/synthea_output/conditions.csv': File exists
put: `/synthea_output/devices.csv': File exists
put: `/synthea_output/encounters.csv': File exists
put: `/synthea_output/imaging_studies.csv': File exists
put: `/synthea_output/immunizations.csv': File exists
put: `/synthea_output/medications.csv': File exists
put: `/synthea_output/observations.csv': File exists
put: `/synthea_output/organizations.csv': File exists
put: `/synthea_output/patients.csv': File exists
put: `/synthea_output/payer_transitions.csv': File exists
put: `/synthea_output/payers.csv': File exists
put: `/synthea_output/procedures.csv': File exists
put: `/synthea_output/providers.csv': File exists
put: `/synthea_output/supplies.csv': File exists


In [5]:
path = '/synthea_output/'

## 2) Creating the spark dataframes 

In [38]:
#patient
observations = spark.read.csv(path+"observations.csv", header=True)
patient = spark.read.csv(path+"patients.csv", header=True) 

#medical
careplans = spark.read.csv(path+"careplans.csv", header=True)
conditions = spark.read.csv(path+"conditions.csv", header=True)
procedures=spark.read.csv(path+"procedures.csv", header=True)
encounters = spark.read.csv(path+"encounters.csv", header=True)
medications = spark.read.csv(path+"medications.csv", header=True)

#insurance and hospital
payer_transitions=spark.read.csv(path+"payer_transitions.csv", header=True)
payers=spark.read.csv(path+"payers.csv", header=True)
providers=spark.read.csv(path+"providers.csv", header=True)
organizations=spark.read.csv(path+"organizations.csv", header=True)

                                                                                

In [7]:
encounters.columns

['Id',
 'START',
 'STOP',
 'PATIENT',
 'ORGANIZATION',
 'PROVIDER',
 'PAYER',
 'ENCOUNTERCLASS',
 'CODE',
 'DESCRIPTION',
 'BASE_ENCOUNTER_COST',
 'TOTAL_CLAIM_COST',
 'PAYER_COVERAGE',
 'REASONCODE',
 'REASONDESCRIPTION']

## 3) Cleaning dataframes and renaming variables 

In [39]:
# renaming columns

patient = (
    patient.withColumnRenamed("Id", "patient_id")
           .withColumnRenamed("MARITAL", "patient_marital")
           .withColumnRenamed("RACE", "patient_race")
           .withColumnRenamed("ETHNICITY", "patient_ethnicity")
           .withColumnRenamed("GENDER", "patient_gender")
           .withColumnRenamed("ZIP", "patient_zip")
)

encounters = (
    encounters.withColumnRenamed("PATIENT", "patient_id")
              .withColumnRenamed("Id", "encounter_id")
              .withColumnRenamed("DESCRIPTION", "encounter_description")
              .withColumnRenamed("CODE", "encounter_code")
              .withColumnRenamed("START", "encounter_start")
              .withColumn("encounter_start", to_date("encounter_start"))
              .withColumnRenamed("STOP", "encounter_stop")
              .withColumn("encounter_stop", to_date("encounter_stop"))
              .withColumn("PATIENT COST", col("TOTAL_CLAIM_COST") - col("PAYER_COVERAGE"))
              .withColumnRenamed("PAYER", "payer_id")
              .withColumnRenamed("ORGANIZATION", "organization_id")
              .withColumnRenamed("PROVIDER", "provider_id")
)

careplans = (
    careplans.withColumnRenamed("PATIENT", "patient_id")
             .withColumnRenamed("Id", "careplan_id")
             .withColumnRenamed("ENCOUNTER", "encounter_id")
             .withColumnRenamed("DESCRIPTION", "careplan_descriptions")
             .withColumnRenamed("CODE", "careplan_code")
)

procedures = (
    procedures.withColumnRenamed("PATIENT", "patient_id")
              .withColumnRenamed("ENCOUNTER", "encounter_id")
              .withColumnRenamed("DESCRIPTION", "procedure_descriptions")
              .withColumnRenamed("CODE", "procedure_code")
              .withColumnRenamed("DATE", "procedure_date")
              .withColumnRenamed("BASE_COST", "procedure_cost")
)

conditions = (
    conditions.withColumnRenamed("PATIENT", "patient_id")
              .withColumnRenamed("ENCOUNTER", "encounter_id")
              .withColumnRenamed("DESCRIPTION", "condition_description")
              .withColumnRenamed("CODE", "condition_code")
              .withColumnRenamed("START", "condition_start")
              .withColumnRenamed("END", "condition_end")
)

observations = (
    observations.withColumnRenamed("PATIENT", "patient_id")
                .withColumnRenamed("ENCOUNTER", "encounter_id")
                .withColumnRenamed("DATE", "observation_date")
                .withColumn("observation_date", to_date("observation_date"))
)

medications = (
    medications.withColumnRenamed("START", "medication_start")
               .withColumn("medication_start", to_date("medication_start"))
               .withColumnRenamed("STOP", "medication_stop")
               .withColumn("medication_stop", to_date("medication_stop"))
               .withColumnRenamed("PATIENT", "patient_id")
               .withColumnRenamed("PAYER", "payer_id")
               .withColumnRenamed("ENCOUNTER", "encounter_id")
               .withColumnRenamed("CODE", "medication_code")
               .withColumnRenamed("DESCRIPTION", "medication_description")
)

payer_transitions = (
    payer_transitions.withColumnRenamed("PATIENT", "patient_id")
                     .withColumnRenamed("PAYER", "payer_id")
)

payers = (
    payers.withColumnRenamed("Id", "payer_id")
          .withColumnRenamed("NAME", "payer_name")
          .withColumnRenamed("OWNERSHIP", "payer_ownership")
)

providers = (
    providers.withColumnRenamed("Id", "provider_id")
             .withColumnRenamed("SPECIALITY", "provider_specialty")
)

organizations = (
    organizations.withColumnRenamed("Id", "organization_id")
                 .withColumnRenamed("NAME", "organization_name")
                 .withColumnRenamed("ZIP", "organization_zip")
                 .withColumn("organization_zip", substring(col("organization_zip").cast("string"), 1, 5))
)

In [40]:
#can do this too 
encounters = (
    encounters
    .join(payers.select("payer_id", "payer_name", "payer_ownership"), on="payer_id", how="left")
    .join(organizations.select("organization_id", "organization_name", "organization_zip"), on="organization_id", how="left")
    .join(providers.select("provider_id", "provider_specialty"), on="provider_id", how="left")
    .join(procedures.select("encounter_id", "procedure_descriptions", "procedure_code"), on="encounter_id", how="left")
    .join(patient.select("patient_id", "BIRTHDATE", "patient_marital", "patient_race", "patient_ethnicity", "patient_gender", "patient_zip"), on="patient_id", how="left")
    .withColumn("age_at_encounter", floor(datediff(col("encounter_start"), col("BIRTHDATE")) / 365.25))
)

In [10]:
encounters.columns

['patient_id',
 'encounter_id',
 'provider_id',
 'organization_id',
 'payer_id',
 'encounter_start',
 'encounter_stop',
 'ENCOUNTERCLASS',
 'encounter_code',
 'encounter_description',
 'BASE_ENCOUNTER_COST',
 'TOTAL_CLAIM_COST',
 'PAYER_COVERAGE',
 'REASONCODE',
 'REASONDESCRIPTION',
 'PATIENT COST',
 'payer_name',
 'payer_ownership',
 'organization_name',
 'organization_zip',
 'provider_specialty',
 'procedure_descriptions',
 'procedure_code',
 'BIRTHDATE',
 'patient_marital',
 'patient_race',
 'patient_ethnicity',
 'patient_gender',
 'patient_zip',
 'age_at_encounter']

## Split by topic using LDA and train regession model per topic 

### LDA using encounter description

In [41]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover, CountVectorizer
from pyspark.ml.clustering import LDA
from pyspark.sql.functions import col, lower, regexp_replace


In [42]:
#clean the text first 
encounters = encounters.filter(col("encounter_description").isNotNull())
encounters = encounters.withColumn(
    "clean_text",
    lower(regexp_replace(col("encounter_description"), "[^a-zA-Z\\s]", ""))
)

In [43]:
# Additional Stopwords 
#ARE WE REMOVING TOO MANY STOPWORDS 
stopwords=StopWordsRemover.loadDefaultStopWords("english") + [
    "patient", "doctor", "visit", "care", "provider", "encounter", "hospital","room","admission","procedure"
]

# Text Preprocessing Pipeline
tokenizer = Tokenizer(inputCol="clean_text", outputCol="words")
remover = StopWordsRemover(inputCol="words", outputCol="filtered",stopWords=stopwords)
vectorizer = CountVectorizer(inputCol="filtered", outputCol="features")

In [44]:
# Fit and Transform Data
tokenized_df = tokenizer.transform(encounters)
filtered_df = remover.transform(tokenized_df)
vectorizer_model = vectorizer.fit(filtered_df)
vectorized_df = vectorizer_model.transform(filtered_df)

                                                                                

In [45]:
# Train LDA Model
#NEED TO INCREASE K WHEN INCREASING PATIENTS 
lda = LDA(k=2, maxIter=10, featuresCol="features")
lda_model = lda.fit(vectorized_df)

                                                                                

In [46]:
# Show Topics
topics = lda_model.describeTopics(5)
topics.show(truncate=False)

+-----+---------------+---------------------------------------------------------------------------------------------------------+
|topic|termIndices    |termWeights                                                                                              |
+-----+---------------+---------------------------------------------------------------------------------------------------------+
|0    |[1, 0, 2, 5, 6]|[0.3204907381968033, 0.30975259100756763, 0.05537340064752298, 0.04370772966588497, 0.029847597655479296]|
|1    |[2, 0, 1, 3, 4]|[0.438065940196969, 0.08362018771314891, 0.07016182811901082, 0.05882397489558225, 0.04377200215132048]  |
+-----+---------------+---------------------------------------------------------------------------------------------------------+



In [47]:
vocab = vectorizer_model.vocabulary
topics_with_words = topics.rdd.map(
    lambda row: (row['topic'], [vocab[i] for i in row['termIndices']])
).toDF(["topic", "top_words"])

topics_with_words.show(truncate=False)

                                                                                

+-----+------------------------------------------------------+
|topic|top_words                                             |
+-----+------------------------------------------------------+
|0    |[examination, general, check, problem, hospice]       |
|1    |[check, general, examination, regimetherapy, prenatal]|
+-----+------------------------------------------------------+



In [48]:
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
import numpy as np

encounters = lda_model.transform(vectorized_df)

get_topic = udf(lambda v: int(np.argmax(v)), IntegerType())
encounters = encounters.withColumn("topic", get_topic(col("topicDistribution")))

In [49]:
from pyspark.sql.functions import col
from pyspark.sql.types import DoubleType

numeric_columns = ["BASE_ENCOUNTER_COST", "TOTAL_CLAIM_COST", "PAYER_COVERAGE", "age_at_encounter"]

for column in numeric_columns:
    encounters = encounters.withColumn(column, col(column).cast(DoubleType()))


In [50]:
from pyspark.ml.feature import StringIndexer

categorical_columns = [
    "patient_gender", "patient_race", "patient_ethnicity",
    "patient_marital", "provider_specialty", "organization_zip","patient_zip",'REASONDESCRIPTION', 
 'payer_ownership',
 'organization_name',
]

for cat_col in categorical_columns:
    indexer = StringIndexer(inputCol=cat_col, outputCol=cat_col + "_index", handleInvalid="keep")
    encounters = indexer.fit(encounters).transform(encounters)


                                                                                

In [51]:
from pyspark.ml.feature import VectorAssembler

# Define final feature list
final_features = numeric_columns + [col + "_index" for col in categorical_columns] + ["topic"]

# Drop rows with nulls in important columns
encounters = encounters.na.drop(subset=final_features + ["PATIENT COST"])

# Assemble into feature vector
assembler = VectorAssembler(inputCols=final_features, outputCol="final_features")
encounters = assembler.transform(encounters)


In [52]:
from pyspark.ml.regression import LinearRegression

models = {}
topics = encounters.select("topic").distinct().rdd.flatMap(lambda x: x).collect()

for t in topics:
    topic_df = encounters.filter(col("topic") == t)
    lr = LinearRegression(featuresCol="final_features", labelCol="PATIENT COST",regParam=0.1) #might not need regParam with a bigger dataset 
    model = lr.fit(topic_df)
    models[t] = model


                                                                                

In [53]:
from functools import reduce
from pyspark.sql import DataFrame

def unionAll(*dfs):
    return reduce(DataFrame.unionByName, dfs)

predictions_list = []

for t, model in models.items():
    topic_df = encounters.filter(col("topic") == t)
    preds = model.transform(topic_df)
    predictions_list.append(preds)

final_predictions = unionAll(*predictions_list)
final_predictions.select("patient_id","PATIENT COST", "prediction", "topic").show(10, truncate=False)


+------------------------------------+------------------+------------------+-----+
|patient_id                          |PATIENT COST      |prediction        |topic|
+------------------------------------+------------------+------------------+-----+
|4113255f-4e35-506a-ddef-4429caa17ffc|85.55             |87.16762483511698 |1    |
|4113255f-4e35-506a-ddef-4429caa17ffc|180.79999999999995|181.01157733737256|1    |
|4113255f-4e35-506a-ddef-4429caa17ffc|142.58            |140.90154432465138|1    |
|4113255f-4e35-506a-ddef-4429caa17ffc|948.35            |946.9532141920059 |1    |
|92675303-ca5b-136a-169b-e764c5753f06|871.06            |872.57503113922   |1    |
|92675303-ca5b-136a-169b-e764c5753f06|808.91            |812.0693849090785 |1    |
|4113255f-4e35-506a-ddef-4429caa17ffc|142.58            |140.92146436156523|1    |
|4113255f-4e35-506a-ddef-4429caa17ffc|85.55             |85.33199619772024 |1    |
|92675303-ca5b-136a-169b-e764c5753f06|3536.75           |3536.8345181877817|1    |
|926

                                                                                

In [54]:
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator(
    labelCol="PATIENT COST",
    predictionCol="prediction",
    metricName="rmse"  
)

rmse = evaluator.evaluate(final_predictions)
print(f"✅ RMSE: {rmse:.2f}")




✅ RMSE: 1.72


                                                                                

### LDA using encounter description + demographic information
#### actually maybe i shouldn't merge them into 1 text, it doesnt work

In [56]:
from pyspark.sql.functions import concat_ws, col, lower, regexp_replace

# Create demographics text
encounters = encounters.withColumn(
    "demographic_text",
    concat_ws(" ", "patient_gender", "patient_race", "patient_ethnicity", "patient_marital","age_at_encounter")
)

# Combine with encounter_description
encounters = encounters.withColumn(
    "full_text",
    concat_ws(" ", "encounter_description", "demographic_text")
)

# Clean it
encounters = encounters.withColumn(
    "clean_text_demo",
    lower(regexp_replace(col("full_text"), "[^a-zA-Z\\s]", ""))
)


In [57]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover, CountVectorizer
from pyspark.ml.clustering import LDA

# Tokenize
tokenizer = Tokenizer(inputCol="clean_text_demo", outputCol="words_demo")
words_df = tokenizer.transform(encounters)

# Remove stopwords
stopwords = StopWordsRemover.loadDefaultStopWords("english") + [
    "patient", "doctor", "visit", "care", "provider", "hospital", "room", "procedure"
]
remover = StopWordsRemover(inputCol="words_demo", outputCol="filtered_demo", stopWords=stopwords)
filtered_df = remover.transform(words_df)

# Vectorize
vectorizer_demo = CountVectorizer(inputCol="filtered_demo", outputCol="features_demo")
vector_model_demo = vectorizer_demo.fit(filtered_df)
vectorized_df = vector_model_demo.transform(filtered_df)

# Train LDA
lda_demo = LDA(k=5, maxIter=10, featuresCol="features_demo",topicDistributionCol="topicDistribution_demo")
lda_model_demo = lda_demo.fit(vectorized_df)

# Add topicDistribution and topic
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
import numpy as np

vectorized_df = lda_model_demo.transform(vectorized_df)
get_topic = udf(lambda v: int(np.argmax(v)), IntegerType())
vectorized_df = vectorized_df.withColumn("topic_demo", get_topic(col("topicDistribution_demo")))


                                                                                

In [61]:
# Show Topics
topics_demo = lda_model_demo.describeTopics(5)
vocab_demo = vector_model_demo.vocabulary
topics_with_words_demo = topics_demo.rdd.map(
    lambda row: (row['topic'], [vocab_demo[i] for i in row['termIndices']])
).toDF(["topic", "top_words"])

topics_with_words_demo.show(truncate=False)

                                                                                

+-----+----------------------------------------------------+
|topic|top_words                                           |
+-----+----------------------------------------------------+
|0    |[unit, environment, check, department, death]       |
|1    |[m, white, general, examination, nonhispanic]       |
|2    |[black, f, nonhispanic, m, regimetherapy]           |
|3    |[m, white, encounter, check, nonhispanic]           |
|4    |[active, administration, immunity, vaccine, produce]|
+-----+----------------------------------------------------+



In [62]:
final_features_demo = numeric_columns + [col + "_index" for col in categorical_columns] + ["topic_demo"]

encounters_demo = vectorized_df.na.drop(subset=final_features_demo + ["PATIENT COST"])

# Assemble feature vector
from pyspark.ml.feature import VectorAssembler

assembler_demo = VectorAssembler(inputCols=final_features_demo, outputCol="final_features_demo")
encounters_demo = assembler_demo.transform(encounters_demo)

In [63]:
from pyspark.ml.regression import LinearRegression

models_demo = {}
topics_demo = encounters_demo.select("topic_demo").distinct().rdd.flatMap(lambda x: x).collect()

for t in topics_demo:
    topic_df = encounters_demo.filter(col("topic_demo") == t)
    lr = LinearRegression(featuresCol="final_features_demo", labelCol="PATIENT COST", regParam=0.1)
    model = lr.fit(topic_df)
    models_demo[t] = model


                                                                                

In [36]:
from functools import reduce
from pyspark.sql import DataFrame

def unionAll(*dfs):
    return reduce(DataFrame.unionByName, dfs)

predictions_demo_list = []

for t, model in models_demo.items():
    topic_df = encounters_demo.filter(col("topic_demo") == t)
    preds = model.transform(topic_df)
    predictions_demo_list.append(preds)

final_predictions_demo = unionAll(*predictions_demo_list)
final_predictions_demo.select("patient_id","PATIENT COST", "prediction", "topic_demo").show(10, truncate=False)


[Stage 401:>                                                        (0 + 1) / 1]

+------------------------------------+------------+------------------+----------+
|patient_id                          |PATIENT COST|prediction        |topic_demo|
+------------------------------------+------------+------------------+----------+
|92675303-ca5b-136a-169b-e764c5753f06|988.17      |987.1867254089184 |1         |
|4113255f-4e35-506a-ddef-4429caa17ffc|0.0         |8.011062295192914 |1         |
|4113255f-4e35-506a-ddef-4429caa17ffc|237.36      |243.266052491055  |1         |
|4113255f-4e35-506a-ddef-4429caa17ffc|919.9       |917.0025606007193 |1         |
|92675303-ca5b-136a-169b-e764c5753f06|917.0       |917.0193329347114 |1         |
|92675303-ca5b-136a-169b-e764c5753f06|871.06      |873.2203742749697 |1         |
|92675303-ca5b-136a-169b-e764c5753f06|1161.11     |1159.3679832971934|1         |
|92675303-ca5b-136a-169b-e764c5753f06|921.58      |922.7738067665181 |1         |
|92675303-ca5b-136a-169b-e764c5753f06|1057.58     |1057.444488004062 |1         |
|92675303-ca5b-1

                                                                                

In [37]:
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator(
    labelCol="PATIENT COST",
    predictionCol="prediction",
    metricName="rmse"  
)

rmse_demo = evaluator.evaluate(final_predictions_demo)
print(f"✅ RMSE_demo: {rmse_demo:.2f}")




✅ RMSE_demo: 2.17


                                                                                

### LDA using encounter description + demographic information + procedure information

In [68]:
from pyspark.sql.functions import concat_ws, col, lower, regexp_replace

# Combine with procedure_description with encounter_description
encounters = encounters.withColumn(
    "enc_proc",
    concat_ws(" ", "encounter_description", "procedure_descriptions")
)

# Clean it
encounters = encounters.withColumn(
    "clean_encproc",
    lower(regexp_replace(col("enc_proc"), "[^a-zA-Z\\s]", ""))
)


In [70]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover, CountVectorizer
from pyspark.ml.clustering import LDA

# Tokenize
tokenizer = Tokenizer(inputCol="clean_encproc", outputCol="words_encproc")
words_df = tokenizer.transform(encounters)

# Remove stopwords
stopwords = StopWordsRemover.loadDefaultStopWords("english") + [
    "patient", "doctor", "visit", "care", "provider", "hospital", "room", "procedure"
]
remover = StopWordsRemover(inputCol="words_encproc", outputCol="filtered_encproc", stopWords=stopwords)
filtered_df = remover.transform(words_df)

# Vectorize
vectorizer_encproc = CountVectorizer(inputCol="filtered_encproc", outputCol="features_encproc")
vector_model_encproc = vectorizer_encproc.fit(filtered_df)
vectorized_df = vector_model_encproc.transform(filtered_df)

# Train LDA
lda_encproc = LDA(k=5, maxIter=10, featuresCol="features_encproc",topicDistributionCol="topicDistribution_encproc")
lda_model_encproc = lda_encproc.fit(vectorized_df)

# Add topicDistribution and topic
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
import numpy as np

vectorized_df = lda_model_encproc.transform(vectorized_df)
get_topic = udf(lambda v: int(np.argmax(v)), IntegerType())
vectorized_df = vectorized_df.withColumn("topic_encproc", get_topic(col("topicDistribution_encproc")))

                                                                                

In [71]:
# Show Topics
topics_encproc = lda_model_encproc.describeTopics(5)
vocab_encproc = vector_model_encproc.vocabulary
topics_with_words_encproc = topics_encproc.rdd.map(
    lambda row: (row['topic'], [vocab_encproc[i] for i in row['termIndices']])
).toDF(["topic", "top_words"])

topics_with_words_encproc.show(truncate=False)

+-----+-------------------------------------------------------------+
|topic|top_words                                                    |
+-----+-------------------------------------------------------------+
|0    |[contraceptive, subcutaneous, insertion, bilateral, ligation]|
|1    |[administration, produce, regimetherapy, prenatal, immunity] |
|2    |[encounter, check, dental, using, removal]                   |
|3    |[general, examination, assessment, health, needs]            |
|4    |[screening, examination, general, depression, admission]     |
+-----+-------------------------------------------------------------+



In [72]:
final_features_encproc = numeric_columns + [col + "_index" for col in categorical_columns] + ["topic_encproc"]

encounters_encproc = vectorized_df.na.drop(subset=final_features_encproc + ["PATIENT COST"])

# Assemble feature vector
from pyspark.ml.feature import VectorAssembler

assembler_encproc = VectorAssembler(inputCols=final_features_encproc, outputCol="final_features_encproc")
encounters_encproc = assembler_encproc.transform(encounters_encproc)

In [73]:
from pyspark.ml.regression import LinearRegression

models_encproc = {}
topics_encproc = encounters_encproc.select("topic_encproc").distinct().rdd.flatMap(lambda x: x).collect()

for t in topics_encproc:
    topic_df = encounters_encproc.filter(col("topic_encproc") == t)
    lr = LinearRegression(featuresCol="final_features_encproc", labelCol="PATIENT COST", regParam=0.1)
    model = lr.fit(topic_df)
    models_encproc[t] = model

                                                                                

In [75]:
from functools import reduce
from pyspark.sql import DataFrame

def unionAll(*dfs):
    return reduce(DataFrame.unionByName, dfs)

predictions_encproc_list = []

for t, model in models_encproc.items():
    topic_df = encounters_encproc.filter(col("topic_encproc") == t)
    preds = model.transform(topic_df)
    predictions_encproc_list.append(preds)

final_predictions_encproc = unionAll(*predictions_encproc_list)
final_predictions_encproc.select("patient_id","PATIENT COST", "prediction", "topic_encproc").show(10, truncate=False)


[Stage 904:>                                                        (0 + 1) / 1]

+------------------------------------+------------+------------------+-------------+
|patient_id                          |PATIENT COST|prediction        |topic_encproc|
+------------------------------------+------------+------------------+-------------+
|4113255f-4e35-506a-ddef-4429caa17ffc|278.58      |278.55545603768104|1            |
|4113255f-4e35-506a-ddef-4429caa17ffc|224.48      |224.51680304515932|1            |
|4113255f-4e35-506a-ddef-4429caa17ffc|19615.09    |19614.387558020153|1            |
|4113255f-4e35-506a-ddef-4429caa17ffc|19615.09    |19614.387558020153|1            |
|4113255f-4e35-506a-ddef-4429caa17ffc|19615.09    |19614.387558020153|1            |
|4113255f-4e35-506a-ddef-4429caa17ffc|19615.09    |19614.387558020153|1            |
|4113255f-4e35-506a-ddef-4429caa17ffc|19615.09    |19614.387558020153|1            |
|4113255f-4e35-506a-ddef-4429caa17ffc|19615.09    |19614.387558020153|1            |
|4113255f-4e35-506a-ddef-4429caa17ffc|19615.09    |19614.38755802

                                                                                

In [76]:
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator(
    labelCol="PATIENT COST",
    predictionCol="prediction",
    metricName="rmse"  
)

rmse_encproc = evaluator.evaluate(final_predictions_encproc)
print(f"✅ RMSE_encproc: {rmse_encproc:.2f}")




✅ RMSE_encproc: 1.92


                                                                                