In [0]:
# Importing packages
from pyspark.sql import functions as F # Importing functions from pyspark.sql
from pyspark.sql.functions import *
from pyspark.sql import Window
import pandas as pd

In [0]:
%run "/Workspace/Repos/yuan.niu@bayer.com/heme_new_writer_models_dev_repo/02_data_processing/helper_functions"

In [0]:
%run "../00_config/set-up"

In [0]:
overlap_raw_data = spark.sql("SELECT * FROM heme_data.overlap_rx")
print('Row count: ', overlap_raw_data.count(), 'Column Count: ', len(overlap_raw_data.columns))

In [0]:
# Converting the original overlap data spark dataframe to pandas dataframe
""" Convert DecimalType columns to float to avoid UserWarning: The conversion of DecimalType columns is inefficient and may take a long time. Column names: [IU, PTD_FNL_CLM_AMT] If those columns are not necessary, you may consider dropping them or converting to primitive types before the conversion."""
overlap_raw_data = overlap_raw_data.withColumn("IU", overlap_raw_data["IU"].cast("float"))

In [0]:
# Selecting columns which are useful for data analysis. Generally we exclude columns which have a high percentage of null values.
col_shortlist = [
  'BH_ID',
  'PATIENT_ID',
  'SHP_DT',
  'WINNING_PATIENT_ID',
  'SP_SOURCE_PTNT_ID',
  'SHS_SOURCE_PTNT_ID',
  'PRD_NM',
  'DRUG_NM',
  'PRD_GRP_NM',
  'MKT_NM',
  'DRUG_STRG_QTY',
  'IU',
  'SRC_SP',
  'SOURCE_TYPE',
  'BRTH_YR',
  'SP_PTNT_BRTH_YR',
  'PTNT_AGE_GRP',
  'PTNT_WGT',
  'PTNT_GENDER',
  'PTNT_GNDR',
  'ETHNC_CD',
  'EPSDC',
  'SEVRTY',
  'PRPHY',
  'INSN_ID',
  'INSN_NM',
  'AFFL_TYP',
  'SPCL_CD',
  'DATE_ID',
  'MTH_ID',
  'PAYR_NM',
  'PAYR_TYP',
  'PAY_TYP_CD',
  'COPAY_AMT',
  'TOTL_PAID_AMT',
  'CLAIM_TYP',
  'PRESCRIBED_UNIT',
  'DAYS_SUPPLY_CNT',
  'REFILL_AUTHORIZED_CD',
  'FILL_DT',
  'ELIG_DT',       
                  
]

In [0]:
# take a subset of columns of overlap data based on columns shortlist
overlap_subset = overlap_raw_data.select(col_shortlist)

In [0]:
## create a YYYY-MM column
overlap_subset = overlap_subset.withColumn("SHP_YR_MO",F.date_format("SHP_DT", "yyyy-MM"))

In [0]:
display(overlap_subset.limit(20))

#### PRH_A1: Number of IU in total

In [0]:
sum_col = 'IU'
overlap_hcp_ptnt_iu_sum = (overlap_subset
                                      .groupby("BH_ID","SHP_YR_MO",'PATIENT_ID')
                                      .agg(sum(sum_col).alias( f'OVP_{sum_col}_TOTAL')))

#### PRH_A2: Number of IU by brand and use type

In [0]:
overlap_subset = overlap_subset.withColumn('PRPHY_CD', when(overlap_subset.PRPHY == 0, 0).otherwise(1))

In [0]:
grp_by_col = ['PRD_NM','PRPHY_CD']
sum_col = 'IU'
overlap_hcp_ptnt_groupby_sum = (overlap_subset
                                      .groupby("BH_ID","SHP_YR_MO",'PATIENT_ID',*grp_by_col)
                                      .agg(sum(sum_col).alias( f'OVP_{sum_col}_TOTAL')))

In [0]:
brand_list = ['KOVALTRY','KOGENATE','JIVI','HEMLIBRA','ALTUVIIIO']
overlap_hcp_ptnt_groupby_sum_brand = overlap_hcp_ptnt_groupby_sum.filter(col('PRD_NM').isin(brand_list))

In [0]:
display(overlap_hcp_ptnt_groupby_sum_brand.filter(col('PATIENT_ID')=="E69C1764-64EB-44AE-ADA2-9669FFFFAF69"))

In [0]:
def long_to_wide(df, group_cols, pivot_col, agg_col, agg_func):
    """
    group_cols: groupby columns

    """
    df_pivot = (df.groupBy(group_cols)
                  .pivot(pivot_col)
                  .agg({agg_col: agg_func})).cache()
    # Rename the columns to add the category names to column names

    renamed_columns = [col(c).alias(f"{agg_col}_{c}") if c not in group_cols else col(c)
                       for c in df_pivot.columns]

    df_pivot = df_pivot.select(
    *renamed_columns).cache()

   
    return df_pivot

In [0]:
overlap_hcp_ptnt_groupby_sum_brand = overlap_hcp_ptnt_groupby_sum_brand.withColumn("PRD_NM_PRPHY_CD",concat(col("PRD_NM"), lit("_"), col("PRPHY_CD")))

In [0]:
group_cols = ["BH_ID", "SHP_YR_MO",   'PATIENT_ID']
pivot_col = "PRD_NM_PRPHY_CD"
agg_col = 'OVP_IU_TOTAL'
agg_func = "first"
overlap_hcp_ptnt_groupby_sum_brand_type = long_to_wide(overlap_hcp_ptnt_groupby_sum_brand, group_cols, pivot_col, agg_col, agg_func)

In [0]:
overlap_hcp_ptnt_groupby_sum_brand_type = overlap_hcp_ptnt_groupby_sum_brand_type.fillna(0)

In [0]:
display(overlap_hcp_ptnt_groupby_sum_brand_type.filter(col('PATIENT_ID')=="E69C1764-64EB-44AE-ADA2-9669FFFFAF69"))