In [0]:
%run "../00_config/set-up"

In [0]:
df = spark.sql("SELECT * FROM heme_data.overlap_rx")
print('Row count: ', df.count(), 'Column Count: ', len(df.columns))
df.printSchema()

### Calculate Weekly IUs

In [0]:
# Summarize IU by patient, product, week
from pyspark.sql.functions import date_trunc, to_date

# derive week from date
columns = ['BH_ID', 'PATIENT_ID', 'PRD_NM', 'SHP_DT', 'PRPHY', 'IU']
df_selected = df.select(*columns)
df_with_week = df_selected.withColumn(
    "SHP_WK",
    to_date(date_trunc("week", F.col("SHP_DT")))
)

# calculate weekly IU
df_iu = df_with_week.filter(F.col("PRD_NM").isNotNull()).groupBy(
    "PATIENT_ID",
    "PRD_NM",
    "PRPHY",
    "SHP_WK"
).agg(
    F.sum("IU").alias("IU")
).withColumn('PRPHY', F.coalesce(F.col('PRPHY'), F.lit('1'))) \
.orderBy(F.col('PATIENT_ID').desc(), F.col('PRD_NM').desc(), F.col('SHP_WK').asc())

df_iu.printSchema()
df_iu.display()

#### Weekly IU in Wide Format

In [0]:
# transform weekly iu data to wide format
df_iu_wide = (df_iu.withColumn("PRD_PRPHY",
    F.concat(
        F.lit("PRD_"),
        F.col("PRD_NM"),
        F.lit("_PRPHY_"),
        F.col("PRPHY"), 
    )).groupBy("PATIENT_ID", "SHP_WK").pivot("PRD_PRPHY").agg(F.first("IU"))
    .orderBy(F.col('PATIENT_ID').desc(), F.col('SHP_WK').asc())
    )

def clean_dataframe(df):
    # Remove columns with 100% null values
    non_null_columns = [col for col, dtype in df.dtypes if df.filter(F.col(col).isNotNull()).count() > 0]
    df_non_null = df.select(*non_null_columns)
    
    # Fill null values with 0 for all columns except PATIENT_ID and SHP_WK
    columns_to_fill = [col for col in df_non_null.columns if col not in ['PATIENT_ID', 'SHP_WK']]
    for col in columns_to_fill:
        df_non_null = df_non_null.withColumn(col, F.coalesce(F.col(col), F.lit(0)))
    
    # Convert decimal values to integers
    df_non_null = df_non_null.select(
    *[F.col(c).cast('int') if t == 'decimal(38,5)' else F.col(c) for c, t in df_non_null.dtypes]
    )
    return df_non_null

df_iu_wide_fmt = clean_dataframe(df_iu_wide)
df_iu_wide_fmt.printSchema()    
df_iu_wide_fmt.display()

#### Explode Patient Data to Consecutive Weeks

In [0]:
# Define the date range
start_date = "2022-01-03"
end_date = "2024-11-25"

# Create a DataFrame with all weeks in the date range
date_range_df = spark.sql(f"""
    SELECT explode(sequence(to_date('{start_date}'), to_date('{end_date}'), interval 1 week)) as SHP_WK
""")

# Get distinct PATIENT_IDs and cross join to get all combinations of PATIENT_ID and SHP_WK
patient_weeks_df = df_iu_wide_fmt.select("PATIENT_ID").distinct().crossJoin(date_range_df)

# Join with the original DataFrame
df_iu_wide_fmt_exp = patient_weeks_df.join(
    df_iu_wide_fmt,
    on=["PATIENT_ID", "SHP_WK"],
    how="left"
).fillna(0).orderBy('PATIENT_ID', 'SHP_WK')

df_iu_wide_fmt_exp.printSchema()
df_iu_wide_fmt_exp.display()

#### Summarize Past N Lags

In [0]:
def summarize_lags(df, n):
  """
    Summarize the value from previous n periods.

    :param df: Input DataFrame
    :param n: Number of periods for the summary stat
    :return: DataFrame with added summary stat
  """
  # Define the window specification
  window_spec = Window.partitionBy("PATIENT_ID").orderBy("SHP_WK").rowsBetween(-n, -1)

  # List of columns 
  prd_columns = [col for col in df.columns if col.startswith("PRD_")]

  # Calculate for each PRD_ column
  for col in prd_columns:
      df = df.withColumn(f"{col}_SUM_{n}WK", F.sum(col).over(window_spec)) \
              .withColumn(f"{col}_STD_{n}WK", F.stddev(col).over(window_spec))

  # Select the desired columns
  final_columns = ["PATIENT_ID", "SHP_WK"] + [f"{col}_SUM_{n}WK" for col in prd_columns] + [f"{col}_STD_{n}WK" for col in prd_columns]
  df_final = df.select(final_columns)

  return df_final

# Clculate 4 week sum
n = 4
df_result = summarize_lags(df_iu_wide_fmt_exp, n)
df_result.printSchema()
df_result.display()