In [0]:
%run "../00_config/set-up"

### Read Data

In [0]:
df = spark.sql("SELECT * FROM heme_data.overlap_preprocessed")
print('Row count: ', df.count(), 'Column Count: ', len(df.columns))
df.printSchema()

### Diagnostics 

In [0]:
from pyspark.sql.functions import year

# unique patients, hcps, and unique patients by PTNT_AGE_GRP
unique_counts = df.agg(
    countDistinct("PATIENT_ID").alias("unique_patients"),
    countDistinct("BH_ID").alias("unique_hcps")
)

unique_patients_by_age_grp = df.groupBy("PTNT_AGE_GRP").agg(
    countDistinct("PATIENT_ID").alias("unique_patients")
)

# count of patients and HCPs in 2023
count_2023 = df.filter(year("SHP_DT") == 2023).agg(
    countDistinct("PATIENT_ID").alias("n_patients_2023"),
    countDistinct("BH_ID").alias("n_hcp_2023")
)

unique_counts.display()
unique_patients_by_age_grp.display()
count_2023.display()

### Calculate Weekly IUs

In [0]:
from pyspark.sql.functions import date_trunc, to_date

# df_iu: Summarize IU by patient, product, week

# derive week from date
columns = ['BH_ID', 'PATIENT_ID', 'PRD_NM', 'SHP_DT', 'PRPHY', 'IU']
df_selected = df.select(*columns)
df_with_week = df_selected.withColumn(
    "SHP_WK",
    to_date(date_trunc("week", F.col("SHP_DT")))
)

# calculate weekly IU
df_iu = df_with_week.filter(F.col("PRD_NM").isNotNull()).groupBy(
    "BH_ID",
    "PATIENT_ID",
    "PRD_NM",
    "PRPHY",
    "SHP_WK"
).agg(
    F.sum("IU").alias("IU")
).withColumn('PRPHY', F.coalesce(F.col('PRPHY'), F.lit('1'))) \
.orderBy(F.col('PATIENT_ID').desc(), F.col('PRD_NM').desc(), F.col('SHP_WK').asc())

df_iu.printSchema()
df_iu.display()
df_iu.count()

### Create Product_Prophy column - v1

In [0]:
# Create additional columns in df_iu
# df_iu_2: concatenate PRD_NM and PRPHY 
df_iu_2 = df_iu.withColumn(
    "PRPHY",
    when(col("PRPHY").isNull() | (col("PRPHY") == "UNK"), "0").otherwise(col("PRPHY"))
).withColumn(
    "PRD_PRPHY", concat(col("PRD_NM"), lit("-"), col("PRPHY"))
)

# df_iu_3: Calculate the previous SHP_WK date and the first SHP_WK date in the window partition
df_iu_3 = df_iu_2.withColumn("PREV_SHP_WK", 
                             expr("lag(SHP_WK) over (partition by BH_ID, PATIENT_ID order by SHP_WK)")) \
                 .withColumn("FST_SHP_WK", 
                             expr("first(SHP_WK) over (partition by BH_ID, PATIENT_ID order by SHP_WK)"))
                 
# df_iu_4: Calculate the number of weeks since the last shipment and since the first shipment
df_iu_4 = df_iu_3.withColumn(
    "WK_SINCE_LST_SHP",
    (datediff(col("SHP_WK"), col("PREV_SHP_WK")) / 7).cast("integer")
).withColumn(
    "WK_SINCE_FST_SHP",
    (datediff(col("SHP_WK"), col("FST_SHP_WK")) / 7).cast("integer")
).fillna(0)

## 
# df_iu_4_filtered: Find first Rx in Jan2023 onward for each BH_ID, PATIENT_ID
date_threshold = "2023-01-01"
df_iu_4_filtered = (df_iu_4.filter(df_iu_4.SHP_WK >= date_threshold)
                    .groupBy('BH_ID', 'PATIENT_ID')
                    .agg(F.min("SHP_WK").alias("FIRST_SHP_WK"))
)

# df_iu_5: Mark in df_iu_4 the first RX after Jan2023
df_iu_5 = df_iu_4.join(df_iu_4_filtered, on=["BH_ID", "PATIENT_ID"], how="left")

# df_iu_6: Create column to mark 1st Rx after Jan 2023
df_iu_6 = df_iu_5.withColumn("IS_FST_GT_DEC22", 
    F.when(df_iu_5.SHP_WK == df_iu_5.FIRST_SHP_WK, "||")
    .otherwise("")
)
# Show the updated DataFrame
df_iu_6.display()
df_iu_6.count()

In [0]:
# df_iu_7: Group by BH_ID and PATIENT_ID and create Rx history string PRD_PRPHY
# Select the columns
selected_columns_df = df_iu_6.select("BH_ID", "PATIENT_ID", "PRD_PRPHY", "SHP_WK", "WK_SINCE_LST_SHP", "WK_SINCE_FST_SHP", "IS_FST_GT_DEC22")

# Concatenate PRD_PRPHY and IS_JAN23_NOV24
selected_columns_df = selected_columns_df.withColumn(
    "IS_FST_GT_DEC22_PRD_PRPHY", 
    F.concat(F.col("IS_FST_GT_DEC22"), F.col("PRD_PRPHY"))
)

# Calculate aggregate values at BH_ID, PATIENT_ID
df_iu_7 = selected_columns_df.withColumn(
    "MAX_WK_SINCE_LST_SHP", 
    F.max("WK_SINCE_LST_SHP").over(Window.partitionBy("BH_ID", "PATIENT_ID"))
).withColumn(
    "TOTAL_DOT", 
    F.last("WK_SINCE_FST_SHP").over(Window.partitionBy("BH_ID", "PATIENT_ID").orderBy("SHP_WK").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
).groupBy("BH_ID", "PATIENT_ID") \
    .agg(
        F.concat_ws(", ", F.collect_list("IS_FST_GT_DEC22_PRD_PRPHY")).alias("PRD_PRPHY_ALL"),
        F.max("MAX_WK_SINCE_LST_SHP").alias("MAX_WK_SINCE_LST_SHP"),
        F.max("TOTAL_DOT").alias("TOTAL_DOT")
    )

# result DataFrame
df_iu_7.display()

### Show Rx history for patients to whom the Jivi new writers prescribed Jivi in study period

In [0]:
# Find patients prescribed Jivi by Jivi New Writers
df_new_writer = spark.sql('select BH_ID, PATIENT_ID, BRTH_YR, SHP_DT, PRD_NM, SRC_SP, SOURCE_TYPE, DRUG_NM, SOURCE, RX_TYP, PTNT_WGT, SEVRTY, PRPHY from jivi_new_writer_model.jivi_new_writers_overlap_raw_data')

df_jivi_pt = df_new_writer.filter(
    (col("PRD_NM") == 'JIVI') &
    (col("SHP_DT").between('2023-01-01', '2024-11-30'))
).orderBy('BH_ID', 'PATIENT_ID', 'SHP_DT')

# Jivi patient ID
df_jivi_pt_id = df_jivi_pt.select('PATIENT_ID').distinct()

# Define a Window specification to rank by SHP_DT within each BH_ID
window_spec = Window.partitionBy("BH_ID").orderBy("SHP_DT")

# Add a row number based on the defined window
ranked_df = df_jivi_pt.withColumn("row_num", row_number().over(window_spec))

# Patient ID for the first PATIENT_ID for each BH_ID
df_fst_jivi_pt_id = ranked_df.filter(col("row_num") == 1).select("PATIENT_ID").distinct()

print(df_jivi_pt_id.count())
print(df_fst_jivi_pt_id.count())
df_fst_jivi_pt_id.display()


In [0]:
# Append Rx history to the Jivi patients
df_jivi_pt_rx = df_jivi_pt_id.join(df_iu_7, on="PATIENT_ID", how="left").orderBy('PATIENT_ID')
df_jivi_pt_rx.display()

In [0]:
# Check df_iu for Jivi patients
df_iu.filter(col('PATIENT_ID').isin())