In [0]:
%pip install openpyxl
%pip install lifelines

In [0]:
import sys
import os
from pyspark.sql.functions import mean, stddev, expr, percentile_approx, round
from pyspark.sql.functions import datediff, expr, when, col, month, year, avg
from pyspark.sql.functions import count, round, sum
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test
from pyspark.sql import functions as F

sys.path.append("../src")
from utils import *

In [0]:
cohort_1_link="../Dummy_Data/Cohort_1_synth.xlsx"
cohort_union_df_link = "../Dummy_Data/Cohort_Synt1.xlsx"

In [0]:
cohort_union_df= pd.read_excel(cohort_union_df_link)

In [0]:
df_wlmds= pd.read_excel(cohort_1_link)

In [0]:
tfc_df = spark.createDataFrame(data, schema=schema)

df_wlmds = df_wlmds.join(tfc_df, df_wlmds.wlmds_treatment_function_code == tfc_df.TFC)

In [0]:
# Filter df_wlmds for non-null REG_DATE_OF_DEATH
df_wlmds = df_wlmds.filter(
    (df_wlmds.epp_rtt_start_date > '2022-03-31') &
    (df_wlmds.epp_rtt_end_date < '2024-04-01')
)

In [0]:
df_wlmds = df_wlmds.withColumn(
    "waiting_time_days",
    datediff(df_wlmds.epp_rtt_end_date, df_wlmds.epp_rtt_start_date)
).filter(col("waiting_time_days")>0)
df_wlmds = df_wlmds.withColumn(
    "waiting_bucket",
    when(df_wlmds.waiting_time_days < 84, "< 12 weeks")
    .when(df_wlmds.waiting_time_days.between(84, 126), "12-18 weeks")
    .when(df_wlmds.waiting_time_days.between(127, 245), "19-36 weeks")
    .when(df_wlmds.waiting_time_days.between(246, 364), "37-52 weeks")
    .when(df_wlmds.waiting_time_days > 364, "> 52 weeks")
)

# Time to First Contact

In [0]:
ttfc = df_wlmds.filter(col("epp_source").startswith("wlmds_sus"))
display(ttfc.groupBy("epp_pid", "epp_tfc", "epp_rtt_start_date").count().count())
#display(ttfc)
ttfc = ttfc.filter((col("epp_rtt_period_sequence") == 1) & (col("epp_sequenced_activity_type") == 'first_op'))

# Calculate days between epp_rtt_start_date and epp_activity_date
ttfc = ttfc.withColumn("days_between", datediff(col("epp_activity_date"), col("epp_rtt_start_date")))

# Group by Specialty and calculate average and median days_between
avg_median_days_per_specialty = ttfc.groupBy("Specialty").agg(
    expr("avg(days_between) as avg_days_between"),
    expr("percentile_approx(days_between, 0.5) as median_days_between")
).orderBy("avg_days_between", ascending=False)

# Convert to Pandas DataFrame for plotting
pdf_avg_median_days = avg_median_days_per_specialty.toPandas()

# Plotting
plt.figure(figsize=(10, 6))
plt.bar(pdf_avg_median_days['Specialty'], pdf_avg_median_days['avg_days_between'], color='skyblue', label='Average Days')
plt.bar(pdf_avg_median_days['Specialty'], pdf_avg_median_days['median_days_between'], color='orange', alpha=0.7, label='Median Days')
plt.xlabel('Specialty')
plt.ylabel('Days Between')
plt.title('Average and Median Days Between epp_rtt_start_date and epp_activity_date per Specialty')
plt.xticks(rotation=45, ha='right')
plt.legend()
plt.tight_layout()

plt.show()

In [0]:
# Descriptive Statistics
desc_stats = ttfc.groupBy("Specialty").agg(
    round(mean("days_between")).alias("mean_days_between"),
    round(percentile_approx("days_between", 0.5)).alias("median_days_between"),
    round(stddev("days_between")).alias("stddev_days_between"),
    round(expr("percentile_approx(days_between, 0.75) - percentile_approx(days_between, 0.25)")).alias("iqr_days_between")
)

display(desc_stats)

In [0]:
display(desc_stats.toPandas().to_csv(index=False))

In [0]:
ttfc = df_wlmds.filter(col("epp_source").startswith("wlmds_sus"))
display(ttfc.groupBy("epp_pid", "epp_tfc", "epp_rtt_start_date").count().count())
ttfc = ttfc.filter((col("epp_rtt_period_sequence") == 1) & (col("epp_sequenced_activity_type") == 'first_op'))

# Calculate days between epp_rtt_start_date and epp_activity_date
ttfc = ttfc.withColumn("days_between", datediff(col("epp_activity_date"), col("epp_rtt_start_date")))

# Group by Specialty and epp_priority, and calculate average and median days_between
avg_median_days_per_specialty_priority = ttfc.groupBy("Specialty", "epp_referral_priority").agg(
    expr("avg(days_between) as avg_days_between"),
    expr("percentile_approx(days_between, 0.5) as median_days_between")
).orderBy("avg_days_between", ascending=False)

# Convert to Pandas DataFrame for plotting
pdf_avg_median_days_priority = avg_median_days_per_specialty_priority.toPandas()

# Plotting
plt.figure(figsize=(12, 8))
priorities = pdf_avg_median_days_priority['epp_referral_priority'].unique()
for priority in priorities:
    subset = pdf_avg_median_days_priority[pdf_avg_median_days_priority['epp_referral_priority'] == priority]
    plt.figure(figsize=(12, 8))
    plt.bar(subset['Specialty'], subset['avg_days_between'], alpha=0.7, label=f'Average Days - Priority {priority}')
    plt.bar(subset['Specialty'], subset['median_days_between'], alpha=0.7, label=f'Median Days - Priority {priority}')
    plt.xlabel('Specialty')
    plt.ylabel('Days Between')
    plt.title(f'Average and Median Days Between epp_rtt_start_date and epp_activity_date per Specialty - Priority {priority}')
    plt.xticks(rotation=45, ha='right')
    plt.legend()
    plt.tight_layout()
    plt.show()

# Deaths

In [0]:
df_wlmds_death = df_wlmds.filter(
    (df_wlmds.REG_DATE_OF_DEATH.isNotNull()) &
    (df_wlmds.REG_DATE_OF_DEATH.between(df_wlmds.epp_rtt_start_date, df_wlmds.epp_rtt_end_date))
)

total=df_wlmds.groupBy("epp_pid","epp_tfc","epp_rtt_start_date", "epp_rtt_end_date",  "Specialty", "waiting_bucket", "waiting_time_days", "epp_referral_priority").count()
death_grouped_df =  df_wlmds_death.groupBy("epp_pid","epp_tfc","epp_rtt_start_date","epp_rtt_end_date", "Specialty", "REG_DATE_OF_DEATH", "waiting_bucket", "waiting_time_days", "epp_referral_priority").count()


death_grouped_df = death_grouped_df.withColumn(
    "cohort_band_conc",
    when((col("epp_rtt_start_date") >= "2022-01-01") & (col("epp_rtt_start_date") <= "2022-03-31"), "mar22")
    .when((col("epp_rtt_start_date") >= "2022-04-01") & (col("epp_rtt_start_date") <= "2022-06-30"), "jun22")
    .when((col("epp_rtt_start_date") >= "2022-07-01") & (col("epp_rtt_start_date") <= "2022-09-30"), "sep22")
    .when((col("epp_rtt_start_date") >= "2022-10-01") & (col("epp_rtt_start_date") <= "2022-12-31"), "dec22")
    .when((col("epp_rtt_start_date") >= "2023-01-01") & (col("epp_rtt_start_date") <= "2023-03-31"), "mar23")
    .when((col("epp_rtt_start_date") >= "2023-04-01") & (col("epp_rtt_start_date") <= "2023-06-30"), "jun23")
    .when((col("epp_rtt_start_date") >= "2023-07-01") & (col("epp_rtt_start_date") <= "2023-09-30"), "sep23")
    .when((col("epp_rtt_start_date") >= "2023-10-01") & (col("epp_rtt_start_date") <= "2023-12-31"), "dec23")
    .when((col("epp_rtt_start_date") >= "2024-01-01") & (col("epp_rtt_start_date") <= "2024-03-31"), "mar24")
)
death_grouped_df = death_grouped_df.filter(col("cohort_band_conc").isNotNull())

In [0]:
df_conc_linked = death_grouped_df.join(
    cohort_union_df,
    (
        (death_grouped_df.cohort_band_conc == cohort_union_df.cohort_band) &
        (death_grouped_df.epp_pid == cohort_union_df.Patient_ID)
    ),
    "left"
)

In [0]:
df_conc_linked.printSchema()

In [0]:
display(death_grouped_df)

In [0]:
# Count total records per Specialty
total_count_per_specialty = total.groupBy("Specialty").count()

# Count records with non-null REG_DATE_OF_DEATH per Specialty
death_count_per_specialty = death_grouped_df.groupBy("Specialty").count()

# Calculate percentage of death per Specialty
percentage_death_per_specialty = death_count_per_specialty.join(
    total_count_per_specialty,
    on="Specialty"
).withColumn(
    "death_percentage",
    round((death_count_per_specialty["count"] / total_count_per_specialty["count"]) * 100, 2)
).select("Specialty", "death_percentage").orderBy("death_percentage", ascending=False)

# Convert to Pandas DataFrame for plotting
pdf = percentage_death_per_specialty.toPandas()

# Plotting
plt.figure(figsize=(10, 6))
plt.bar(pdf['Specialty'], pdf['death_percentage'], color='skyblue')
plt.xlabel('Specialty')
plt.ylabel('Death Percentage')
plt.title('Death Percentage per Specialty')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()

# Annotate bars with death percentage values
for index, value in enumerate(pdf['death_percentage']):
    plt.text(index, value, str(value), ha='center', va='bottom')

plt.show()

In [0]:
# Calculate total death count per specialty
total_death_count_per_specialty = death_grouped_df.groupBy("Specialty").count().withColumnRenamed("count", "total_death_count")

# Calculate death count per waiting bucket and specialty
death_count_per_bucket_specialty = death_grouped_df.groupBy("waiting_bucket", "Specialty").count().withColumnRenamed("count", "death_count")

# Join total death count with death count per waiting bucket and specialty
percentage_death_per_bucket_specialty = death_count_per_bucket_specialty.join(
    total_death_count_per_specialty,
    on="Specialty"
).withColumn(
    "death_percentage_of_total",
    round((F.col("death_count") / F.col("total_death_count")) * 100, 2)
).select("waiting_bucket", "Specialty", "death_percentage_of_total").orderBy("Specialty", "waiting_bucket")

# Convert to Pandas DataFrame for plotting
pdf = percentage_death_per_bucket_specialty.toPandas()

# Pivot and plot
pdf_pivot = pdf.pivot(index='Specialty', columns='waiting_bucket', values='death_percentage_of_total')
pdf_pivot = pdf_pivot[['< 12 weeks', '12-18 weeks', '19-36 weeks', '37-52 weeks', '> 52 weeks']]
pdf_pivot.plot(kind='bar', stacked=True, figsize=(12, 8), colormap='tab20')

plt.xlabel('Specialty')
plt.ylabel('Death Percentage of Total')
#plt.title('Death Percentage of Total per Waiting Bucket for Each Specialty')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.legend(title='Waiting Bucket', bbox_to_anchor=(1.05, 1), loc='upper left')

plt.show()

In [0]:
#Calculate number of pathways per specialty and referral priority with death during waiting time
pathways_with_deaths = death_grouped_df.groupBy("Specialty", "epp_referral_priority").count().withColumnRenamed("count", "pathways_with_deaths")

# Calculate total number of pathways per specialty and referral priority
total_pathways = total.groupBy("Specialty", "epp_referral_priority").count().withColumnRenamed("count", "total_pathways")


# Join the two DataFrames
result = pathways_with_deaths.join(total_pathways, ["Specialty", "epp_referral_priority"])

# Calculate the percentage of pathways with deaths
result = result.withColumn(
    'percentage_pathways_with_deaths',
    (F.col('pathways_with_deaths') / F.col('total_pathways')) * 100
)

# Convert to Pandas DataFrame for analysis
result_pd = result.orderBy(F.col('percentage_pathways_with_deaths').desc()).toPandas()

# Plot the results as grouped bar chart
result_pivot = result_pd.pivot(index='Specialty', columns='epp_referral_priority', values='percentage_pathways_with_deaths')
result_pivot.plot(kind='bar', figsize=(12, 8))

plt.xlabel('Specialty')
plt.ylabel('Percentage of Pathways with Deaths')
#plt.title('Percentage of Pathways with Deaths by Specialty and Referral Priority')
plt.legend(title='Referral Priority')
plt.xticks(rotation=90)
plt.show()

In [0]:
# Calculate time to death as the difference between registration date and start date
death_grouped_df = death_grouped_df.withColumn("time_to_death", datediff("REG_DATE_OF_DEATH", "epp_rtt_start_date"))

# Perform advanced analysis: summary statistics for time to death
summary_stats = death_grouped_df.select("time_to_death").summary()

# Display summary statistics
display(summary_stats)

# Advanced analysis: distribution of time to death by specialty
time_to_death_by_specialty = death_grouped_df.groupBy("Specialty").agg(
    F.mean("time_to_death").alias("mean_time_to_death"),
    F.expr("percentile(time_to_death, 0.5)").alias("median_time_to_death"),
    F.stddev("time_to_death").alias("stddev_time_to_death"),
    F.min("time_to_death").alias("min_time_to_death"),
    F.max("time_to_death").alias("max_time_to_death")
).orderBy("Specialty")

# Display time to death by specialty
display(time_to_death_by_specialty)
time_to_death_by_specialty.show()

# Convert to Pandas DataFrame for plotting
time_to_death_pd = time_to_death_by_specialty.toPandas()

# Plot the results
plt.figure(figsize=(12, 8))
sns.boxplot(x="Specialty", y="time_to_death", data=death_grouped_df.toPandas())
plt.xticks(rotation=45, ha='right')
plt.xlabel('Specialty')
plt.ylabel('Time to Death (days)')
plt.title('Distribution of Time to Death by Specialty')
plt.tight_layout()
plt.show()

In [0]:
df = death_grouped_df

# Step 1: Calculate time to death (in days)
df = df.withColumn('time_to_death', datediff(col('REG_DATE_OF_DEATH'), col('epp_rtt_start_date')))
df = df.withColumn('death_status', when(col('REG_DATE_OF_DEATH').isNotNull(), 1).otherwise(0))

# Ensure 'death_status' column exists
if 'death_status' not in df.columns:
    raise KeyError("Column 'death_status' does not exist in the DataFrame")

# Convert Spark DataFrame to Pandas DataFrame for further analysis
df_pd = df.toPandas()

# Step 2: Descriptive statistics of time to death grouped by Specialty and Referral Priority
time_to_death_stats = df_pd.groupby(['Specialty', 'epp_referral_priority'])['time_to_death'].describe()
print(time_to_death_stats.to_string())

# Step 3: Plot distribution of time to death by Specialty and Referral Priority
plt.figure(figsize=(10,6))
sns.boxplot(x='Specialty', y='time_to_death', data=df_pd, hue='epp_referral_priority')
plt.xticks(rotation=90)
plt.title('Time to Death by Specialty and Referral Priority')
plt.show()

# Step 3: Plot distribution of time to death by waiting bucket and Referral Priority
waiting_bucket_order = ['< 12 weeks', '12-18 weeks',  '19-36 weeks', '37-52 weeks', '> 52 weeks']
df_pd['waiting_bucket'] = pd.Categorical(df_pd['waiting_bucket'], categories=waiting_bucket_order, ordered=True)

plt.figure(figsize=(10,6))
sns.boxplot(x='waiting_bucket', y='time_to_death', data=df_pd, hue='epp_referral_priority')
plt.xticks(rotation=90)
plt.title('Time to Death by Waiting Bucket and Referral Priority')
plt.show()

# Step 4: Kaplan-Meier survival analysis
kmf = KaplanMeierFitter()
kmf.fit(df_pd['time_to_death'], event_observed=df_pd['death_status'])

# Plot survival curve
kmf.plot()
plt.title('Survival Analysis: Time to Death')
plt.show()

# Step 5: Statistical test (log-rank test) to compare survival between two groups (e.g., 'urgent' vs. 'routine')
group1 = df_pd[df_pd['epp_referral_priority'] == 'urgent']
group2 = df_pd[df_pd['epp_referral_priority'] == 'routine']
results = logrank_test(group1['time_to_death'], group2['time_to_death'], event_observed_A=group1['death_status'], event_observed_B=group2['death_status'])
print("Log-rank test results:", results)

# Time to death 
df_pd = df.select("waiting_time_days", "death_status", "epp_referral_priority", "time_to_death", "waiting_bucket").toPandas()
kmf = KaplanMeierFitter()
kmf.fit(df_pd["waiting_time_days"], event_observed=df_pd['death_status'])
kmf.plot_survival_function()

# Convert data to pandas
df_pd = df.select("waiting_time_days", "death_status", "epp_referral_priority", "time_to_death", "Specialty", "waiting_bucket").toPandas()

# Plot survival for each specialty
plt.figure(figsize=(12, 8))
for specialty in df_pd['waiting_bucket'].unique():
    specialty_df = df_pd[df_pd['waiting_bucket'] == specialty]
    kmf = KaplanMeierFitter()
    kmf.fit(specialty_df["waiting_time_days"], event_observed=specialty_df['death_status'])
    kmf.plot_survival_function(label=specialty)

plt.xlabel('Waiting Time (Days)')
plt.ylabel('Survival Probability')
plt.legend(title='Specialty')
plt.show()

from lifelines import CoxPHFitter

cph = CoxPHFitter()
cph.fit(df_pd, duration_col='waiting_time_days', event_col='death_status', formula="C(waiting_bucket)") # C() treats waiting_bucket as categorical

print(cph.print_summary()) # Prints hazard ratios, confidence intervals, p-values

from pyspark.ml.stat import Correlation
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(inputCols=["waiting_time_days"], outputCol="features")
df_vector = assembler.transform(df).select("features", "time_to_death")

correlation = Correlation.corr(df_vector, "features", "spearman")

In [0]:
# List of categories to calculate percentage for
categories = ["AgeBand_5yr", "Sex", "Ethnic_Group", "IMD_Decile"]

# Dictionary to store pivoted tables
pivot_tables = {}

# Loop through each category and calculate percentages
for cat in categories:
    # Group and count by waiting_bucket and category
    df_conc_linked_g = df_conc_linked.groupBy(cat, "waiting_bucket").agg(count("*").alias("count"))
    
    # Total count per category for percentage calculation
    df_total = df_conc_linked_g.groupBy(cat).agg(sum("count").alias("total_count"))

    # Join with total count and calculate percentage
    df_percentage = df_conc_linked_g.join(
        df_total,
        cat,
        "left"
    ).withColumn(
        "percentage",
        round((col("count") / col("total_count")) * 100, 2)
    )

    # Pivot to make waiting_bucket as columns
    df_pivot = df_percentage.groupBy(cat).pivot("waiting_bucket").agg(F.first("percentage"))
    
    # Fill missing values with 0
    df_pivot = df_pivot.fillna(0)

    # Add pivoted table to dictionary
    pivot_tables[cat] = df_pivot



display(pivot_tables["AgeBand_5yr"])

    

In [0]:
# Define waiting buckets order (for consistent ordering in plots)
bucket_order = ["< 12 weeks", "12-18 weeks", "19-36 weeks", "37-52 weeks", "> 52 weeks"]

# Plot function for each category
def plot_stacked_bar(df, category):
    df[bucket_order].plot(
        kind='bar',
        stacked=True,
        figsize=(12, 6),
        colormap='viridis'
    )
    plt.title(f"Percentage Distribution by {category}")
    plt.xlabel(category)
    plt.ylabel("Percentage (%)")
    plt.xticks(rotation=45, ha="right")
    plt.legend(title="Waiting Bucket", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

pivot_dict_pd = {}
for cat in pivot_tables:    
    pivot_dict_pd[cat] = pivot_tables[cat].toPandas().set_index(cat)

# Plot for each category
for cat in pivot_dict_pd:
    plot_stacked_bar(pivot_dict_pd[cat], cat)

# Cancelation

In [0]:
df_wlmds= spark.read.format("parquet").load(cohort_1_link")
tfc_df = spark.createDataFrame(data, schema=schema)
df_wlmds = df_wlmds.join(tfc_df, df_wlmds.wlmds_treatment_function_code == tfc_df.TFC)


In [0]:
sel_wlmds = df_wlmds.filter(
    col("epp_source").contains("wlmds_")
).filter(
    (col("epp_rtt_start_date") >= '2022-04-01') & (col("epp_rtt_end_date") < '2024-03-31')
)
display(sel_wlmds)

In [0]:
# Step 1: Add columns for DNA and Cancellation status
sel_wlmds = sel_wlmds.withColumn('is_dna', when(col('sus_attendance_status').contains('dna'), 1).otherwise(0))
sel_wlmds = sel_wlmds.withColumn('is_patient_canc', when(col('sus_attendance_status').contains('cancel(pat)'), 1).otherwise(0))
sel_wlmds = sel_wlmds.withColumn('is_hospital_canc', when(col('sus_attendance_status').contains('cancel(hos)'), 1).otherwise(0))

# Step 2: Group by Specialty and calculate the number of DNAs, Cancellations (both patient and hospital), and unique Pathways
result = sel_wlmds.groupBy('Specialty').agg(
    # Count DNAs
    countDistinct(when(col('is_dna') == 1, col('epp_pathway_id'))).alias('dna_count'),
    
    # Count Patient Cancellations
    countDistinct(when(col('is_patient_canc') == 1, col('epp_pathway_id'))).alias('patient_canc_count'),
    
    # Count Hospital Cancellations
    countDistinct(when(col('is_hospital_canc') == 1, col('epp_pathway_id'))).alias('hospital_canc_count'),
    
    # Count unique Pathways
    countDistinct('epp_pathway_id').alias('unique_pathways')
)

# Calculate percentages
result = result.withColumn('dna_percentage', (col('dna_count') / col('unique_pathways')) * 100)
result = result.withColumn('patient_canc_percentage', (col('patient_canc_count') / col('unique_pathways')) * 100)
result = result.withColumn('hospital_canc_percentage', (col('hospital_canc_count') / col('unique_pathways')) * 100)

# Convert to Pandas DataFrame for plotting
pdf_result = result.toPandas()

# Plotting
plt.figure(figsize=(14, 7))
pdf_result.plot(kind='barh', x='Specialty', y=['dna_percentage', 'patient_canc_percentage', 'hospital_canc_percentage'], stacked=False)
plt.ylabel('Specialty')
plt.xlabel('Percentage')
#plt.title('DNA and Cancellation Percentages by Specialty')
plt.tight_layout()
plt.show()

In [0]:
# Extract month and year from the attendance date
sel_wlmds = sel_wlmds.withColumn('year', year(col('sus_attendance_date_time')))
sel_wlmds = sel_wlmds.withColumn('month', month(col('sus_attendance_date_time')))

# Group by year, month, and Specialty to get counts of DNAs and cancellations
result_by_time = sel_wlmds.groupBy('Specialty', 'year', 'month').agg(
    countDistinct(when(col('is_dna') == 1, col('epp_pathway_id'))).alias('dna_count'),
    countDistinct(when(col('is_patient_canc') == 1, col('epp_pathway_id'))).alias('patient_canc_count'),
    countDistinct(when(col('is_hospital_canc') == 1, col('epp_pathway_id'))).alias('hospital_canc_count')
)

result_by_time.show()

In [0]:
cancellation_rate = sel_wlmds.groupBy('Specialty', 'year', 'month').agg(
    countDistinct('epp_pathway_id').alias('total_appointments'),
    countDistinct(when(col('is_patient_canc') == 1, col('epp_pathway_id'))).alias('patient_canc_count'),
    countDistinct(when(col('is_hospital_canc') == 1, col('epp_pathway_id'))).alias('hospital_canc_count')
)

# Calculate cancellation rates for patient and hospital cancellations
cancellation_rate = cancellation_rate.withColumn(
    'patient_canc_rate', (col('patient_canc_count') / col('total_appointments')) * 100
).withColumn(
    'hospital_canc_rate', (col('hospital_canc_count') / col('total_appointments')) * 100
)

cancellation_rate.show()

In [0]:
# Explore correlation between referral priority and cancellation rates
correlation_analysis = sel_wlmds.groupBy('Specialty').agg(
    avg('is_dna').alias('avg_dna_rate'),
    avg('is_patient_canc').alias('avg_patient_canc_rate'),
    avg('is_hospital_canc').alias('avg_hospital_canc_rate')
)

# Convert to Pandas DataFrame for plotting
correlation_analysis_pd = correlation_analysis.toPandas()

# Plotting
correlation_analysis_pd.plot(x='Specialty', kind='bar', figsize=(12, 8), colormap='tab20')
plt.xlabel('Specialty')
plt.ylabel('Average Rate')
plt.title('Average DNA, Patient Cancellation, and Hospital Cancellation Rates by Specialty')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.legend(title='Rate Type', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

In [0]:
# Add waiting_time column
sel_wlmds = sel_wlmds.withColumn(
    "waiting_time",
    datediff(col("wlmds_rtt_end_date"), col("wlmds_rtt_start_date"))
)

# Filter out rows with waiting_time < 0
sel_wlmds = sel_wlmds.filter(col("waiting_time") >= 0)

# Define waiting buckets
sel_wlmds = sel_wlmds.withColumn(
    "waiting_bucket",
    when((col("waiting_time") <= 18 *7), "< 18 weeks")
    .when((col("waiting_time") > 18*7) & (col("waiting_time") <= 36*7), "19-36 weeks")
    .when((col("waiting_time") > 36*7) & (col("waiting_time") <= 52*7), "37-52 weeks")
    .otherwise("> 52 weeks")
)

# Group by waiting_bucket and Specialty, and calculate total and cancellation counts
total_count_per_bucket_specialty = sel_wlmds.groupBy("waiting_bucket", "Specialty").count().withColumnRenamed("count", "total_count")
hospital_cancellation_count_per_bucket_specialty = sel_wlmds.filter(col("is_hospital_canc") == 1).groupBy("waiting_bucket", "Specialty").count().withColumnRenamed("count", "hospital_cancellation_count")
patient_cancellation_count_per_bucket_specialty = sel_wlmds.filter(col("is_patient_canc") == 1).groupBy("waiting_bucket", "Specialty").count().withColumnRenamed("count", "patient_cancellation_count")
dna_count_per_bucket_specialty = sel_wlmds.filter(col("is_dna") == 1).groupBy("waiting_bucket", "Specialty").count().withColumnRenamed("count", "dna_count")

# Join total, hospital cancellation, patient cancellation, and DNA counts
impact_analysis = total_count_per_bucket_specialty.join(
    hospital_cancellation_count_per_bucket_specialty,
    on=["waiting_bucket", "Specialty"],
    how="left"
).join(
    patient_cancellation_count_per_bucket_specialty,
    on=["waiting_bucket", "Specialty"],
    how="left"
).join(
    dna_count_per_bucket_specialty,
    on=["waiting_bucket", "Specialty"],
    how="left"
).withColumn(
    "hospital_cancellation_percentage",
    round((col("hospital_cancellation_count") / col("total_count")) * 100, 2)
).withColumn(
    "patient_cancellation_percentage",
    round((col("patient_cancellation_count") / col("total_count")) * 100, 2)
).withColumn(
    "dna_percentage",
    round((col("dna_count") / col("total_count")) * 100, 2)
).select("waiting_bucket", "Specialty", "hospital_cancellation_percentage", "patient_cancellation_percentage", "dna_percentage").orderBy("Specialty", "waiting_bucket")

# Convert to Pandas DataFrame for plotting
impact_analysis_pd = impact_analysis.toPandas()

# Pivot and plot
impact_analysis_pivot = impact_analysis_pd.pivot(index='Specialty', columns='waiting_bucket', values=['hospital_cancellation_percentage', 'patient_cancellation_percentage', 'dna_percentage'])
impact_analysis_pivot['hospital_cancellation_percentage'].plot(kind='bar', figsize=(12, 8), colormap='tab20', position=0, width=0.25, title='Hospital Cancellation Percentage')
impact_analysis_pivot['patient_cancellation_percentage'].plot(kind='bar', figsize=(12, 8), colormap='tab20', position=1, width=0.25, title='Patient Cancellation Percentage')
impact_analysis_pivot['dna_percentage'].plot(kind='bar', figsize=(12, 8), colormap='tab20', position=2, width=0.25, title='DNA Percentage')

plt.xlabel('Specialty')
plt.ylabel('Percentage')
plt.title('Impact Analysis per Waiting Bucket for Each Specialty')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.legend(title='Waiting Bucket', bbox_to_anchor=(1.05, 1), loc='upper left')

plt.show()

In [0]:
# Aggregate the data to ensure unique combinations of "Specialty", "year", and "month"
result_by_time_agg = result_by_time.groupBy("Specialty", "year", "month").agg(
    {"dna_count": "sum", "patient_canc_count": "sum", "hospital_canc_count": "sum"}
).withColumnRenamed("sum(dna_count)", "dna_count") \
 .withColumnRenamed("sum(patient_canc_count)", "patient_canc_count") \
 .withColumnRenamed("sum(hospital_canc_count)", "hospital_canc_count")

# Convert the Spark DataFrame to a Pandas DataFrame for plotting
result_by_time_pd = result_by_time_agg.toPandas()

# Convert columns to integer
result_by_time_pd["dna_count"] = result_by_time_pd["dna_count"].astype(int)
result_by_time_pd["patient_canc_count"] = result_by_time_pd["patient_canc_count"].astype(int)
result_by_time_pd["hospital_canc_count"] = result_by_time_pd["hospital_canc_count"].astype(int)

# Create a pivot table for a heatmap
pivot_dna = result_by_time_pd.pivot_table(index="Specialty", columns=["year", "month"], values="dna_count")
pivot_patient_canc = result_by_time_pd.pivot_table(index="Specialty", columns=["year", "month"], values="patient_canc_count")
pivot_hospital_canc = result_by_time_pd.pivot_table(index="Specialty", columns=["year", "month"], values="hospital_canc_count")

# Plot heatmaps
plt.figure(figsize=(12, 6))
sns.heatmap(pivot_dna, annot=True, fmt=".0f", cmap="Blues", cbar=False)
plt.title('DNA Count by Specialty, Year, and Month')
plt.show()

plt.figure(figsize=(12, 6))
sns.heatmap(pivot_patient_canc, annot=True, fmt=".0f", cmap="Oranges", cbar=False)
plt.title('Patient Cancellation Count by Specialty, Year, and Month')
plt.show()

plt.figure(figsize=(12, 6))
sns.heatmap(pivot_hospital_canc, annot=True, fmt=".0f", cmap="Reds", cbar=False)
plt.title('Hospital Cancellation Count by Specialty, Year, and Month')
plt.show()

# Print pivot tables as string
print(pivot_dna.to_string())
print(pivot_patient_canc.to_string())
print(pivot_hospital_canc.to_string())