In [0]:
from pyspark.sql import functions as F # Importing functions from pyspark.sql
from pyspark.sql.functions import *
from pyspark.sql import Window
import pandas as pd
from pyspark.sql.functions import col, concat, lit

In [0]:
%run "/Workspace/Repos/yuan.niu@bayer.com/heme_new_writer_models_dev_repo/02_data_processing/helper_functions"

In [0]:
%run "../00_config/set-up"

### Load data

In [0]:
MDM_STG_SPCLTY_XREF = get_data_snowflake(
f"""
  SELECT * FROM PHCDW.PHCDW_STG.MDM_STG_SPCLTY_XREF
"""
)
print(MDM_STG_SPCLTY_XREF.count(), len(MDM_STG_SPCLTY_XREF.columns))

In [0]:
hcp_demo = spark.sql("SELECT * FROM heme_data.hcp_demo")
print(hcp_demo.count(), len(hcp_demo.columns))

In [0]:
display(hcp_demo)

In [0]:
hcp_demo.select('SPECIALTY_1').distinct().show()

In [0]:
display(hcp_demo.limit(20))

In [0]:
hcp_affl = spark.sql("SELECT * FROM heme_data.hcp_affl")

In [0]:
hcp_affl.filter(col('HMT_BU_RANKING')!=col('ENTERPRISE_RANK')).select("BH_ID","BHO_ID",'HMT_BU_RANKING','ENTERPRISE_RANK').display()

In [0]:
hcp_hco_count = hcp_affl.groupby("BH_ID").agg(countDistinct("BHO_ID").alias("HCO_CNT"))

In [0]:
acct_demo = spark.sql("SELECT * FROM heme_data.acct_demo")

In [0]:
acct_demo = get_data_snowflake(
f"""
  SELECT * FROM PHCDW.PHCDW_STG.MDM_STG_ACCT_PRFL
"""
)
print(acct_demo.count(), len(acct_demo.columns))

In [0]:
overlap_raw_data = spark.sql("SELECT * FROM heme_data.overlap_rx")
# Converting the original overlap data spark dataframe to pandas dataframe
""" Convert DecimalType columns to float to avoid UserWarning: The conversion of DecimalType columns is inefficient and may take a long time. Column names: [IU, PTD_FNL_CLM_AMT] If those columns are not necessary, you may consider dropping them or converting to primitive types before the conversion."""
overlap_raw_data = overlap_raw_data.withColumn("IU", overlap_raw_data["IU"].cast("float"))
# creat yyyy-MM column
overlap_raw_data = overlap_raw_data.withColumn("SHP_YR_MO",F.date_format("SHP_DT", "yyyy-MM"))
print('Row count: ', overlap_raw_data.count(), 'Column Count: ', len(overlap_raw_data.columns))

In [0]:
overlap_hcp_hco_count = overlap_raw_data.select("BH_ID").distinct().join(hcp_hco_count, on = ['BH_ID'],how='left')

In [0]:
overlap_hcp_spec = overlap_raw_data.select("BH_ID").distinct().join(hcp_demo.select('BH_ID',"SPECIALTY_1",'SPECIALTY_2','SPECIALTY_3'), on = ['BH_ID'],how='left')


In [0]:
MDM_STG_SPCLTY_XREF.filter(col("SPECIALTY_CODE").isin(['HO','ON','PHO','PD','GE','ADU','HEP'])).display()

In [0]:
resonable_spec = ['HO','ON','PHO','HEM','HEP']
#ADU ->PCP,AC

In [0]:
'GP','FM','PD','IM'

In [0]:
overlap_hcp_spec.groupby("SPECIALTY_1").agg(countDistinct("BH_ID").alias("HCP_CNT")).display()

In [0]:
overlap_hcp_hco_count.approxQuantile("HCO_CNT", [0.25, 0.5, 0.75], 0.01)

In [0]:
overlap_hcp_hco_count.select("HCO_CNT").describe().show()

In [0]:
print("% of hcps are associated with multiple HCOs: ", overlap_hcp_hco_count.filter(col("HCO_CNT")>1).count()/overlap_hcp_hco_count.count())

In [0]:
def long_to_wide(df, group_cols, pivot_col, agg_col, agg_func):
    """
    group_cols: groupby columns

    """
    df_pivot = (df.groupBy(group_cols)
                  .pivot(pivot_col)
                  .agg({agg_col: agg_func})).cache()
    # Rename the columns to add the category names to column names

    renamed_columns = [col(c).alias(f"{agg_col}_{c}") if c not in group_cols else col(c)
                       for c in df_pivot.columns]

    df_pivot = df_pivot.select(
    *renamed_columns).cache()

   
    return df_pivot

In [0]:
def calc_cumulative_sums_excl_current(df, id_col, date_col, value_col):
    """
    Calculate cumulative sums over past 1, 3, 6, 9, and 12 records for each BAYER_HCP_ID, excluding the current record.

    Parameters:
    df (DataFrame): The input DataFrame containing the data.
    id_col (str): The column name used for partitioning the data (e.g., 'BAYER_HCP_ID').
    date_col (str): The column name used for ordering the data within each partition (e.g., 'CALL_MONTH').
    value_col (str): The column name whose values will be summed (e.g., 'call_cnt').

    Returns:
    DataFrame: The DataFrame with additional columns for cumulative sums over past 1, 3, 6, 9, and 12 records, excluding the current record.
    """

    window_specs = {
        #f"{value_col}_cumsum_last_1M": Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-1, -1),
        f"{value_col}_cumsum_last_3M": Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-3, -1),
        f"{value_col}_cumsum_last_6M": Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-6, -1),
        f"{value_col}_cumsum_last_9M": Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-9, -1),
        f"{value_col}_cumsum_last_12M": Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-12, -1)
    }
    
    for sum_col, window_spec in window_specs.items():
        df = df.withColumn(sum_col, F.sum(value_col).over(window_spec))
    
    return df

In [0]:
display(overlap_raw_data.limit(20))

### HCP Specialty:
1) Hematology (HEM, HEP), 2) Hem/Onc (HO, ON), 3) Pediatric Hem/Onc (PHO), 4) Internal Medicine (Family medicine, general practitioner) - this group may emerge as needed after we check the Speicalty_2 and _3; 5) Others - not in above 

In [0]:
hcp_specialty = hcp_demo.select("BH_ID",'SPECIALTY_1').distinct()

In [0]:
hcp_specialty.groupby("BH_ID").agg(countDistinct("SPECIALTY_1").alias("HCP_SPECIALTY_COUNT")).filter(col("HCP_SPECIALTY_COUNT") > 1).show()

#### HCP Affiliation with institution flag

In [0]:
hcp_aff_wi_insn = (overlap_raw_data
                   .filter(col("INSN_NM").isNotNull())
                   .withColumn("AFFL_WI_INSN",lit(1))
                   .select("BH_ID","SHP_YR_MO", 'AFFL_WI_INSN').distinct()).cache()

In [0]:
display(hcp_aff_wi_insn.limit(20))

#### Size of HCP's affiliated account

In [0]:
hco_account_size = (overlap_raw_data.groupby("INSN_ID")
                    .agg(countDistinct("BH_ID").alias("HCO_ACCOUNT_SIZE"))).cache()

### Total unique patient count per HCP 

In [0]:
hcp_ptnt_cnt = (overlap_raw_data
                .filter(col("SHP_DT")>='2022-01-01')
                .filter(col("SHP_DT")<='2022-12-31')
                .groupby("BH_ID").agg(countDistinct("PATIENT_ID").alias("PTNT_CNT")))

### Number of patients by age buckets per HCP

In [0]:
## Creating the age bucket
overlap_raw_data = (overlap_raw_data
                .withColumn('Age', ceil(datediff(current_date(), overlap_raw_data.BRTH_YR)/365))
                .withColumn(
    "patient_age",
    when(col("Age") < 12, 'PAT_AGE_0_11')
    .when((col("Age") >= 12) & (col("Age") <= 17), 'PAT_AGE_12_17')
    .when((col("Age") >= 18) & (col("Age") <= 24), 'PAT_AGE_18_24')
    .when((col("Age") >= 25) & (col("Age") <= 65), 'PAT_AGE_25_65')
    .when(col("Age") > 65, 'PAT_AGE_65+')
    .otherwise('PAT_AGE_NA'))).cache()

In [0]:
hcp_ptnt_age = (overlap_raw_data
                .filter(col("SHP_DT")>='2022-01-01')
                .filter(col("SHP_DT")<='2022-12-31')
                .groupby("BH_ID","patient_age").agg(countDistinct("PATIENT_ID").alias("PTNT_CNT")))

In [0]:
## convert to wide format
hcp_ptnt_age_wide = (hcp_ptnt_age.groupBy("BH_ID").pivot("patient_age").agg(F.first("PTNT_CNT")))
hcp_ptnt_age_wide = hcp_ptnt_age_wide.fillna(0)

In [0]:
for col in hcp_ptnt_age_wide.columns[1:]:  # Skip the first column (BH_ID)
    new_col_name = f"{col}_CNT"  # Create new column name
    hcp_ptnt_age_wide = hcp_ptnt_age_wide.withColumnRenamed(col, new_col_name)

In [0]:
hcp_ptnt_age_wide.show()

### Per HCP unqiue count of patient by Brand and Use type

In [0]:

from pyspark.sql.functions import date_trunc, to_date

columns = ['BH_ID', 'PATIENT_ID', 'PRD_NM', 'SHP_DT', 'PRPHY', 'IU',"SHP_YR_MO"]
df_selected = overlap_raw_data.select(*columns)


# calculate month unique patient count by brand and type
patient_cnt_by_brand_use = (df_selected.filter(F.col("PRD_NM").isNotNull()).select(
    "BH_ID",
    "PRD_NM",
    "PRPHY",
    "SHP_YR_MO",
    "PATIENT_ID").distinct()
.withColumn('PRPHY', F.coalesce(F.col('PRPHY'), F.lit('1'))) \
.orderBy(F.col('BH_ID').desc(), F.col('PRD_NM').desc(), F.col('SHP_YR_MO').asc()))

In [0]:
## filter on selected brand
brand_list = ['KOVALTRY','KOGENATE','JIVI','HEMLIBRA','ALTUVIIIO']
patient_cnt_by_brand_use_filtered = (patient_cnt_by_brand_use.filter(col("PRD_NM").isin(brand_list))
                                     .withColumn("PRD_NM_PRPHY_CD",concat(col("PRD_NM"), lit("_"), col("PRPHY")))
                                     .select("BH_ID",'SHP_YR_MO','PRD_NM_PRPHY_CD','PATIENT_ID').distinct()
                                     .orderBy(F.col('BH_ID').desc(), F.col('SHP_YR_MO').desc())
                                     )  

In [0]:
patient_cnt_by_brand_use_filtered.count()

In [0]:
## filter on  brands and aggregate by type
other_shl = ['ADVATE', 'RECOMBINATE', 'XYNTHA','NUWIQ','NOVOEIGHT','AFSTYLA']
other_ehl = ['ELOCATE', 'ADYNOVATE', 'ESPEROCT']
patient_cnt_by_category = (patient_cnt_by_brand_use.withColumn("PROD_TYPE",
                                                                when(col("PRD_NM").isin(other_shl), "OTHER_SHL").when(col("PRD_NM").isin(other_ehl), "OTHER_EHL").otherwise("OTHERS")))
patient_cnt_by_category = (patient_cnt_by_category.filter(col("PROD_TYPE")!='OTHERS')
                           .withColumn("PRD_NM_PRPHY_CD",concat(col("PROD_TYPE"), lit("_"), col("PRPHY")))
                           .select("BH_ID",'SHP_YR_MO','PRD_NM_PRPHY_CD','PATIENT_ID'))

In [0]:
patient_cnt_by_category.count()

In [0]:
patient_cnt_by_brand_use_all = patient_cnt_by_category.union(patient_cnt_by_brand_use_filtered)

In [0]:
patient_cnt_by_brand_use_all.count() #288781

In [0]:
display(patient_cnt_by_brand_use_all.filter(col("BH_ID")=='FF3464A7-5'))

In [0]:
unique_month = overlap_raw_data.select(col("SHP_YR_MO").alias("COHORT_MONTH")).filter(col("COHORT_MONTH")>='2023-01').distinct()
hcp_month_pair = (patient_cnt_by_brand_use_all.select("BH_ID").distinct()
                  .crossJoin(unique_month)
                  .orderBy('BH_ID','COHORT_MONTH'))

In [0]:
unique_month.count()

In [0]:
## explode the dataframe
patient_cnt_by_brand_use_all_explode = (
  hcp_month_pair.join(patient_cnt_by_brand_use_all, on = ['BH_ID'], how = 'left')
  .orderBy('BH_ID','COHORT_MONTH')
  .fillna('NA')
  )

In [0]:
patient_cnt_by_brand_use_all_explode = (patient_cnt_by_brand_use_all_explode.withColumn("month_diff",F.months_between(F.col("COHORT_MONTH"), F.col("SHP_YR_MO")))
                                         .withColumn("ROLLING_WIN", F.when((F.col("month_diff")<=1)&(F.col("month_diff")>0), 'LAST_1M')
                                                                    .when((F.col("month_diff")<=3)&(F.col("month_diff")>0), 'LAST_3M')
                                                                     .when((col("month_diff")<=6)&(col("month_diff")>0), 'LAST_6M')
                                                                     .when((col("month_diff")<=9)&(col("month_diff")>0), 'LAST_9M')
                                                                     .when((col("month_diff")<=12)&(col("month_diff")>0), 'LAST_12M').otherwise('NA') ))

In [0]:
patient_cnt_by_brand_use_all_explode.count() #6641963

In [0]:
patient_share_by_brand_use_all_explode = (patient_cnt_by_brand_use_all_explode
                                          .filter(col("ROLLING_WIN")!='NA')
                                          .groupby("BH_ID", "COHORT_MONTH", "PRD_NM_PRPHY_CD",'ROLLING_WIN').agg(countDistinct("PATIENT_ID").alias('PTNT_CNT')))

In [0]:
patient_row_sum_by_brand_use_all_explode = (patient_cnt_by_brand_use_all_explode
                                          .filter(col("ROLLING_WIN")!='NA')
                                          .groupby("BH_ID", "COHORT_MONTH",'ROLLING_WIN').agg(countDistinct("PATIENT_ID").alias('ROW_TOTAL')))

In [0]:
patient_share_by_brand_use_all_explode = (patient_share_by_brand_use_all_explode.join(patient_row_sum_by_brand_use_all_explode,
                                                                                     on = ["BH_ID", "COHORT_MONTH",'ROLLING_WIN'],
                                                                                     how = 'left'))
patient_share_by_brand_use_all_explode = (patient_share_by_brand_use_all_explode.withColumn("PTNT_SHARE", col("PTNT_CNT")/col("ROW_TOTAL"))
                                                                               .withColumn("PRD_NM_PRPHY_CD_WIN",concat(col("PRD_NM_PRPHY_CD"),
                                                                                                                         lit("_"), 
                                                                                                                         col("ROLLING_WIN"),
                                                                                                                         lit("_"), 
                                                                                                                         lit("PTNT_SHARE"))))

In [0]:
display(patient_share_by_brand_use_all_explode.filter(col("BH_ID")=='FF3464A7-5'))

In [0]:
## change to wide format
group_cols = ["BH_ID", "COHORT_MONTH"]
pivot_col = "PRD_NM_PRPHY_CD_WIN"
agg_col = 'PTNT_SHARE'
agg_func = "first"
patient_share_by_brand_use_all_explode = long_to_wide(patient_share_by_brand_use_all_explode, group_cols, pivot_col, agg_col, agg_func)
patient_share_by_brand_use_all_explode = patient_share_by_brand_use_all_explode.fillna(0)

In [0]:
## eplode to fill the missing corhort months
patient_share_by_brand_use_all_explode = (
  hcp_month_pair.join(patient_share_by_brand_use_all_explode, on = ['BH_ID','COHORT_MONTH'], how = 'left')
  .orderBy('BH_ID','COHORT_MONTH')
  .fillna(0)
  )

In [0]:
patient_share_by_brand_use_all_explode.count()

In [0]:
display(patient_share_by_brand_use_all_explode.filter(col("BH_ID")=='FF3464A7-5'))

### Merge all features together

In [0]:
patient_share_by_brand_use_all_explode.select('BH_ID').distinct().count()

In [0]:
patient_share_by_brand_use_all_explode.count()

In [0]:
hcp_aff_wi_insn.select('BH_ID').distinct().count()


In [0]:
hcp_aff_wi_insn.count()


In [0]:
hcp_ptnt_cnt.select('BH_ID').distinct().count()

In [0]:
hcp_ptnt_age_wide.select('BH_ID').distinct().count()

In [0]:
all_hcp_characteristics = hcp_ptnt_cnt.join(hcp_ptnt_age_wide.drop("PAT_AGE_NA_CNT"), on = 'BH_ID', how = 'outer')

In [0]:
all_hcp_characteristics.count()

In [0]:
all_hcp_monthly_features = (patient_share_by_brand_use_all_explode.join(hcp_aff_wi_insn.select('BH_ID',col('SHP_YR_MO').alias("COHORT_MONTH"), 'AFFL_WI_INSN'), on = ['BH_ID',"COHORT_MONTH"], how = 'outer')
                            .orderBy('BH_ID','COHORT_MONTH')
                            )

In [0]:
all_hcp_monthly_features.count()

In [0]:
all_hcp_monthly_features.filter(col("PTNT_SHARE_ALTUVIIIO_0_LAST_12M_PTNT_SHARE").isNull()).count()

In [0]:
all_hcp_monthly_features.filter(col("AFFL_WI_INSN").isNull()).count()

In [0]:
all_hcp_monthly_features.display()

In [0]:
all_hcp_characteristics.columns

In [0]:
# saving the hcp monthly features to the Hivestore
save_sdf(all_hcp_monthly_features, 'jivi_new_writer_model', 'all_hcp_monthly_features')

In [0]:
# saving the hcp features to the Hivestore
save_sdf(all_hcp_characteristics, 'jivi_new_writer_model', 'all_hcp_characteristics')


### sum of unique patient cnt


In [0]:

from pyspark.sql.functions import date_trunc, to_date

columns = ['BH_ID', 'PATIENT_ID', 'PRD_NM', 'SHP_DT', 'PRPHY', 'IU',"SHP_YR_MO"]
df_selected = overlap_raw_data.select(*columns)


# calculate month unique patient count by brand and type
patient_cnt_by_brand_use = df_selected.filter(F.col("PRD_NM").isNotNull()).groupBy(
    "BH_ID",
    "PRD_NM",
    "PRPHY",
    "SHP_YR_MO"
).agg(
    countDistinct("PATIENT_ID").alias("PATIENT_CNT")
).withColumn('PRPHY', F.coalesce(F.col('PRPHY'), F.lit('1'))) \
.orderBy(F.col('BH_ID').desc(), F.col('PRD_NM').desc(), F.col('SHP_YR_MO').asc())


In [0]:
## filter on selected brand
brand_list = ['KOVALTRY','KOGENATE','JIVI','HEMLIBRA','ALTUVIIIO']
patient_cnt_by_brand_use_filtered = (patient_cnt_by_brand_use.filter(col("PRD_NM").isin(brand_list))
                                     .withColumn("PRD_NM_PRPHY_CD",concat(col("PRD_NM"), lit("_"), col("PRPHY")))
                                     .select("BH_ID",'SHP_YR_MO','PRD_NM_PRPHY_CD','PATIENT_CNT')
                                     .orderBy(F.col('BH_ID').desc(), F.col('SHP_YR_MO').desc())
                                     )                  

In [0]:
patient_cnt_by_brand_use_filtered.count()

In [0]:
## filter on  brands and aggregate by type
other_shl = ['ADVATE', 'RECOMBINATE', 'XYNTHA','NUWIQ','NOVOEIGHT','AFSTYLA']
other_ehl = ['ELOCATE', 'ADYNOVATE', 'ESPEROCT']
patient_cnt_by_category = (patient_cnt_by_brand_use.withColumn("PROD_TYPE",
                                                                when(col("PRD_NM").isin(other_shl), "OTHER_SHL").when(col("PRD_NM").isin(other_ehl), "OTHER_EHL").otherwise("OTHERS"))
                                                    .groupBy("BH_ID","PROD_TYPE", "PRPHY","SHP_YR_MO")
                                                    .agg(sum("PATIENT_CNT").alias("PATIENT_CNT"))
                                                    .orderBy(F.col('BH_ID').desc(), F.col('SHP_YR_MO').desc()))
patient_cnt_by_category = (patient_cnt_by_category.filter(col("PROD_TYPE")!='OTHERS')
                           .withColumn("PRD_NM_PRPHY_CD",concat(col("PROD_TYPE"), lit("_"), col("PRPHY")))
                           .select("BH_ID",'SHP_YR_MO','PRD_NM_PRPHY_CD','PATIENT_CNT'))

In [0]:
patient_cnt_by_category.count()

In [0]:
patient_cnt_by_brand_use_all = patient_cnt_by_category.union(patient_cnt_by_brand_use_filtered)

In [0]:
patient_cnt_by_brand_use_all.count()

In [0]:
patient_cnt_by_brand_use_all.select('BH_ID','SHP_YR_MO').distinct().count()

In [0]:
## change to wide format
group_cols = ["BH_ID", "SHP_YR_MO"]
pivot_col = "PRD_NM_PRPHY_CD"
agg_col = 'PATIENT_CNT'
agg_func = "first"
patient_cnt_by_brand_use_all = long_to_wide(patient_cnt_by_brand_use_all, group_cols, pivot_col, agg_col, agg_func)
patient_cnt_by_brand_use_all = patient_cnt_by_brand_use_all.fillna(0)

In [0]:
patient_cnt_by_brand_use_all.count()

In [0]:
## explode the dataframe
patient_cnt_by_brand_use_all_explode = (
  hcp_week_pair.join(patient_cnt_by_brand_use_all, on = ['BH_ID','SHP_YR_MO'], how = 'left')
  .orderBy('BH_ID','SHP_YR_MO')
  .fillna(0)
  )

In [0]:
## calculate the cuculumative sum in the last 3,6,9,12 months
patient_cnt_by_brand_use_all_explode_accu = calc_cumulative_sums_excl_current(patient_cnt_by_brand_use_all_explode, 'BH_ID', 'SHP_YR_MO', patient_cnt_by_brand_use_all_explode.columns[2])
for col in patient_cnt_by_brand_use_all_explode.columns[3:]:
  patient_cnt_by_brand_use_all_explode_accu = calc_cumulative_sums_excl_current(patient_cnt_by_brand_use_all_explode_accu, 'BH_ID', 'SHP_YR_MO', col)

In [0]:
cumsum_cols_3M = [x for x in patient_cnt_by_brand_use_all_explode_accu.columns if 'cumsum' and '3M' in x]
cumsum_cols_6M = [x for x in patient_cnt_by_brand_use_all_explode_accu.columns if 'cumsum' and '6M' in x]
cumsum_cols_9M = [x for x in patient_cnt_by_brand_use_all_explode_accu.columns if 'cumsum' and '9M' in x]
cumsum_cols_12M = [x for x in patient_cnt_by_brand_use_all_explode_accu.columns if 'cumsum' and '12M' in x]

In [0]:
from functools import reduce
from pyspark.sql import functions as F

patient_cnt_by_brand_use_all_explode_accu = (patient_cnt_by_brand_use_all_explode_accu
.withColumn("RowTotal_3M",reduce(lambda a, b: a + b,[F.col(c) for c in cumsum_cols_3M]))
.withColumn("RowTotal_6M",reduce(lambda a, b: a + b,[F.col(c) for c in cumsum_cols_6M]))
.withColumn("RowTotal_9M",reduce(lambda a, b: a + b,[F.col(c) for c in cumsum_cols_9M]))
.withColumn("RowTotal_12M",reduce(lambda a, b: a + b,[F.col(c) for c in cumsum_cols_12M]))).cache()

## create patient share: patient cnt per brand and use type/ row total


In [0]:
from pyspark.sql.functions import col
display(patient_cnt_by_brand_use_all_explode_accu.filter(col('BH_ID') == 'FF3464A7-5'))

In [0]:
# Create a new DataFrame with shares
patient_cnt_by_brand_use_all_explode_accu = patient_cnt_by_brand_use_all_explode_accu.select('BH_ID','SHP_YR_MO',
    *[F.col(col) / patient_cnt_by_brand_use_all_explode_accu['RowTotal_3M'] for col in cumsum_cols_3M],
    *[F.col(col) / patient_cnt_by_brand_use_all_explode_accu['RowTotal_6M'] for col in cumsum_cols_6M],
    *[F.col(col) / patient_cnt_by_brand_use_all_explode_accu['RowTotal_9M'] for col in cumsum_cols_9M],
    *[F.col(col) / patient_cnt_by_brand_use_all_explode_accu['RowTotal_12M'] for col in cumsum_cols_12M])


In [0]:
for col in patient_cnt_by_brand_use_all_explode_accu.columns[2:]:
    new_col_name = col.replace('CNT', 'SHARE').split(' / ')[0][1:]  # Replace 'CNT' with 'SHARE' and remove the "("
    patient_cnt_by_brand_use_all_explode_accu = patient_cnt_by_brand_use_all_explode_accu.withColumnRenamed(col, new_col_name)

In [0]:
display(patient_cnt_by_brand_use_all_explode_accu)

In [0]:
display(patient_cnt_by_brand_use_all_explode_accu.filter(col('BH_ID') == 'FF3464A7-5'))

In [0]:
display(patient_cnt_by_brand_use_filtered.limit(20))