In [0]:
import sys
import os
sys.path.append("../src")

In [0]:
from pyspark.sql import functions as F
from utils import *

In [0]:
cohort_1_link="../Dummy_Data/Cohort_1_synth.xlsx"
cohort_2_link="../Dummy_Data/Cohort_2_synth.xlsx"
cohort_3_link="../Dummy_Data/Cohort_3_synth.xlsx"

In [0]:
cohort2=pd.read_excel(cohort_2_link)
cohort2=cohort2.filter(col("wlmds_status") != "30")

In [0]:
display(cohort2.count())

In [0]:
unique_patients = cohort2.select("wlmds_patient_id").distinct().count()
display(unique_patients)

In [0]:
#check overlaping rtt pathways
overlapping_pathways = cohort2.alias("df1").join(
    cohort2.alias("df2"),
    (F.col("df1.wlmds_patient_id") == F.col("df2.wlmds_patient_id")) &
    (F.col("df1.wlmds_treatment_function_code") == F.col("df2.wlmds_treatment_function_code")) &
    (F.col("df1.wlmds_rtt_start_date_conc3") < F.col("df2.wlmds_rtt_end_date_conc4")) &
    (F.col("df1.wlmds_rtt_end_date_conc4") > F.col("df2.wlmds_rtt_start_date_conc3")) &
    (F.col("df1.wlmds_pathway_id") != F.col("df2.wlmds_pathway_id"))
).select("df1.wlmds_patient_id").distinct().count()

display(overlapping_pathways)

In [0]:
# Identify patients with overlapping pathways
overlapping_pathways = cohort2.alias("df1").join(
    cohort2.alias("df2"),
    (F.col("df1.wlmds_patient_id") == F.col("df2.wlmds_patient_id")) &
    (F.col("df1.wlmds_rtt_start_date_conc3") < F.col("df2.wlmds_rtt_end_date_conc4")) &
    (F.col("df1.wlmds_rtt_end_date_conc4") > F.col("df2.wlmds_rtt_start_date_conc3")) &
    (F.col("df1.wlmds_pathway_id") != F.col("df2.wlmds_pathway_id"))
).select("df1.wlmds_patient_id", "df1.ndl_wait_band","df1.ndl_age_band", "df1.ndl_imd_quantile", "df1.ndl_ethnicity", "df1.ndl_ltc", "df1.Sex", "df1.Frailty_level").distinct()

# Identify non-overlapping patients by excluding overlapping ones
non_overlapping_pathways = cohort2.join(
    overlapping_pathways, 
    on="wlmds_patient_id", 
    how="left_anti"
).select("wlmds_patient_id", "ndl_wait_band", "ndl_age_band", "ndl_imd_quantile", "ndl_ethnicity", "ndl_ltc", "Sex", "Frailty_level")

# Count patients in each waiting time band
overlapping_counts = overlapping_pathways.groupBy("ndl_wait_band").count()
non_overlapping_counts = non_overlapping_pathways.groupBy("ndl_wait_band").count()

# Get total counts for percentage calculation
total_overlapping = overlapping_pathways.count()
total_non_overlapping = non_overlapping_pathways.count()

# Calculate percentage of patients in each wait band
overlapping_percentages = overlapping_counts.withColumn("percentage", (F.col("count") / total_overlapping) * 100)
non_overlapping_percentages = non_overlapping_counts.withColumn("percentage", (F.col("count") / total_non_overlapping) * 100)

# Convert to Pandas for visualization
overlapping_df = overlapping_percentages.toPandas()
non_overlapping_df = non_overlapping_percentages.toPandas()

# Merge both datasets
merged_df = overlapping_df.merge(non_overlapping_df, on="ndl_wait_band", suffixes=("_overlapping", "_non_overlapping"))

# Sorting wait bands
wait_time_order = ['<= 18 weeks', '> 18 weeks', '> 36 weeks', '> 52 weeks']
merged_df["ndl_wait_band"] = pd.Categorical(merged_df["ndl_wait_band"], categories=wait_time_order, ordered=True)
merged_df = merged_df.sort_values("ndl_wait_band")

# Plot
import matplotlib.pyplot as plt
import numpy as np

x = np.arange(len(merged_df["ndl_wait_band"]))
width = 0.4

plt.figure(figsize=(10, 6))
plt.bar(x - width/2, merged_df["percentage_overlapping"], width, label='Overlapping Pathways', alpha=0.7)
plt.bar(x + width/2, merged_df["percentage_non_overlapping"], width, label='Non-Overlapping Pathways', alpha=0.7)

plt.xlabel('Waiting Time Band')
plt.ylabel('Percentage of Patients')
plt.title('Percentage of Patients in Each Waiting Time Band')
plt.xticks(ticks=x, labels=merged_df["ndl_wait_band"])
plt.legend()
plt.show()

In [0]:
def calculate_wait_band_distribution_anna2(df, input_cols):
    """
    Calculate the distribution of wait bands for the given input columns,
    including counts and percentages for each wait band value, and total counts
    for each input column value, renaming null values in the input columns to "unknown".
 
    Args:
        df (DataFrame): Input DataFrame.
        input_cols (list): List of column names for the input grouping.
 
    Returns:
        DataFrame: A DataFrame containing input column values, total counts,
                   and affixed columns with each wait band value: count and percentage.

    """

    personal_ch = df.groupBy(input_cols + ["wlmds_patient_id", "ndl_wait_band"]).agg(count("*").alias("count"))

    grouped_counts_list = []

    for var in input_cols:
        personal_ch = personal_ch.withColumn(var, when(col(var).isNull(), "unknown").otherwise(col(var)))
        grouped_counts = personal_ch.groupBy(var).pivot("ndl_wait_band").agg(count("*").alias("count")).withColumnRenamed(var, "value")
        total_counts = personal_ch.groupBy(var).agg(count("*").alias("total_count")).withColumnRenamed(var, "value")
        grouped_counts = grouped_counts.join(total_counts, on="value", how="left")
        for c in grouped_counts.columns:
            if c not in ["value", "total_count"]:
                grouped_counts = grouped_counts.withColumn(f"{c}_percentage", (coalesce(col(c), lit(0)) / col("total_count")) * 100)
                grouped_counts = grouped_counts.withColumn(c, round(coalesce(col(c), lit(0)), 2))
                grouped_counts = grouped_counts.withColumn(f"{c}_percentage", round(col(f"{c}_percentage"), 2))
        grouped_counts = grouped_counts.withColumn("Variable", lit(var))
        grouped_counts_list.append(grouped_counts)

    grouped_counts = reduce(lambda df1, df2: df1.unionByName(df2), grouped_counts_list)
        
    return grouped_counts

columns = ["ndl_age_band", "ndl_imd_quantile", "ndl_ethnicity", "ndl_ltc", "Sex", "Frailty_level"]
group_stats = calculate_wait_band_distribution_anna2(overlapping_pathways, columns)
display(group_stats)

In [0]:
group_stats_str = group_stats.toPandas().to_csv(index=False, sep=',', lineterminator='')
display(group_stats_str)

In [0]:
#non overlapping 
group_stats2 = calculate_wait_band_distribution_anna2(non_overlapping_pathways, columns)
display(group_stats2)

In [0]:
group_stats_str = group_stats2.toPandas().to_csv(index=False, sep=',', lineterminator='')
display(group_stats_str)

In [0]:
# Rename percentage columns for clarity before joining
overlapping_df = group_stats.withColumnRenamed("<= 18 weeks_percentage", "Overlapping_≤18wks")
non_overlapping_df = group_stats2.withColumnRenamed("<= 18 weeks_percentage", "NonOverlapping_≤18wks")

# Join the two datasets on the relevant category column (e.g., "ndl_ltc", "ndl_ethnicity", etc.)
comparison_df = overlapping_df.join(non_overlapping_df, on=["Variable","value"], how="inner")

# Calculate the percentage change
comparison_df = comparison_df.withColumn(
    "Change",
    (F.col("NonOverlapping_≤18wks") - F.col("Overlapping_≤18wks")).alias("Change")
)

# Format Change column to show percentage and label improvement
comparison_df = comparison_df.withColumn(
    "Change",
    F.concat(
        F.round(F.col("Change"), 2), F.lit("% improvement")
    )
)

# Select and reorder relevant columns
final_df_less_18weeks = comparison_df.select(
    "Variable","value", "Overlapping_≤18wks", "NonOverlapping_≤18wks", "Change"
)



# Show the result

In [0]:
# Rename percentage columns for clarity before joining
overlapping_df = group_stats.withColumnRenamed("> 52 weeks_percentage", "Overlapping_>52wks")
non_overlapping_df = group_stats2.withColumnRenamed("> 52 weeks_percentage", "NonOverlapping_>52wks")

# Join the two datasets on the relevant category column (e.g., "ndl_ltc", "ndl_ethnicity", etc.)
comparison_df = overlapping_df.join(non_overlapping_df, on=["Variable","value"], how="inner")

# Calculate the percentage change
comparison_df = comparison_df.withColumn(
    "Change",
    (F.col("NonOverlapping_>52wks") - F.col("Overlapping_>52wks")).alias("Change")
)

# Format Change column to show percentage and label improvement
comparison_df = comparison_df.withColumn(
    "Change",
    F.concat(
        F.round(F.col("Change"), 2), F.lit("% improvement")
    )
)

# Select and reorder relevant columns
final_df_more_52weeks = comparison_df.select(
    "Variable","value", "Overlapping_>52wks", "NonOverlapping_>52wks", "Change"
)

In [0]:
final_df_less_18weeks.show()


In [0]:
# Convert Spark DataFrames to Pandas DataFrames
cols= ["No LTCs", "Single LTC", "Multimorbidities", "black_background", "asian_background", "<= 10"]
filtered_df = final_df_less_18weeks.filter(
    col("value").isin(cols)
)
filtered_df=filtered_df.toPandas()

filtered_df_2 = final_df_more_52weeks.filter(
    col("value").isin(cols)
).toPandas()

# Set width for bars
bar_width = 0.4
x = np.arange(len(filtered_df["value"]))
y = np.arange(len(filtered_df_2["value"]))

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# First subplot: ≤18 Weeks Completion Rate
axes[0].bar(x - bar_width/2, filtered_df["Overlapping_≤18wks"], width=bar_width, label="Overlapping", color="orange")
axes[0].bar(x + bar_width/2, filtered_df["NonOverlapping_≤18wks"], width=bar_width, label="Non-Overlapping", color="orangered")

axes[0].set_xticks(x)
axes[0].set_xticklabels(filtered_df["value"], rotation=20, ha="right")
axes[0].set_ylabel("≤18 Weeks Completion Rate (%)")
axes[0].set_title("Short-Wait Completion Rates Comparison")
axes[0].legend()
axes[0].grid(axis="y", linestyle="--", alpha=0.6)

# Second subplot: >18 Weeks Completion Rate
axes[1].bar(y - bar_width/2, filtered_df_2["Overlapping_>52wks"], width=bar_width, label="Overlapping", color="orange")
axes[1].bar(y + bar_width/2, filtered_df_2["NonOverlapping_>52wks"], width=bar_width, label="Non-Overlapping", color="orangered")

axes[1].set_xticks(y)
axes[1].set_xticklabels(filtered_df_2["value"], rotation=20, ha="right")
axes[1].set_ylabel("Long-Wait Percentage (%)")
axes[1].set_title("Long-Wait Percentage Comparison")
axes[1].legend()
axes[1].grid(axis="y", linestyle="--", alpha=0.6)

plt.tight_layout()
plt.show()