## Imports

In [0]:
import sys
#sys.path.append("../../")
#from Config import *

sys.path.append("../src/")

from utils import *

In [0]:
%pip install openpyxl

## Cohort 1 Investigation

In [0]:
# dummy data

cohort_1_path="../Dummy_Data/Cohort_1_synth.xlsx"
cohort_1_df = pd.read_excel(cohort_1_path)

cohort_1_df = spark.createDataFrame(cohort_1_df)

In [0]:
# real data

# cohort_1_df = spark.read.format("parquet").load(cohort_1_link)
#display(cohort_1_df)

In [0]:
# Create a new column to categorize ndl_wait_band
cohort_1_df = cohort_1_df.withColumn("wait_band_category", when(col("ndl_wait_band") == "<= 18 weeks", "Less than 18 weeks").otherwise("More than 18 weeks"))

# Group by Specialty and wait_band_category to get the count
wait_band_counts_specialty = cohort_1_df.groupBy("Specialty", "wait_band_category").count()

# Convert to Pandas DataFrame for plotting
wait_band_counts_specialty_pd = wait_band_counts_specialty.select("Specialty", "wait_band_category", "count").toPandas()

# Order the bars from largest to smallest
wait_band_counts_specialty_pd = wait_band_counts_specialty_pd.sort_values(by="count", ascending=False)

# Plot using seaborn for Specialty
plt.figure(figsize=(12, 8))
sns.barplot(data=wait_band_counts_specialty_pd, x="count", y="Specialty", hue="wait_band_category", order=wait_band_counts_specialty_pd["Specialty"].unique())
plt.ylabel("Specialty", fontsize=18)
plt.xlabel("Count", fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend(title="Wait Band Category", fontsize=16)
plt.show()

# Create a new column to categorize provider and provider_site
cohort_1_df = cohort_1_df.withColumn("provider_category", when(col("wlmds_provider").contains("RR8"), "Contains RR8").otherwise("Does not contain RR8"))
cohort_1_df = cohort_1_df.withColumn("provider_site_category", when(col("wlmds_provider_site").contains("RR8"), "Contains RR8").otherwise("Does not contain RR8"))

# Group by wlmds_provider and wait_band_category to get the count and proportion
wait_band_counts_provider = cohort_1_df.groupBy("provider_category", "wait_band_category").count()
total_counts_provider = wait_band_counts_provider.groupBy("provider_category").agg(sum("count").alias("total_count"))
wait_band_counts_provider = wait_band_counts_provider.join(total_counts_provider, "provider_category").withColumn("proportion", col("count") / col("total_count"))

# Convert to Pandas DataFrame for plotting
wait_band_counts_provider_pd = wait_band_counts_provider.select("provider_category", "wait_band_category", "proportion").toPandas()

# Order the bars from largest to smallest
wait_band_counts_provider_pd = wait_band_counts_provider_pd.sort_values(by="proportion", ascending=False)

# Plot using seaborn for wlmds_provider
plt.figure(figsize=(12, 8))
sns.barplot(data=wait_band_counts_provider_pd, y="proportion", x="provider_category", hue="wait_band_category", order=wait_band_counts_provider_pd["provider_category"].unique())
plt.xlabel("Provider", fontsize=18)
plt.ylabel("Proportion", fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend(title="Wait Band Category", fontsize=16)
plt.show()

# Group by wlmds_provider_site and wait_band_category to get the count and proportion
wait_band_counts_provider_site = cohort_1_df.groupBy("provider_site_category", "wait_band_category").count()
total_counts_provider_site = wait_band_counts_provider_site.groupBy("provider_site_category").agg(sum("count").alias("total_count"))
wait_band_counts_provider_site = wait_band_counts_provider_site.join(total_counts_provider_site, "provider_site_category").withColumn("proportion", col("count") / col("total_count"))

# Convert to Pandas DataFrame for plotting
wait_band_counts_provider_site_pd = wait_band_counts_provider_site.select("provider_site_category", "wait_band_category", "proportion").toPandas()

# Order the bars from largest to smallest
wait_band_counts_provider_site_pd = wait_band_counts_provider_site_pd.sort_values(by="proportion", ascending=False)

# Plot using seaborn for wlmds_provider_site
plt.figure(figsize=(12, 8))
sns.barplot(data=wait_band_counts_provider_site_pd, y="proportion", x="provider_site_category", hue="wait_band_category", order=wait_band_counts_provider_site_pd["provider_site_category"].unique())
plt.xlabel("Provider Site", fontsize=18)
plt.ylabel("Proportion", fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.legend(title="Wait Band Category", fontsize=16)
plt.show()

In [0]:
gynaecology_df = cohort_1_df.filter(cohort_1_df.Specialty == "Gynaecology")

median_wait_length = gynaecology_df.groupBy("ndl_ltc").agg(median("ndl_wait_length").alias("median_wait_length"))

display(median_wait_length)

In [0]:
# Calculate the median "ndl_wait_length" for each value of "ndl_start_date_count" for each "ndl_ltc"
median_wait_length_start_date_count = cohort_1_df.groupBy("ndl_ltc", "ndl_start_date_count").agg(median("ndl_wait_length").alias("median_wait_length"))

median_wait_length_start_bucket_count = cohort_1_df.groupBy("ndl_ltc", "ndl_start_bucket").agg(median("ndl_wait_length").alias("median_wait_length"))

display(median_wait_length_start_date_count)
display(median_wait_length_start_bucket_count)

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql.functions import countDistinct, col

# Filter records with more than 1 distinct "wlmds_rtt_start_date_conc3" for each "wlmds_patient_id" and "wlmds_treatment_function_code"
distinct_rtt_start_date_count = cohort_1_df.groupBy("wlmds_patient_id", "wlmds_treatment_function_code", "Specialty") \
    .agg(countDistinct("wlmds_rtt_start_date_conc3").alias("distinct_rtt_start_date_count")) \
    .filter("distinct_rtt_start_date_count > 1")

# Filter records where "wlmds_type_changes_last" contains "IRTT"
irtt_records = cohort_1_df.filter(col("wlmds_type_changes_last").contains("IRTT"))

# Count the number of such records for each "Specialty"
specialty_record_count = distinct_rtt_start_date_count.groupBy("Specialty") \
    .count() \
    .withColumnRenamed("count", "record_count")

# Count the number of IRTT records for each "Specialty"
irtt_specialty_record_count = irtt_records.groupBy("Specialty") \
    .count() \
    .withColumnRenamed("count", "irtt_record_count")

# Join the two counts
specialty_record_count = specialty_record_count.join(irtt_specialty_record_count, on="Specialty", how="left").fillna(0)

# Calculate the proportion of IRTT records
specialty_record_count = specialty_record_count.withColumn("irtt_proportion", col("irtt_record_count") / col("record_count"))

# Convert to Pandas DataFrame for plotting
specialty_record_count_pd = specialty_record_count.orderBy("record_count", ascending=False).toPandas()

# Plot the count of records for each "Specialty" and highlight the proportion of IRTT records
plt.figure(figsize=(12, 8))
sns.barplot(data=specialty_record_count_pd, x='Specialty', y='record_count', order=specialty_record_count_pd['Specialty'])
plt.xlabel('Specialty')
plt.ylabel('Record Count')
plt.title('Count of Records where Patients have >1 Distinct RTT Start Date for Each Specialty')
plt.xticks(rotation=45, ha='right')

# Add a secondary y-axis for the proportion of IRTT records
ax2 = plt.twinx()
sns.lineplot(data=specialty_record_count_pd, x='Specialty', y='irtt_proportion', color='red', marker='o', ax=ax2)
ax2.set_ylabel('Proportion of Records with an Admission')

plt.tight_layout()
plt.show()

In [0]:
from pyspark.sql.functions import countDistinct, col

# Filter records with more than 1 distinct "wlmds_rtt_start_date_conc3" for each "wlmds_patient_id" and "wlmds_treatment_function_code"
distinct_rtt_start_date_count = cohort_1_df.groupBy("wlmds_patient_id", "wlmds_treatment_function_code") \
    .agg(countDistinct("wlmds_rtt_start_date_conc3").alias("distinct_rtt_start_date_count")) \
    .filter("distinct_rtt_start_date_count > 1")

# Filter records where "Specialty" = "Other - Medical Services"
filtered_records = cohort_1_df.filter(col("Specialty") == "Other - Medical Services")

# Join the filtered records with the distinct_rtt_start_date_count
joined_records = filtered_records.join(distinct_rtt_start_date_count, on=["wlmds_patient_id", "wlmds_treatment_function_code"], how="inner")

# Count the frequency of groups for each "wlmds_treatment_function_code"
group_frequency = joined_records.groupBy("wlmds_treatment_function_code") \
    .count() \
    .withColumnRenamed("count", "group_frequency")

# Convert to Pandas DataFrame for plotting
group_frequency_pd = group_frequency.orderBy("group_frequency", ascending=False).toPandas()

# Plot the frequency of groups for each "wlmds_treatment_function_code"
plt.figure(figsize=(12, 8))
sns.barplot(data=group_frequency_pd, x='wlmds_treatment_function_code', y='group_frequency', order=group_frequency_pd['wlmds_treatment_function_code'])
plt.xlabel('Treatment Function Code')
plt.ylabel('Group Frequency')
plt.title('Frequency of Groups with >1 Distinct RTT Start Date for Each Treatment Function Code (Specialty: Other - Medical Services)')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [0]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Define the order for the x-axis
order = ["No LTCs", "Single LTC", "Comorbidities", "Multimorbidities"]

median_wait_length_start_date_count_pd = median_wait_length_start_date_count.toPandas()

plt.figure(figsize=(12, 8))
sns.barplot(data=median_wait_length_start_date_count_pd, x='ndl_ltc', y='median_wait_length', hue='ndl_start_date_count', order=order)
plt.xlabel('LTC Bracket')
plt.ylabel('Wait Length (Days)')
plt.title('Median Waiting Time Across LTC Bracket and Start Date Count (Gynaecology)')
plt.xticks(rotation=45)
plt.legend(title="NDL Start Date Count")
plt.tight_layout()

average_median_wait_length = median_wait_length_start_date_count_pd.groupby('ndl_ltc')['median_wait_length'].mean().reset_index()

# Ensure the order for the line plot
sns.lineplot(data=average_median_wait_length, x='ndl_ltc', y='median_wait_length', color='red', marker='o', label='Average Median Wait Length').set_xticks(order)

plt.show()

In [0]:
from pyspark.sql.functions import col, when
import matplotlib.pyplot as plt
import seaborn as sns

# Define the columns for which to create box and whisker plots, excluding "Specialty" and "ndl_age"
columns = ["ndl_ethnicity", "ndl_ltc", "ndl_imd_quantile", "Frailty_level"]

# Create a grid for the plots
fig, axes = plt.subplots(2, 2, figsize=(20, 16))
axes = axes.flatten()

for i, column in enumerate(columns):
    # Select the relevant columns for plotting
    plot_df = cohort_1_df.select(column, "ndl_wait_length").toPandas()

    # Sort values by median_wait_length in descending order
    order = plot_df.groupby(column)["ndl_wait_length"].median().sort_values(ascending=False).index

    sns.boxplot(ax=axes[i], data=plot_df, x=column, y='ndl_wait_length', order=order)
    axes[i].set_xlabel(column.replace('_', ' ').title(), fontsize=18)
    axes[i].set_ylabel('Wait Length (Days)', fontsize=18)
    axes[i].tick_params(axis='x', rotation=45, labelsize=16)
    axes[i].tick_params(axis='y', labelsize=16)

plt.tight_layout()
plt.show()

In [0]:
from pyspark.sql.functions import col, median, desc
import seaborn as sns

# Filter out null values in "ndl_age_band"
df_filtered = cohort_1_df.filter(col("ndl_age_band").isNotNull())

# Calculate the median "ndl_wait_length" for each value of "ndl_age_band"
median_wait_length_age_band = df_filtered.groupBy("ndl_age_band").agg(median("ndl_wait_length").alias("median_wait_length")).orderBy(desc("median_wait_length"))

# Calculate the median "ndl_wait_length" for each value of "Specialty"
median_wait_length_specialty = cohort_1_df.groupBy("Specialty").agg(median("ndl_wait_length").alias("median_wait_length")).orderBy(desc("median_wait_length"))

# Convert to Pandas DataFrames for plotting
median_wait_length_age_band_pd = median_wait_length_age_band.toPandas()
median_wait_length_specialty_pd = median_wait_length_specialty.toPandas()

# Define the order for ndl_age_band
age_band_order = ["<= 10", "11-17", "18-24", "25-34", "35-44", "45-54", "55-64", "65-74", "75-84", "84+"]

# Plot bar charts
fig, axes = plt.subplots(1, 2, figsize=(24, 10))

# Plot for "ndl_age_band"
sns.barplot(ax=axes[0], data=median_wait_length_age_band_pd, y='ndl_age_band', x='median_wait_length', order=age_band_order)
axes[0].set_ylabel('NDL Age Band', fontsize=18)
axes[0].set_xlabel('Median Wait Length (Days)', fontsize=18)
axes[0].tick_params(axis='y', rotation=0, labelsize=18)
axes[0].tick_params(axis='x', labelsize=18)
for bar in axes[0].patches:
    if bar.get_width() > 100:
        bar.set_color('red')

# Plot for "Specialty"
sns.barplot(ax=axes[1], data=median_wait_length_specialty_pd, y='Specialty', x='median_wait_length', order=median_wait_length_specialty_pd['Specialty'])
axes[1].set_ylabel('Specialty', fontsize=18)
axes[1].set_xlabel('Median Wait Length (Days)', fontsize=18)
axes[1].tick_params(axis='y', rotation=0, labelsize=18)
axes[1].tick_params(axis='x', labelsize=18)
for bar in axes[1].patches:
    if bar.get_width() > 100:
        bar.set_color('red')

# Add legend
fig.legend(['Median waits over 100 days are highlighted'], loc='upper right', bbox_to_anchor=(0.99, 0.15), fontsize=18)

plt.tight_layout()
plt.show()

In [0]:
from pyspark.sql.functions import median

gynaecology_df = cohort_1_df.filter(cohort_1_df.Specialty == "Gynaecology")

median_wait_length = gynaecology_df.groupBy("ndl_ethnicity").agg(median("ndl_wait_length").alias("median_wait_length"))

display(median_wait_length)

In [0]:
import matplotlib.pyplot as plt

# Extract waiting time column
waiting_times = cohort_1_df.select("ndl_wait_length").rdd.flatMap(lambda x: x).collect()

# Plot histogram
plt.figure(figsize=(10, 6))
plt.hist(waiting_times, bins=30, edgecolor='black')
plt.title('Histogram of Waiting Times')
plt.xlabel('Waiting Time')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [0]:
import matplotlib.pyplot as plt
import pandas as pd
from pyspark.sql import functions as F

# Define the calculate_wait_band_percentages function
def calculate_wait_band_percentages(df, wait_band_col, speciality_col):
    total_counts = df.groupBy(speciality_col).count().withColumnRenamed("count", "total_count")
    wait_band_counts = df.groupBy(speciality_col, wait_band_col).count().withColumnRenamed("count", "wait_band_count")
    return wait_band_counts.join(total_counts, on=speciality_col) \
                           .withColumn("percentage", (F.col("wait_band_count") / F.col("total_count")) * 100)

# Calculate wait band percentages for specialties
df_speciality_wait_band = calculate_wait_band_percentages(cohort_1_df, "ndl_wait_band", "Specialty")

# Convert to Pandas DataFrame for plotting
pdf_speciality_wait_band = df_speciality_wait_band.select("Specialty", "ndl_wait_band", "percentage").toPandas()

# Pivot the DataFrame for stacked bar plot
pivot_df = pdf_speciality_wait_band.pivot(index='Specialty', columns='ndl_wait_band', values='percentage').fillna(0)

# Plot the stacked bar chart
pivot_df.plot(kind='bar', stacked=True, figsize=(10, 7))

plt.title('Wait Band Percentages as Stacked Bars for Speciality')
plt.xlabel('Speciality')
plt.ylabel('Percentage')
plt.legend(title='Wait Band', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [0]:
from pyspark.sql.functions import countDistinct

# Group by patient ID and count distinct TFCs for each patient
patient_tfc_counts = cohort_1_df.groupBy("wlmds_patient_id").agg(
    countDistinct("tfc").alias("distinct_tfc_count")
)

# Filter patients with more than one distinct TFC
patients_with_multiple_tfc = patient_tfc_counts.filter(col("distinct_tfc_count") > 1)

# Count the number of patients with multiple TFCs
num_patients_with_multiple_tfc = patients_with_multiple_tfc.count()

num_patients_with_multiple_tfc

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns

# Convert Spark DataFrame to Pandas DataFrame for plotting
df_pandas = cohort_1_df.select("ndl_wait_length", "Sex").toPandas()

# Create the violin plot
plt.figure(figsize=(10, 6))
sns.violinplot(hue="Sex", y="ndl_wait_length", data=df_pandas, split=True, palette="Set2", inner="quart", gap=0)

# Set plot title and labels
plt.title("Violin Plot of Waiting Length by Sex")
plt.xlabel("Sex")
plt.ylabel("Waiting Length (Days)")

# Display the plot
plt.show()

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql.functions import col

# Plotting ethnicity proportions for each specialty and wait band
df_ethnicity_specialty_band = cohort_1_df.groupBy("Specialty", "ndl_ethnicity", "ndl_wait_band").count()
df_total_specialty_band = cohort_1_df.groupBy("Specialty", "ndl_wait_band").count().withColumnRenamed("count", "total_count")

df_proportion_band = df_ethnicity_specialty_band.join(df_total_specialty_band, on=["Specialty", "ndl_wait_band"])
df_proportion_band = df_proportion_band.withColumn("proportion", col("count") / col("total_count"))

df_proportion_band_pandas = df_proportion_band.select("Specialty", "ndl_ethnicity", "ndl_wait_band", "proportion").toPandas()

ethnicity_baseline_mapping = {
    "asian_background": 0.097,
    "black_background": 0.056,
    "white_background": 0.79,
    "mixed_background": 0.034,
    "other_background": 0.023
}

wait_band_order = ["<= 18 weeks", "> 18 weeks", "> 36 weeks", "> 52 weeks"]

specialties = df_proportion_band_pandas["Specialty"].unique()
for specialty in specialties:
    plt.figure(figsize=(14, 8))
    data = df_proportion_band_pandas[df_proportion_band_pandas["Specialty"] == specialty]
    data = data.sort_values(by="proportion", ascending=False)
    sns.barplot(x="ndl_ethnicity", y="proportion", hue="ndl_wait_band", data=data, palette="Set2", hue_order=wait_band_order)
    
    sorted_baseline = dict(sorted(ethnicity_baseline_mapping.items(), key=lambda item: item[1], reverse=True))
    plt.plot(list(sorted_baseline.keys()), list(sorted_baseline.values()), marker='o', color='r', label='Baseline Population Percentage')
    
    plt.title(f"Proportion of Each Ethnicity for Specialty: {specialty}", fontsize=20)
    plt.xlabel("Ethnicity", fontsize=18)
    plt.ylabel("Proportion", fontsize=18)
    plt.xticks(rotation=45, ha="right", fontsize=18)
    plt.yticks(fontsize=18)

    plt.legend(fontsize=18)
    plt.show()

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql.functions import col

# Calculate the proportion of each "ndl_ethnicity" for each "Specialty"
df_ethnicity_specialty = cohort_1_df.groupBy("Specialty", "ndl_ethnicity").count()
df_total_specialty = cohort_1_df.groupBy("Specialty").count().withColumnRenamed("count", "total_count")

df_proportion = df_ethnicity_specialty.join(df_total_specialty, on="Specialty")
df_proportion = df_proportion.withColumn("proportion", col("count") / col("total_count"))

df_proportion_pandas = df_proportion.select("Specialty", "ndl_ethnicity", "proportion").toPandas()

# population data from https://www.ons.gov.uk/peoplepopulationandcommunity/culturalidentity/ethnicity/bulletins/ethnicgroupenglandandwales/census2021
ethnicity_baseline_mapping = {
    "asian_background": 0.097,
    "black_background": 0.056,
    "white_background": 0.79,
    "mixed_background": 0.034,
    "other_background": 0.023
}

df_proportion_pandas["baseline"] = df_proportion_pandas["ndl_ethnicity"].map(ethnicity_baseline_mapping)
df_proportion_pandas["difference"] = df_proportion_pandas["proportion"] - df_proportion_pandas["baseline"]

plt.figure(figsize=(14, 8))
sns.barplot(x="Specialty", y="difference", hue="ndl_ethnicity", data=df_proportion_pandas, palette="Set2")
plt.title("Difference Between Proportion of Ethnicity and Baseline Across Each Specialty")
plt.xlabel("Specialty")
plt.ylabel("Difference in Proportion")
plt.xticks(rotation=45, ha="right")

plt.show()

In [0]:
from pyspark.sql.functions import countDistinct, expr
import matplotlib.pyplot as plt
import seaborn as sns

# Calculate the count of distinct "wlmds_rtt_start_date_conc3" values for each "wlmds_patient_id"
df_distinct_counts = cohort_1_df.groupBy("wlmds_patient_id", "Frailty_level").agg(
    countDistinct("wlmds_rtt_start_date_conc3").alias("distinct_count"),
    expr("percentile_approx(ndl_wait_length, 0.5)").alias("median_wait_length")
)

# Filter out rows where distinct_count is 1
df_distinct_counts = df_distinct_counts#.filter(col("distinct_count") != 1)

# Convert to Pandas DataFrame for plotting
df_distinct_counts_pandas = df_distinct_counts.toPandas()

# Plot the frequency of each count and median wait length on the same chart with 2 y-axes
fig, ax1 = plt.subplots(figsize=(14, 8))

# Bar plot for median wait length with hue for "Frailty_level"
sns.barplot(x="distinct_count", y="median_wait_length", hue="Frailty_level", data=df_distinct_counts_pandas, ax=ax1, hue_order=["Fit", "Mild", "Moderate", "Severe"])
ax1.set_xlabel("Count of Distinct 'wlmds_rtt_start_date_conc3'")
ax1.set_ylabel("Median 'ndl_wait_length'", color='b')
ax1.tick_params(axis='y', labelcolor='b')
ax1.set_title("Frequency of Distinct 'wlmds_rtt_start_date_conc3' Counts and Median 'ndl_wait_length'")

plt.xticks(rotation=45, ha="right")
plt.show()

In [0]:
from pyspark.sql.functions import col, when, lit

# Define the conditions for the two subsections
condition_1 = (col("Frailty_level") == "Fit") & (col("Age") < 61) & (col("ndl_ltc") == "No LTCs")
condition_2 = ((col("Frailty_level").isin("Severe", "Moderate")) & (col("Age") > 60) & 
               (col("ndl_ltc").isin("Comorbidities", "Multimorbidities", "Single LTC")))

# Filter the DataFrame for each condition
df_condition_1 = cohort_1_df.filter(condition_1)
df_condition_2 = cohort_1_df.filter(condition_2)

# Calculate the count of distinct "wlmds_rtt_start_date_conc3" values and median "ndl_wait_length" for each condition
df_distinct_counts_1 = df_condition_1.groupBy("wlmds_patient_id").agg(
    countDistinct("wlmds_rtt_start_date_conc3").alias("distinct_count"),
    expr("percentile_approx(ndl_wait_length, 0.5)").alias("median_wait_length")
).withColumn("Subsection", lit("Fit, Age < 61, No LTCs"))

df_distinct_counts_2 = df_condition_2.groupBy("wlmds_patient_id").agg(
    countDistinct("wlmds_rtt_start_date_conc3").alias("distinct_count"),
    expr("percentile_approx(ndl_wait_length, 0.5)").alias("median_wait_length")
).withColumn("Subsection", lit("Severe/Moderate, Age > 60, Comorbidities/Multimorbidities/Single LTC"))

# Union the two DataFrames
df_distinct_counts_combined = df_distinct_counts_1.union(df_distinct_counts_2)

# Convert to Pandas DataFrame for plotting
df_distinct_counts_combined_pandas = df_distinct_counts_combined.toPandas()

# Plot the median "ndl_wait_length" for each count of distinct "wlmds_rtt_start_date_conc3" values with subsections as hue
fig, ax1 = plt.subplots(figsize=(14, 8))

sns.barplot(x="distinct_count", y="median_wait_length", hue="Subsection", data=df_distinct_counts_combined_pandas, ax=ax1)
ax1.set_xlabel("Count of Distinct 'wlmds_rtt_start_date_conc3'")
ax1.set_ylabel("Median 'ndl_wait_length'", color='b')
ax1.tick_params(axis='y', labelcolor='b')
ax1.set_title("Median 'ndl_wait_length' for Each Count of Distinct 'wlmds_rtt_start_date_conc3' Values by Subsection")

plt.xticks(rotation=45, ha="right")
plt.show()

In [0]:
from pyspark.sql.functions import col, when, countDistinct, expr, lit
import matplotlib.pyplot as plt
import seaborn as sns

# Define the conditions for the two subsections
condition_1 = (col("Frailty_level") == "Fit") & (col("Age") < 61) & (col("ndl_ltc") == "No LTCs")
condition_2 = ((col("Frailty_level").isin("Severe", "Moderate")) & (col("Age") > 60) & 
               (col("ndl_ltc").isin("Comorbidities", "Multimorbidities", "Single LTC")))

# Filter the DataFrame for each condition
df_condition_1 = cohort_1_df.filter(condition_1)
df_condition_2 = cohort_1_df.filter(condition_2)

# Calculate the count of distinct "wlmds_rtt_start_date_conc3" values and median "ndl_wait_length" for each condition
df_distinct_counts_1 = df_condition_1.groupBy("wlmds_patient_id").agg(
    countDistinct("wlmds_rtt_start_date_conc3").alias("distinct_count"),
    expr("percentile_approx(ndl_wait_length, 0.5)").alias("median_wait_length")
).withColumn("Subsection", lit("Fit, Age < 61, No LTCs"))

df_distinct_counts_2 = df_condition_2.groupBy("wlmds_patient_id").agg(
    countDistinct("wlmds_rtt_start_date_conc3").alias("distinct_count"),
    expr("percentile_approx(ndl_wait_length, 0.5)").alias("median_wait_length")
).withColumn("Subsection", lit("Severe/Moderate, Age > 60, Comorbidities/Multimorbidities/Single LTC"))

# Union the two DataFrames
df_distinct_counts_combined = df_distinct_counts_1.union(df_distinct_counts_2)

# Get distinct specialties
distinct_specialties = cohort_1_df.select("Specialty").distinct().collect()
specialties = [row["Specialty"] for row in distinct_specialties]

# Loop through each specialty and create plots
for specialty in specialties:
    df_specialty = cohort_1_df.groupBy("wlmds_patient_id").agg(
        countDistinct("wlmds_rtt_start_date_conc3").alias("distinct_count")
    ).filter(col("distinct_count") > 1).join(cohort_1_df.filter(col("Specialty") == specialty), "wlmds_patient_id")
    
    df_specialty = df_specialty.withColumnRenamed("distinct_count", f"{specialty.lower().replace(' ', '_')}_distinct_count")
    
    df_specialty_combined = df_distinct_counts_combined.join(df_specialty, "wlmds_patient_id")
    
    df_specialty_combined_pandas = df_specialty_combined.select(
        "distinct_count", "median_wait_length", "Subsection"
    ).toPandas()
    
    fig, ax1 = plt.subplots(figsize=(14, 8))
    sns.barplot(x="distinct_count", y="median_wait_length", hue="Subsection", data=df_specialty_combined_pandas, ax=ax1)
    ax1.set_xlabel("Count of Distinct 'wlmds_rtt_start_date_conc3'")
    ax1.set_ylabel("Median 'ndl_wait_length'", color='b')
    ax1.tick_params(axis='y', labelcolor='b')
    ax1.set_title(f"Median 'ndl_wait_length' for Each Count of Distinct 'wlmds_rtt_start_date_conc3' Values by Subsection ({specialty})")
    plt.xticks(rotation=45, ha="right")
    plt.show()

In [0]:
from pyspark.sql.functions import countDistinct, expr

# Calculate the count of distinct "wlmds_rtt_start_date_conc3" values for each "wlmds_patient_id" and "Specialty"
df_distinct_counts = cohort_1_df.groupBy("Specialty", "wlmds_patient_id", "Frailty_level").agg(
    countDistinct("wlmds_rtt_start_date_conc3").alias("distinct_count"),
    expr("percentile_approx(ndl_wait_length, 0.5)").alias("median_wait_length")
)

# Filter out rows where distinct_count is 1
df_distinct_counts = df_distinct_counts#.filter(col("distinct_count") != 1)

# Convert to Pandas DataFrame for plotting
df_distinct_counts_pandas = df_distinct_counts.toPandas()

# Plot the frequency of each count and median wait length on the same chart with 2 y-axes for each Specialty
specialties = df_distinct_counts_pandas["Specialty"].unique()
for specialty in specialties:
    df_specialty = df_distinct_counts_pandas[df_distinct_counts_pandas["Specialty"] == specialty]
    
    fig, ax1 = plt.subplots(figsize=(14, 8))
    
    # Bar plot for median wait length with hue for "Frailty_level"
    sns.barplot(x="distinct_count", y="median_wait_length", hue="Frailty_level", data=df_specialty, ax=ax1)
    ax1.set_xlabel("Count of Distinct 'wlmds_rtt_start_date_conc3'")
    ax1.set_ylabel("Median 'ndl_wait_length'", color='b')
    ax1.tick_params(axis='y', labelcolor='b')
    ax1.set_title(f"Frequency of Distinct 'wlmds_rtt_start_date_conc3' Counts and Median 'ndl_wait_length' for {specialty}")
    
    plt.xticks(rotation=45, ha="right")
    plt.show()

## Cohort 2 Investigations

In [0]:
# dummy data

cohort_2_path="../Dummy_Data/Cohort_2_synth.xlsx"
cohort_2_df = pd.read_excel(cohort_2_path)

cohort_2_df = spark.createDataFrame(cohort_2_df)

In [0]:
# real data
# cohort_2_df = spark.read.format("parquet").load(cohort_2_link)

In [0]:
conditions = [
    (cohort_2_df.wlmds_status == 30) & (col("wlmds_type_changes_last").contains("IRTT")),    # treatment + admitted                  = 30 & IRTT
    (cohort_2_df.wlmds_status == 30) & (~col("wlmds_type_changes_last").contains("IRTT")),   # treatment + not admitted              = 31 & no IRTT
    (cohort_2_df.wlmds_status == 31) | (cohort_2_df.wlmds_status == 32),                    # non treatment, start monitoring       = 31 or 32
    cohort_2_df.wlmds_status == 35,                                                         # non treatment, patient declines       = 35
    cohort_2_df.wlmds_status == 34,                                                         # non treatment, decision to not treat  = 34
    cohort_2_df.wlmds_status == 33,                                                         # non treatment, DNA                    = 33
    (~cohort_2_df.wlmds_status.isin(30, 31, 32, 33, 34, 35, 36, 99)),                       # non treatment, other                  = not 30, 31, 32, 33, 34, 35, 36, 99
    cohort_2_df.wlmds_status == 36,                                                         # non treatment, death                  = 36 
    cohort_2_df.wlmds_status == 99,                                                         # unknown                               = 99
    (cohort_2_df.wlmds_status == 99) & (col("wlmds_end_date_type_last") == "week_end_date") # unknown, imputed end                  = 99 and week_end_date
]

labels = [
    "treatment_admitted",
    "treatment_non_admitted",
    "non_treatment_monitoring",
    "non_treatment_patient_declines",
    "non_treatment_decision_not_to_treat",
    "non_treatment_dna",
    "non_treatment_other",
    "non_treatment_death",
    "unknown",
    "unknown_imputed_end"
]

reason_expr = when(conditions[0], labels[0])
for cond, label in zip(conditions[1:], labels[1:]):
    reason_expr = reason_expr.when(cond, label)

cohort_2_df = cohort_2_df.withColumn("Reason", reason_expr)
#display(cohort_2_df)

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns

# Filter the DataFrame
filtered_df = cohort_2_df.filter(col("wlmds_end_date_type_last") == "rtt_end_date")

# Group by Reason and calculate the median ndl_wait_length
median_wait_length_df = filtered_df.groupBy("Reason").agg(expr("percentile_approx(ndl_wait_length, 0.5)").alias("median_ndl_wait_length"))

# Convert to Pandas DataFrame for plotting
median_wait_length_pd = median_wait_length_df.toPandas()

# Define the order of the x-axis values
order = [
    "treatment_admitted", 
    "treatment_non_admitted", 
    "non_treatment_monitoring", 
    "non_treatment_patient_declines", 
    "non_treatment_decision_not_to_treat", 
    "non_treatment_dna", 
    "non_treatment_death", 
    "unknown"
]

# Plotting
plt.figure(figsize=(12, 8))
barplot = sns.barplot(x="Reason", y="median_ndl_wait_length", data=median_wait_length_pd, order=order)

# Highlight bars with median wait length above 100 days
for index, reason in enumerate(order):
    if median_wait_length_pd[median_wait_length_pd["Reason"] == reason]["median_ndl_wait_length"].values[0] > 100:
        barplot.patches[index].set_color('red')

plt.xticks(rotation=45, ha="right", fontsize=16)
plt.yticks(fontsize=16)
plt.xlabel("Reason", fontsize=18)
plt.ylabel("Median NDL Wait Length (days)", fontsize=18)
plt.tight_layout()
plt.show()

## Table Generation

### Objective 1

In [0]:
# dummy data
cohort_1_path="../Dummy_Data/Cohort_1_synth.xlsx"
cohort_1_df = pd.read_excel(cohort_1_path)

cohort_1_df = spark.createDataFrame(cohort_1_df)

#real data
#cohort_1_df = spark.read.format("parquet").load(cohort_1_link)

In [0]:
cohort_1_inpatient_df = cohort_1_df.filter(cohort_1_df.ndl_type == "inpatient")
cohort_1_outpatient_df = cohort_1_df.filter(cohort_1_df.ndl_type == "outpatient")

In [0]:
df_11_1_all = calculate_wait_band_distribution(cohort_1_df, "Specialty")
df_11_2_all = calculate_wait_band_distribution(cohort_1_df, "Sex")
df_11_3_all = calculate_wait_band_distribution(cohort_1_df, "ndl_age_band")
df_11_4_all = calculate_wait_band_distribution(cohort_1_df, "ndl_imd_quantile")
df_11_5_all = calculate_wait_band_distribution(cohort_1_df, "ndl_ethnicity")
df_11_6_all = calculate_wait_band_distribution(cohort_1_df, "Frailty_Level")
df_11_7_all = calculate_wait_band_distribution(cohort_1_df, "ndl_ltc")

In [0]:
df_11_1_all = rename_and_add_column(df_11_1_all, "Value")
df_11_2_all = rename_and_add_column(df_11_2_all, "Value")
df_11_3_all = rename_and_add_column(df_11_3_all, "Value")
df_11_4_all = rename_and_add_column(df_11_4_all, "Value")
df_11_5_all = rename_and_add_column(df_11_5_all, "Value")
df_11_6_all = rename_and_add_column(df_11_6_all, "Value")
df_11_7_all = rename_and_add_column(df_11_7_all, "Value")

combined_df11_all = df_11_1_all.unionByName(df_11_2_all).unionByName(df_11_3_all).unionByName(df_11_4_all).unionByName(df_11_5_all).unionByName(df_11_6_all).unionByName(df_11_7_all)

In [0]:
df_11_1_ip = calculate_wait_band_distribution(cohort_1_inpatient_df, "Specialty")
df_11_2_ip = calculate_wait_band_distribution(cohort_1_inpatient_df, "Sex")
df_11_3_ip = calculate_wait_band_distribution(cohort_1_inpatient_df, "ndl_age_band")
df_11_4_ip = calculate_wait_band_distribution(cohort_1_inpatient_df, "ndl_imd_quantile")
df_11_5_ip = calculate_wait_band_distribution(cohort_1_inpatient_df, "ndl_ethnicity")
df_11_6_ip = calculate_wait_band_distribution(cohort_1_inpatient_df, "Frailty_Level")
df_11_7_ip = calculate_wait_band_distribution(cohort_1_inpatient_df, "ndl_ltc")

In [0]:
df_11_1_ip = rename_and_add_column(df_11_1_ip, "Value")
df_11_2_ip = rename_and_add_column(df_11_2_ip, "Value")
df_11_3_ip = rename_and_add_column(df_11_3_ip, "Value")
df_11_4_ip = rename_and_add_column(df_11_4_ip, "Value")
df_11_5_ip = rename_and_add_column(df_11_5_ip, "Value")
df_11_6_ip = rename_and_add_column(df_11_6_ip, "Value")
df_11_7_ip = rename_and_add_column(df_11_7_ip, "Value")

combined_df11_ip = df_11_1_ip.unionByName(df_11_2_ip).unionByName(df_11_3_ip).unionByName(df_11_4_ip).unionByName(df_11_5_ip).unionByName(df_11_6_ip).unionByName(df_11_7_ip)


In [0]:
df_11_1_op = calculate_wait_band_distribution(cohort_1_outpatient_df, "Specialty")
df_11_2_op = calculate_wait_band_distribution(cohort_1_outpatient_df, "Sex")
df_11_3_op = calculate_wait_band_distribution(cohort_1_outpatient_df, "ndl_age_band")
df_11_4_op = calculate_wait_band_distribution(cohort_1_outpatient_df, "ndl_imd_quantile")
df_11_5_op = calculate_wait_band_distribution(cohort_1_outpatient_df, "ndl_ethnicity")
df_11_6_op = calculate_wait_band_distribution(cohort_1_outpatient_df, "Frailty_Level")
df_11_7_op = calculate_wait_band_distribution(cohort_1_outpatient_df, "ndl_ltc")

In [0]:
df_11_1_op = rename_and_add_column(df_11_1_op, "Value")
df_11_2_op = rename_and_add_column(df_11_2_op, "Value")
df_11_3_op = rename_and_add_column(df_11_3_op, "Value")
df_11_4_op = rename_and_add_column(df_11_4_op, "Value")
df_11_5_op = rename_and_add_column(df_11_5_op, "Value")
df_11_6_op = rename_and_add_column(df_11_6_op, "Value")
df_11_7_op = rename_and_add_column(df_11_7_op, "Value")

combined_df11_op = df_11_1_op.unionByName(df_11_2_op).unionByName(df_11_3_op).unionByName(df_11_4_op).unionByName(df_11_5_op).unionByName(df_11_6_op).unionByName(df_11_7_op)

In [0]:
df_12_1 = calculate_cross_wait_band_distribution(cohort_1_df, "ndl_age_band", "Sex")
df_12_2 = calculate_cross_wait_band_distribution(cohort_1_df, "ndl_imd_quantile", "Sex")
df_12_3 = calculate_cross_wait_band_distribution(cohort_1_df, "ndl_imd_quantile", "ndl_age_band")
# Requested tables for tom:
# imd x frailty, imd x ltc, ethnicity x frailty, ethnicity x ltc, imd x ethnicity
df_12_4 = calculate_cross_wait_band_distribution(cohort_1_df, "ndl_imd_quantile", "Frailty_Level")
df_12_5 = calculate_cross_wait_band_distribution(cohort_1_df, "ndl_imd_quantile", "ndl_ltc")
df_12_6 = calculate_cross_wait_band_distribution(cohort_1_df, "ndl_ethnicity", "Frailty_Level")
df_12_7 = calculate_cross_wait_band_distribution(cohort_1_df, "ndl_ethnicity", "ndl_ltc")
df_12_8 = calculate_cross_wait_band_distribution(cohort_1_df, "ndl_imd_quantile", "ndl_ethnicity")

In [0]:
df_12_1 = rename_and_add_column_cross(df_12_1, "Value 1", "Value 2")
df_12_2 = rename_and_add_column_cross(df_12_2, "Value 1", "Value 2")
df_12_3 = rename_and_add_column_cross(df_12_3, "Value 1", "Value 2")
df_12_4 = rename_and_add_column_cross(df_12_4, "Value 1", "Value 2")
df_12_5 = rename_and_add_column_cross(df_12_5, "Value 1", "Value 2")
df_12_6 = rename_and_add_column_cross(df_12_6, "Value 1", "Value 2")
df_12_7 = rename_and_add_column_cross(df_12_7, "Value 1", "Value 2")
df_12_8 = rename_and_add_column_cross(df_12_8, "Value 1", "Value 2")

combined_df12 = df_12_1.unionByName(df_12_2).unionByName(df_12_3).unionByName(df_12_4).unionByName(df_12_5).unionByName(df_12_6).unionByName(df_12_7).unionByName(df_12_8)

In [0]:
df_13_1 = calculate_wait_length_statistics(cohort_1_df, "ndl_wait_length", "Specialty")
df_13_2 = calculate_wait_length_statistics(cohort_1_df, "ndl_wait_length", "Sex")
df_13_3 = calculate_wait_length_statistics(cohort_1_df, "ndl_wait_length", "ndl_age_band")
df_13_4 = calculate_wait_length_statistics(cohort_1_df, "ndl_wait_length", "ndl_imd_quantile")
df_13_5 = calculate_wait_length_statistics(cohort_1_df, "ndl_wait_length", "ndl_ethnicity")
df_13_6 = calculate_wait_length_statistics(cohort_1_df, "ndl_wait_length", "Frailty_level")
df_13_7 = calculate_wait_length_statistics(cohort_1_df, "ndl_wait_length", "ndl_ltc")

In [0]:
df_13_1 = rename_and_add_column(df_13_1, "Value")
df_13_2 = rename_and_add_column(df_13_2, "Value")
df_13_3 = rename_and_add_column(df_13_3, "Value")
df_13_4 = rename_and_add_column(df_13_4, "Value")
df_13_5 = rename_and_add_column(df_13_5, "Value")
df_13_6 = rename_and_add_column(df_13_6, "Value")
df_13_7 = rename_and_add_column(df_13_7, "Value")

combined_df13 = df_13_1.unionByName(df_13_2).unionByName(df_13_3).unionByName(df_13_4).unionByName(df_13_5).unionByName(df_13_6).unionByName(df_13_7)

In [0]:
df_14_1 = calculate_wait_length_statistics_cross(cohort_1_df, "ndl_wait_length", "ndl_age_band", "Sex")
df_14_2 = calculate_wait_length_statistics_cross(cohort_1_df, "ndl_wait_length", "ndl_imd_quantile", "Sex")
df_14_3 = calculate_wait_length_statistics_cross(cohort_1_df, "ndl_wait_length", "ndl_age_band", "ndl_imd_quantile")

# Requested tables for tom:
# imd x frailty, imd x ltc, ethnicity x frailty, ethnicity x ltc, imd x ethnicity
df_14_4 = calculate_wait_length_statistics_cross(cohort_1_df, "ndl_wait_length", "ndl_imd_quantile", "Frailty_Level")
df_14_5 = calculate_wait_length_statistics_cross(cohort_1_df, "ndl_wait_length", "ndl_imd_quantile", "ndl_ltc")
df_14_6 = calculate_wait_length_statistics_cross(cohort_1_df, "ndl_wait_length", "ndl_ethnicity", "Frailty_Level")
df_14_7 = calculate_wait_length_statistics_cross(cohort_1_df, "ndl_wait_length", "ndl_ethnicity", "ndl_ltc")
df_14_8 = calculate_wait_length_statistics_cross(cohort_1_df, "ndl_wait_length", "ndl_imd_quantile", "ndl_ethnicity")

In [0]:
df_14_1 = rename_and_add_column_cross(df_14_1, "Value 1", "Value 2")
df_14_2 = rename_and_add_column_cross(df_14_2, "Value 1", "Value 2")
df_14_3 = rename_and_add_column_cross(df_14_3, "Value 1", "Value 2")
df_14_4 = rename_and_add_column_cross(df_14_4, "Value 1", "Value 2")
df_14_5 = rename_and_add_column_cross(df_14_5, "Value 1", "Value 2")
df_14_6 = rename_and_add_column_cross(df_14_6, "Value 1", "Value 2")
df_14_7 = rename_and_add_column_cross(df_14_7, "Value 1", "Value 2")
df_14_8 = rename_and_add_column_cross(df_14_8, "Value 1", "Value 2")

combined_df14 = df_14_1.unionByName(df_14_2).unionByName(df_14_3).unionByName(df_14_4).unionByName(df_14_5).unionByName(df_14_6).unionByName(df_14_7).unionByName(df_14_8)

In [0]:
objective_1_1_all = objective_1_1_all_link
objective_1_1_ip = objective_1_1_ip_link
objective_1_1_op = objective_1_1_op_link
objective_1_2 = objective_1_2_link
objective_1_3 = objective_1_3_link
objective_1_4 = objective_1_4_link

#combined_df11_all.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_1_1_all)
#combined_df11_ip.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_1_1_ip)
#combined_df11_op.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_1_1_op)
#combined_df12.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_1_2)
#combined_df13.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_1_3)
#combined_df14.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_1_4)

### Objective 2

In [0]:
# dummy data

cohort_2_path="../Dummy_Data/Cohort_2_synth.xlsx"
cohort_2_df = pd.read_excel(cohort_2_path)

cohort_2_df = spark.createDataFrame(cohort_2_df)

#real data
# cohort_2_df = spark.read.format("parquet").load(cohort_2_link)

In [0]:
df_21_1 = objective2_table(cohort_2_df, "Specialty")
df_21_2 = objective2_table(cohort_2_df, "Sex")
df_21_3 = objective2_table(cohort_2_df, "ndl_age_band")
df_21_4 = objective2_table(cohort_2_df, "ndl_imd_quantile")
df_21_5 = objective2_table(cohort_2_df, "ndl_ethnicity")
df_21_6 = objective2_table(cohort_2_df, "Frailty_level")
df_21_7 = objective2_table(cohort_2_df, "ndl_ltc")

In [0]:
df_21_1 = rename_and_add_column(df_21_1, "Value")
df_21_2 = rename_and_add_column(df_21_2, "Value")
df_21_3 = rename_and_add_column(df_21_3, "Value")
df_21_4 = rename_and_add_column(df_21_4, "Value")
df_21_5 = rename_and_add_column(df_21_5, "Value")
df_21_6 = rename_and_add_column(df_21_6, "Value")
df_21_7 = rename_and_add_column(df_21_7, "Value")

combined_df_2 = df_21_1.unionByName(df_21_2).unionByName(df_21_3).unionByName(df_21_4).unionByName(df_21_5).unionByName(df_21_6).unionByName(df_21_7)

In [0]:
objective_2 = objective_2_link

#combined_df_2.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_2)

### Objective 3

In [0]:
# dummy data
cohort_3_path="../Dummy_Data/Cohort_3_synth.xlsx"
cohort_3_df = pd.read_excel(cohort_3_path)

cohort_3_df = spark.createDataFrame(cohort_3_df).na.fill(0)

# real data
#cohort_3_df = spark.read.format('parquet').load(cohort_3_link).na.fill(0)

In [0]:
use_prefixes = [
    "gp_healthcare_use_sum",
    "u111_healthcare_use_sum",
    "u999_healthcare_use_sum",
    "u00H_healthcare_use_sum",
    "ae_healthcare_use_sum",
    "nel_healthcare_use_sum",
    "el_healthcare_use_sum",
    "op_healthcare_use_sum",
    "all_pres_sum",
    "antib_pres_sum",
    "antidep_pres_sum",
    "pain_pres_sum",
    "sick_note_sum"
]

investigation_cols = [
    "Specialty",
    "Sex",
    "ndl_age_band",
    "ndl_imd_quantile",
    "ndl_ethnicity",
    "Frailty_Level",
    "ndl_ltc"
]

result_dfs = {}

for prefix in use_prefixes:
    dfs = [rename_and_add_column(objective_3_stats(cohort_3_df, col, prefix), "Value") for col in investigation_cols]
    result_df = reduce(lambda df1, df2: df1.unionByName(df2), dfs)
    result_dfs[prefix] = result_df

df_31_1_use = result_dfs["gp_healthcare_use_sum"]
df_31_2_use = result_dfs["u111_healthcare_use_sum"]
df_31_3_use = result_dfs["u999_healthcare_use_sum"]
df_31_4_use = result_dfs["u00H_healthcare_use_sum"]
df_31_5_use = result_dfs["ae_healthcare_use_sum"]
df_31_6_use = result_dfs["nel_healthcare_use_sum"]
df_31_7_use = result_dfs["el_healthcare_use_sum"]
df_31_8_use = result_dfs["op_healthcare_use_sum"]
df_31_9_use = result_dfs["all_pres_sum"]
df_31_10_use = result_dfs["antib_pres_sum"]
df_31_11_use = result_dfs["antidep_pres_sum"]
df_31_12_use = result_dfs["pain_pres_sum"]
df_31_13_use = result_dfs["sick_note_sum"]

In [0]:
objective_3_1_gp = objective_3_1_gp_link
objective_3_1_111 = objective_3_1_111_link
objective_3_1_999 = objective_3_1_999_link
objective_3_1_ooh = objective_3_1_ooh_link
objective_3_1_ae = objective_3_1_ae_link
objective_3_1_nel = objective_3_1_nel_link
objective_3_1_el = objective_3_1_el_link
objective_3_1_op = objective_3_1_op_link
objective_3_1_all_pres = objective_3_1_all_pres_link
objective_3_1_antib = objective_3_1_antib_link
objective_3_1_antidep = objective_3_1_antidep_link
objective_3_1_pain = objective_3_1_pain_link
objective_3_1_sick_note = objective_3_1_sick_note_link

#df_31_1_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_gp)
#df_31_2_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_111)
#df_31_3_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_999)
#df_31_4_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_ooh)
#df_31_5_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_ae)
#df_31_6_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_nel)
#df_31_7_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_el)
#df_31_8_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_op)
#df_31_9_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_all_pres)
#df_31_10_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_antib)
#df_31_11_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_antidep)
#df_31_12_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_pain)
#df_31_13_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_sick_note)

In [0]:
cost_prefixes = [
    "op_Total_Cost",
    "ae_Total_Cost",
    "gp_Total_Cost",
    "el_Total_Cost"
]

result_dfs = {}

for prefix in cost_prefixes:
    dfs = [rename_and_add_column(objective_3_stats(cohort_3_df, col, prefix), "Value") for col in investigation_cols]
    result_df = reduce(lambda df1, df2: df1.unionByName(df2), dfs)
    result_dfs[prefix] = result_df

df_31_1_cost = result_dfs["op_Total_Cost"]
df_31_2_cost = result_dfs["ae_Total_Cost"]
df_31_3_cost = result_dfs["gp_Total_Cost"]
df_31_4_cost = result_dfs["el_Total_Cost"]

In [0]:
objective_3_1_op_cost = objective_3_1_op_cost_link
objective_3_1_ae_cost = objective_3_1_ae_cost_link
objective_3_1_gp_cost = objective_3_1_gp_cost_link
objective_3_1_el_cost = objective_3_1_el_cost_link

#df_31_1_cost.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_op_cost)
#df_31_2_cost.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_ae_cost)
#df_31_3_cost.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_gp_cost)
#df_31_4_cost.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_1_el_cost)

In [0]:
ob_3_2_cols = [
    "ndl_wait_band"
]

for prefix in use_prefixes:
    dfs = [rename_and_add_column(objective_3_stats(cohort_3_df, col, prefix), "Value") for col in ob_3_2_cols]
    result_df = reduce(lambda df1, df2: df1.unionByName(df2), dfs)
    result_dfs[prefix] = result_df

df_32_1_use = result_dfs["gp_healthcare_use_sum"]
df_32_2_use = result_dfs["u111_healthcare_use_sum"]
df_32_3_use = result_dfs["u999_healthcare_use_sum"]
df_32_4_use = result_dfs["u00H_healthcare_use_sum"]
df_32_5_use = result_dfs["ae_healthcare_use_sum"]
df_32_6_use = result_dfs["nel_healthcare_use_sum"]
df_32_7_use = result_dfs["el_healthcare_use_sum"]
df_32_8_use = result_dfs["op_healthcare_use_sum"]
df_32_9_use = result_dfs["all_pres_sum"]
df_32_10_use = result_dfs["antib_pres_sum"]
df_32_11_use = result_dfs["antidep_pres_sum"]
df_32_12_use = result_dfs["pain_pres_sum"]
df_32_13_use = result_dfs["sick_note_sum"]

In [0]:
objective_3_2_gp = objective_3_2_gp_link
objective_3_2_111 = objective_3_2_111_link
objective_3_2_999 = objective_3_2_999_link
objective_3_2_ooh = objective_3_2_ooh_link
objective_3_2_ae = objective_3_2_ae_link
objective_3_2_nel = objective_3_2_nel_link
objective_3_2_el = objective_3_2_el_link
objective_3_2_op = objective_3_2_op_link
objective_3_2_all_pres = objective_3_2_all_pres_link
objective_3_2_antib = objective_3_2_antib_link
objective_3_2_antidep = objective_3_2_antidep_link
objective_3_2_pain = objective_3_2_pain_link
objective_3_2_sick_note = objective_3_2_sick_note_link

#df_32_1_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_gp)
#df_32_2_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_111)
#df_32_3_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_999)
#df_32_4_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_ooh)
#df_32_5_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_ae)
#df_32_6_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_nel)
#df_32_7_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_el)
#df_32_8_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_op)
#df_32_9_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_all_pres)
#df_32_10_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_antib)
#df_32_11_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_antidep)
#df_32_12_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_pain)
#df_32_13_use.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_sick_note)

In [0]:
ob_3_2_cols = [
    "ndl_wait_band"
]

for prefix in cost_prefixes:
    dfs = [rename_and_add_column(objective_3_stats(cohort_3_df, col, prefix), "Value") for col in ob_3_2_cols]
    result_df = reduce(lambda df1, df2: df1.unionByName(df2), dfs)
    result_dfs[prefix] = result_df

df_32_1_cost = result_dfs["op_Total_Cost"]
df_32_2_cost = result_dfs["ae_Total_Cost"]
df_32_3_cost = result_dfs["gp_Total_Cost"]
df_32_4_cost = result_dfs["el_Total_Cost"]

In [0]:
objective_3_2_op_cost = objective_3_2_op_cost_link
objective_3_2_ae_cost = objective_3_2_ae_cost_link
objective_3_2_gp_cost = objective_3_2_gp_cost_link
objective_3_2_el_cost = objective_3_2_el_cost_link

#df_32_1_cost.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_op_cost)
#df_32_2_cost.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_ae_cost)
#df_32_3_cost.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_gp_cost)
#df_32_4_cost.write.format('parquet').mode('overwrite').option('overwriteSchema','True').save(objective_3_2_el_cost)