In [0]:
%run "../00_config/set-up"

In [0]:
df = spark.sql("SELECT * FROM heme_data.overlap_rx")
print('Row count: ', df.count(), 'Column Count: ', len(df.columns))
df.printSchema()

### Calculate Weekly IUs

In [0]:
# Summarize IU by patient, product, week
from pyspark.sql.functions import date_trunc, to_date

# derive week from date
columns = ['BH_ID', 'PATIENT_ID', 'PRD_NM', 'SHP_DT', 'PRPHY', 'IU']
df_selected = df.select(*columns)
df_with_week = df_selected.withColumn(
    "SHP_WK",
    to_date(date_trunc("week", F.col("SHP_DT")))
)

# calculate weekly IU
df_iu = df_with_week.filter(F.col("PRD_NM").isNotNull()).groupBy(
    "PATIENT_ID",
    "PRD_NM",
    "PRPHY",
    "SHP_WK"
).agg(
    F.sum("IU").alias("IU")
).withColumn('PRPHY', F.coalesce(F.col('PRPHY'), F.lit('1'))) \
.orderBy(F.col('PATIENT_ID').desc(), F.col('PRD_NM').desc(), F.col('SHP_WK').asc())

df_iu.printSchema()
df_iu.display()

#### Market Overview

In [0]:
from pyspark.sql.functions import date_trunc, to_date, col, countDistinct, sum as F_sum

# derive month from SHP_WK
df_with_month = df_iu.withColumn("SHP_MONTH", to_date(date_trunc("month", col("SHP_WK"))))

# calculate number of unique patients and total IU by PRD_NM by month
df_summary = df_with_month.groupBy("SHP_MONTH", "PRD_NM").agg(
    countDistinct("PATIENT_ID").alias("unique_patients"),
    F_sum("IU").alias("total_IU")
)

# pivot the table to show PRD_NM in columns
df_pivot = df_summary.groupBy("SHP_MONTH").pivot("PRD_NM").agg(
    F.first("unique_patients").alias("unique_patients"),
    F.first("total_IU").alias("total_IU")
).orderBy("SHP_MONTH")

df_pivot.display()

### Create Wide Rx Table

In [0]:

# transform weekly iu data to wide format
df_iu_wide = (df_iu.withColumn("PRD_PRPHY",
    F.concat(
        F.lit("PRD_"),
        F.col("PRD_NM"),
        F.lit("_PRPHY_"),
        F.col("PRPHY"), 
    )).groupBy("PATIENT_ID", "SHP_WK").pivot("PRD_PRPHY").agg(F.first("IU"))
    .orderBy(F.col('PATIENT_ID').desc(), F.col('SHP_WK').asc())
    )
df_iu_wide.printSchema()    
df_iu_wide.display()

##### Check distribution

In [0]:
df_iu_wide = df_iu_wide.fillna(0)

In [0]:
patient_wk_cnt = df_iu_wide.groupby("PATIENT_ID").agg(countDistinct("SHP_WK").alias("week_cnt"))
patient_wk_cnt.describe().show()
patient_wk_cnt.approxQuantile("week_cnt", [0.25, 0.5, 0.75,0.9], 0.01)

In [0]:
ptnt_wi_38w = patient_wk_cnt.filter(col("week_cnt") >= 38)

In [0]:
print(ptnt_wi_38w.count())
ptnt_wi_38w.display()

In [0]:
df_iu_wide.filter(col("PATIENT_ID")=='C92AA410-0AF7-48E7-8DE9-B4B98E4F0B78').display()

In [0]:
# Grouping by PatientID and calculating standard deviation for multiple columns
agg_exprs = [stddev(col).alias(f"StdDev_{col}") for col in df_iu_wide.columns[2:]]
ptnt_iu_std = df_iu_wide.groupBy("PATIENT_ID").agg(*agg_exprs)

In [0]:
ptnt_iu_std_wi_38w= ptnt_wi_38w.join(ptnt_iu_std, on =['PATIENT_ID'],how='left')

In [0]:
from pyspark.sql import functions as F
from functools import reduce

ptnt_iu_std_wi_38w = ptnt_iu_std_wi_38w.withColumn(
    "RowTotal",
    reduce(lambda a, b: a + b, [F.col(c) for c in ptnt_iu_std_wi_38w.columns[2:]])
)

In [0]:
ptnt_iu_std_wi_38w.display()

In [0]:
27201*166

In [0]:
unique_wk = df_iu_wide.select("SHP_WK").filter(col("SHP_WK")>='2022-01-01').distinct()

In [0]:
patient_week_pair = (df_iu_wide.select("PATIENT_ID").distinct()
                  .crossJoin(unique_wk)
                  .orderBy('PATIENT_ID','SHP_WK'))

In [0]:
df_iu_wide.select("SHP_WK").distinct().count()

In [0]:
df_iu_wide_explode = (
  patient_week_pair.join(df_iu_wide, on = ['PATIENT_ID','SHP_WK'], how = 'left')
  .orderBy('PATIENT_ID','SHP_WK')
  .fillna(0)
  )

In [0]:
df_iu_wide_explode.count()

In [0]:
window_spec = Window.partitionBy("PATIENT_ID").orderBy("SHP_WK").rowsBetween(-4, -1)
std_cols = []

# Calculate the rolling standard deviation for each specified column
for column in df_iu_wide.columns[2:]:
    df_iu_wide_explode = df_iu_wide_explode.withColumn(f"RollingStdDev_{column}", stddev(col(column)).over(window_spec))
    std_cols.append(f"RollingStdDev_{column}")

In [0]:
df_iu_wide_explode.approxQuantile('RollingStdDev_PRD_KOATE_PRPHY_0', [0.25, 0.5, 0.75,0.9], 0.01)

In [0]:
for col in std_cols:
  print(col, df_iu_wide_explode.approxQuantile(col, [0.25, 0.5, 0.75,0.9,1.0], 0.01))

In [0]:
unique_wk.display()

In [0]:
patient_week_pair.count()

### Patient Records - Jivi Patients

In [0]:
# get sample Jivi patients: top n with most number of Jivi records

n=20
top_n_patients_with_rank = df_iu.filter(
    (F.col("PRD_NM") == "JIVI") & 
    (F.col("PRPHY").isNotNull())
).groupBy("PATIENT_ID"
).agg(
    F.count("*").alias("record_count")
).withColumn(
    "rank", 
    F.dense_rank().over(Window.orderBy(F.col("record_count").desc()))
).filter(
    F.col("rank") <= n
)

result_df_with_rank = df_iu.join(
    top_n_patients_with_rank, 
    on="PATIENT_ID",
    how="inner"
).orderBy(
    "rank", 
    "SHP_WK"
)
# result_df_with_rank.display()
# show top Jivi patients in wide format
df_iu_wide_jivi = df_iu_wide.join(
    result_df_with_rank.select("PATIENT_ID").distinct(),
    on="PATIENT_ID",
    how="inner"
).orderBy(
    "PATIENT_ID",
    "SHP_WK"
)

# Remove columns with 100% null values
non_null_columns = [col for col, dtype in df_iu_wide_jivi.dtypes if df_iu_wide_jivi.filter(F.col(col).isNotNull()).count() > 0]
df_iu_wide_jivi_non_null = df_iu_wide_jivi.select(*non_null_columns)

df_iu_wide_jivi_non_null.display()

# findings: top 5 JIVI users didn't use other products

In [0]:
df_iu_wide_jivi_non_null.printSchema()

In [0]:
# get sample Jivi patients: top n with most number of Jivi records who also used other products

n = 20
# patients using other products
other_products_patients = df_iu.filter(
    (df_iu.PRD_NM.isin(['ELOCATE', 'HEMLIBRA', 'ALTUVIIIO', 'ADVATE', 'RECOMBINATE', 'NUWIQ']))
).select('PATIENT_ID').distinct()

# find the PATIENT_IDs with JIVI records and count them
jivi_counts = df_iu.filter(
    (df_iu.PRD_NM == 'JIVI') &
    (df_iu.PRPHY.isNotNull())
).join(
    other_products_patients,
    'PATIENT_ID'
).groupBy('PATIENT_ID').count()

# Get the top n PATIENT_IDs based on count 
top_n_patients = jivi_counts.orderBy('count', ascending=False).limit(n).select('PATIENT_ID')

# get all records for these top PATIENT_IDs
result = df_iu_wide.join(
    top_n_patients,
    'PATIENT_ID'
).orderBy('PATIENT_ID', 'SHP_WK')

# Remove columns with 100% null values
non_null_columns = [col for col, dtype in result.dtypes if result.filter(F.col(col).isNotNull()).count() > 0]
result_non_null = result.select(*non_null_columns)

result_non_null.display()


#### Patients using Jivi in 2024 and other product prior

In [0]:
# get sample Jivi patients: patients who started Jivi in yr and were on other products before Jivi adoption

jivi_start_yr = 2024
# find patients who had other products before JIVI
# Create a window spec partitioned by PATIENT_ID ordered by SHP_WK
window_spec = Window.partitionBy('PATIENT_ID').orderBy('SHP_WK')

# Add a column to mark the first JIVI date for each patient
df_with_first_jivi = df_iu.withColumn(
    'first_jivi_date',
    F.first(
        F.when(F.col('PRD_NM') == 'JIVI', F.col('SHP_WK'))
    ).over(Window.partitionBy('PATIENT_ID'))
)
# Find patients who had other products before their first JIVI
eligible_patients = df_with_first_jivi.filter(
    (F.col('SHP_WK') < F.col('first_jivi_date')) &
    (~F.col('PRD_NM').isin(['JIVI', 'KOVALTRY', 'KOGENATE'])) &
    (F.col('first_jivi_date').isNotNull())
).select('PATIENT_ID').distinct()

final_results = df_iu.join(
    eligible_patients,
    'PATIENT_ID'
).filter(
    (F.col('PRD_NM') == 'JIVI') &
    (F.year('SHP_WK') == jivi_start_yr) 
).orderBy('PATIENT_ID', 'SHP_WK')

# count new jivi patients
distinct_patients = final_results.select('PATIENT_ID').distinct().count()
print(f"Number of distinct PATIENT_IDs in {jivi_start_yr}: {distinct_patients}")

# show rx history of new jivi patients
patient_ids = final_results.select('PATIENT_ID').distinct()

# Join with df_iu_wide to get records for these patients
df_iu_wide_jivi_pt = df_iu_wide.join(patient_ids, 'PATIENT_ID')

def clean_dataframe(df):
    # Remove columns with 100% null values
    non_null_columns = [col for col, dtype in df.dtypes if df.filter(F.col(col).isNotNull()).count() > 0]
    df_non_null = df.select(*non_null_columns)
    
    # Fill null values with 0 for all columns except PATIENT_ID and SHP_WK
    columns_to_fill = [col for col in df_non_null.columns if col not in ['PATIENT_ID', 'SHP_WK']]
    for col in columns_to_fill:
        df_non_null = df_non_null.withColumn(col, F.coalesce(F.col(col), F.lit(0)))
    
    # Convert decimal values to integers
    df_non_null = df_non_null.select(
    *[F.col(c).cast('int') if t == 'decimal(38,5)' else F.col(c) for c, t in df_non_null.dtypes]
    )
    
    return df_non_null

df_iu_wide_jivi_pt_clean = clean_dataframe(df_iu_wide_jivi_pt)
df_iu_wide_jivi_pt_clean.display()

#### Get top n Jivi patients

In [0]:
# show top n patients who started Jivi in 2024 and had the most number of records 
n = 20

# Count number of records per PATIENT_ID
patient_counts = df_iu_wide_jivi_pt_clean.groupBy('PATIENT_ID').count().orderBy('count', ascending=False)

# Get top n PATIENT_IDs
top_n_patients = patient_counts.limit(n)

# Join back with the original data to get all records for top 5 patients
df_iu_wide_jivi_pt_top = df_iu_wide_jivi_pt_clean.join(
    top_n_patients,
    'PATIENT_ID'
).orderBy('PATIENT_ID', 'SHP_WK').drop('count')

# Show all records for top n patients
print("\nDetailed records for top n patients:")
df_iu_wide_jivi_pt_top.display(truncate=False)



#### Explode patient Rx history to consecutive weeks

In [0]:
# Define the date range
start_date = "2022-01-03"
end_date = "2024-11-25"

# Create a DataFrame with all weeks in the date range
date_range_df = spark.sql(f"""
    SELECT explode(sequence(to_date('{start_date}'), to_date('{end_date}'), interval 1 week)) as SHP_WK
""")

# Get distinct PATIENT_IDs and cross join to get all combinations of PATIENT_ID and SHP_WK
patient_weeks_df = df_iu_wide_jivi_pt_top.select("PATIENT_ID").distinct().crossJoin(date_range_df)

# Join with the original DataFrame
result_df = patient_weeks_df.join(
    df_iu_wide_jivi_pt_top,
    on=["PATIENT_ID", "SHP_WK"],
    how="left"
).fillna(0).orderBy('PATIENT_ID', 'SHP_WK')

result_df.printSchema()
result_df.display()

#### Visualize

##### One patient at a time

In [0]:
def filter_patient_data(df, patient_id):
    return df.filter(col("PATIENT_ID") == patient_id)

# Example 
patient_id_val = 'B6FFEE43-05C4-48C6-983C-D0EADD401AF5'
df_pt = filter_patient_data(result_df, patient_id_val)
df_pt.display()

In [0]:
import seaborn as sns
import matplotlib.pyplot as plt

# Select only the columns that start with 'PRD_' and 'PATIENT_ID', 'SHP_WK'
prd_columns = [col for col in df_pt.columns if col.startswith('PRD_')]
df_pt_prd = df_pt.select(['SHP_WK'] + prd_columns)

# Convert to Pandas DataFrame for plotting
df_pt_prd_pd = df_pt_prd.toPandas()

# Set 'PATIENT_ID' and 'SHP_WK' as index
df_pt_prd_pd.set_index(['SHP_WK'], inplace=True)

# Plot heatmap
plt.figure(figsize=(20, 10))
# sns.heatmap(df_pt_prd_pd, cmap='viridis_r', annot=False)
sns.heatmap(df_pt_prd_pd, cmap='YlOrBr', annot=False)

plt.title('Heatmap of PRD_ Columns')
plt.show()

##### All patients

In [0]:
from pyspark.sql.functions import col

# Function to generate heatmap for each patient
def generate_heatmap(df, patient_id, pdf):
    df_pt = df.filter(col("PATIENT_ID") == patient_id)
    if df_pt.count() == 0:
        return
    prd_columns = [col for col in df_pt.columns if col.startswith('PRD_')]
    df_pt_prd = df_pt.select(['SHP_WK'] + prd_columns)
    df_pt_prd_pd = df_pt_prd.toPandas()
    df_pt_prd_pd.set_index(['SHP_WK'], inplace=True)
    
    plt.figure(figsize=(12, 8))
    sns.heatmap(df_pt_prd_pd, cmap='YlOrBr', annot=False)
    plt.title(f'Heatmap of PRD_ Columns for Patient {patient_id}')
    plt.xticks(fontsize=8, rotation=45, ha='right')
    plt.tight_layout()
    pdf.savefig()
    plt.close()

# Create a PDF file to save the heatmaps
with PdfPages('output_heatmaps.pdf') as pdf:
    # Loop through each PATIENT_ID and generate heatmap
    patient_ids = result_df.select("PATIENT_ID").distinct().collect()
    for row in patient_ids:
        display(row)
        generate_heatmap(result_df, row["PATIENT_ID"], pdf)

In [0]:
result_df = df_iu_wide_jivi_pt_top
patient_ids = result_df.select("PATIENT_ID").distinct().collect()
for row in patient_ids:
    display(row)

### Patient Records - Altuviio Patients

In [0]:
# get sample ALTUVIIO patients: top n with most number of ALTUVIIO records

n=20
top_n_patients_with_rank = df_iu.filter(
    (F.col("PRD_NM") == "ALTUVIIIO") & 
    (F.col("PRPHY").isNotNull())
).groupBy("PATIENT_ID"
).agg(
    F.count("*").alias("record_count")
).withColumn(
    "rank", 
    F.dense_rank().over(Window.orderBy(F.col("record_count").desc()))
).filter(
    F.col("rank") <= n
)

result_df_with_rank = df_iu.join(
    top_n_patients_with_rank, 
    on="PATIENT_ID",
    how="inner"
).orderBy(
    "rank", 
    "SHP_WK"
)
# result_df_with_rank.display()
# show top ALTU patients in wide format
df_iu_wide_alt = df_iu_wide.join(
    result_df_with_rank.select("PATIENT_ID").distinct(),
    on="PATIENT_ID",
    how="inner"
).orderBy(
    "PATIENT_ID",
    "SHP_WK"
)

# Remove columns with 100% null values
non_null_columns = [col for col, dtype in df_iu_wide_alt.dtypes if df_iu_wide_alt.filter(F.col(col).isNotNull()).count() > 0]
df_iu_wide_alt_non_null = df_iu_wide_alt.select(*non_null_columns)

df_iu_wide_alt_non_null.display()

