In [0]:
%run "../00_config/set-up"

### Read Data

In [0]:
df = spark.sql("SELECT * FROM heme_data.overlap_rx")
print('Row count: ', df.count(), 'Column Count: ', len(df.columns))
df.printSchema()

### Diagnostics 

In [0]:
from pyspark.sql.functions import year

# unique patients, hcps, and unique patients by PTNT_AGE_GRP
unique_counts = df.agg(
    countDistinct("PATIENT_ID").alias("unique_patients"),
    countDistinct("BH_ID").alias("unique_hcps")
)

unique_patients_by_age_grp = df.groupBy("PTNT_AGE_GRP").agg(
    countDistinct("PATIENT_ID").alias("unique_patients")
)

# count of patients in 2023
patients_2023_count = df.filter(year("SHP_DT") == 2023).agg(
    countDistinct("PATIENT_ID").alias("patients_2023")
)

unique_counts.display()
unique_patients_by_age_grp.display()
patients_2023_count.display()

##### Patients with multiple BRTH_YR

In [0]:
# Patients with inconsistent BRTH_YR btween SHS and SP
# Filter rows where both SP_PTNT_BRTH_YR and SHS_PTNT_BRTH_YR are not null and their values are different
df_filtered = df.filter(
    (col("SP_PTNT_BRTH_YR").isNotNull()) &
    (col("SHS_PTNT_BRTH_YR").isNotNull()) &
    (col("SP_PTNT_BRTH_YR") != col("SHS_PTNT_BRTH_YR"))
)

# Count the number of unique patients with the problem
problem_patient_count = df_filtered.select("PATIENT_ID").distinct().count()

# Find the patients with the most number of records with different non-null values
patient_record_counts = df_filtered.groupBy("PATIENT_ID").count().orderBy(col("count").desc())

# Get the patient(s) with the most number of records
max_record_count = patient_record_counts.first()["count"]
patients_with_max_records = patient_record_counts.filter(col("count") == max_record_count).select("PATIENT_ID")

# Show records for the patients who have the most number of records with different non-null values
df_problem_patients = df_filtered.join(patients_with_max_records, on="PATIENT_ID")

# Display results
print(f"Number of patients with the problem: {problem_patient_count}")
df_problem_patients.display()

In [0]:
# Count the number of unique PATIENT_ID that have multiple not null BRTH_YR values

# Filter rows where BRTH_YR is not null
df_filtered = df.filter(col("BRTH_YR").isNotNull())

# Select PATIENT_ID where BRTH_YR column has multiple values
df_multiple_brth_yr = df_filtered.groupBy("PATIENT_ID").agg(
    countDistinct("BRTH_YR").alias("distinct_brth_yr_count")
).filter(col("distinct_brth_yr_count") > 1).select("PATIENT_ID")

# Count the number of unique PATIENT_IDs with multiple BRTH_YR values
patient_count = df_multiple_brth_yr.distinct().count()

# Display results
print(f"Number of patients with multiple BRTH_YR values: {patient_count}")
df_multiple_brth_yr.display()

### Calculate Weekly IUs

In [0]:
# Summarize IU by patient, product, week
from pyspark.sql.functions import date_trunc, to_date

# derive week from date
columns = ['BH_ID', 'PATIENT_ID', 'PRD_NM', 'SHP_DT', 'PRPHY', 'IU']
df_selected = df.select(*columns)
df_with_week = df_selected.withColumn(
    "SHP_WK",
    to_date(date_trunc("week", F.col("SHP_DT")))
)

# calculate weekly IU
df_iu = df_with_week.filter(F.col("PRD_NM").isNotNull()).groupBy(
    "BH_ID",
    "PATIENT_ID",
    "PRD_NM",
    "PRPHY",
    "SHP_WK"
).agg(
    F.sum("IU").alias("IU")
).withColumn('PRPHY', F.coalesce(F.col('PRPHY'), F.lit('1'))) \
.orderBy(F.col('PATIENT_ID').desc(), F.col('PRD_NM').desc(), F.col('SHP_WK').asc())

df_iu.printSchema()
df_iu.display()

### Create Product_Prophy column

In [0]:
## Create PRD_PRPHY
# 1. Create the new column PRD_PRPHY by combining PRD_NM and PRPHY
df_iu_2 = df_iu.withColumn(
    "PRD_PRPHY",
    concat(col("PRD_NM"), lit("_"), col("PRPHY"))
)

## Create WK_SINCE_LST_SHP
# 2. Define a window specification
window_spec = Window.partitionBy("BH_ID", "PATIENT_ID").orderBy("SHP_WK")

# 3. Calculate the previous SHP_WK date
df_iu_3 = df_iu_2.withColumn("PREV_SHP_WK", 
                                        expr("lag(SHP_WK) over (partition by BH_ID, PATIENT_ID order by SHP_WK)"))

# 4. Calculate the number of weeks since the last shipment
df_iu_4 = df_iu_3.withColumn(
    "WK_SINCE_LST_SHP",
    (datediff(col("SHP_WK"), col("PREV_SHP_WK")) / 7).cast("integer")
).fillna(0)
# Show the updated DataFrame
df_iu_4.display()

In [0]:
# Select the columns
selected_columns_df = df_iu_4.select("BH_ID", "PATIENT_ID", "PRD_PRPHY", "WK_SINCE_LST_SHP")

# Group by BH_ID and PATIENT_ID and aggregate PRD_PRPHY
df_iu_5 = selected_columns_df.groupBy("BH_ID", "PATIENT_ID") \
    .agg(
        concat_ws(", ", collect_list("PRD_PRPHY")).alias("PRD_PRPHY_ALL"),
        # You can also aggregate WK_SINCE_LST_SHP if needed, for example, using max or sum
        # Here we will take the maximum WK_SINCE_LST_SHP for demonstration
        # max("WK_SINCE_LST_SHP").alias("WK_SINCE_LST_SHP")
    )

# result DataFrame
df_iu_5.display()

### Show HCP's Rx history

In [0]:
df_iu_5.filter(col('BH_ID') == 'BH10200613').display()

In [0]:
# Patients who used JIVI
df_iu_jivi = df_iu_5.filter(col("PRD_PRPHY_ALL").contains("JIVI"))
df_iu_jivi.display()


In [0]:
print("unique BH_ID:", df_iu_jivi.select("BH_ID").distinct().count())
print("unique PATIENT_ID:", df_iu_jivi.select("PATIENT_ID").distinct().count())

In [0]:
# All PRD_PRPHY records for JIVI patients
# Select relevant columns from df_iu_5
df_iu_5_selected = df_iu_5.select("BH_ID", "PATIENT_ID", "PRD_PRPHY_ALL")

# Step 2: Create a new DataFrame df_iu_jivi_pt by joining df_iu_jivi with df_iu_5
df_iu_jivi_pt = df_iu_jivi.alias("jivi").join(
    df_iu_5_selected.alias("iu_5"),
    on="PATIENT_ID",
    how="inner"
).select(
    "iu_5.BH_ID",
    "jivi.PATIENT_ID",
    "iu_5.PRD_PRPHY_ALL"
)

# Step 3: Show the resulting DataFrame
display(df_iu_jivi_pt)

In [0]:
print("unique BH_ID:", df_iu_jivi_pt.select("BH_ID").distinct().count())
print("unique PATIENT_ID:", df_iu_jivi_pt.select("PATIENT_ID").distinct().count())