In [0]:
# Importing packages
from pyspark.sql import functions as F  # Importing functions from pyspark.sql
from pyspark.sql import Window
import pandas as pd

In [0]:
%run "../00_config/set-up"

In [0]:
# Month and Date parameters for manual control
first_month = "2019-12"
last_month = "2024-11"

study_period_start_date = "2023-01-01"
study_period_start_month = "2023-01"
study_period_end_date = "2024-11-30"
study_period_end_month = "2024-11"

In [0]:
def calc_cumulative_sums_excl_current(df, id_col, date_col, value_col):
    """
    Calculate cumulative sums over past 1, 3, 6, 9, and 12 records for each BAYER_HCP_ID, excluding the current record.

    Parameters:
    df (DataFrame): The input DataFrame containing the data.
    id_col (str): The column name used for partitioning the data (e.g., 'BAYER_HCP_ID').
    date_col (str): The column name used for ordering the data within each partition (e.g., 'CALL_MONTH').
    value_col (str): The column name whose values will be summed (e.g., 'call_cnt').

    Returns:
    DataFrame: The DataFrame with additional columns for cumulative sums over past 1, 3, 6, 9, and 12 records, excluding the current record.
    """
    window_specs = {
        'CAL_calls_prev_01M_sum': Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-1, -1),
        'CAL_calls_prev_03M_sum': Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-3, -1),
        'CAL_calls_prev_06M_sum': Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-6, -1),
        'CAL_calls_prev_09M_sum': Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-9, -1),
        'CAL_calls_prev_12M_sum': Window.partitionBy(id_col).orderBy(date_col).rowsBetween(-12, -1)
    }
    
    for sum_col, window_spec in window_specs.items():
        df = df.withColumn(sum_col, F.sum(value_col).over(window_spec))
    
    return df

In [0]:
# Reading the calls activity data from Hivestore
call_data_sdf = spark.sql("SELECT * FROM heme_data.call_activity_data_preprocessed")
print(
    "Row count: ",
    call_data_sdf.count(),
    "Column Count: ",
    len(call_data_sdf.columns),
)

In [0]:
# Reading the HCP monthly target spine from Hivestore
hcp_target_spine_sdf = spark.sql("SELECT * FROM jivi_new_writer_model.hcp_target_spine")
print(
    "Row count: ",
    hcp_target_spine_sdf.count(),
    "Column Count: ",
    len(hcp_target_spine_sdf.columns),
)

In [0]:
hcp_monthly_calls_sdf = (
  call_data_sdf
  .groupBy("BAYER_HCP_ID", "CALL_MONTH")
  .count()
  .withColumnRenamed("count", "call_cnt")
  .orderBy("BAYER_HCP_ID", "CALL_MONTH")
)

In [0]:
display(hcp_monthly_calls_sdf)

In [0]:
all_months_sdf = (
  call_data_sdf
  .select("CALL_MONTH")
  .distinct()
  .orderBy("CALL_MONTH")
  # .withColumnRenamed("CALL_MONTH", "index_month")
)
print(all_months_sdf.count())

In [0]:
call_hcp_ids_sdf = call_data_sdf.select("BAYER_HCP_ID").distinct().orderBy("BAYER_HCP_ID")
print(call_hcp_ids_sdf.count())

In [0]:
call_hcp_monthly_spine = (
  call_hcp_ids_sdf
  .crossJoin(all_months_sdf)
  .orderBy("BAYER_HCP_ID", "CALL_MONTH")
)
print(call_hcp_monthly_spine.count())

In [0]:
# Perform a left join between call_hcp_monthly_spine and hcp_monthly_calls_sdf
monthly_hcp_calls_sdf = call_hcp_monthly_spine.join(
    hcp_monthly_calls_sdf,
    on=['BAYER_HCP_ID', 'CALL_MONTH'],
    how='left'
).orderBy("BAYER_HCP_ID", "CALL_MONTH")

print(monthly_hcp_calls_sdf.count())

In [0]:
# Fill missing values with 0 for the months where HCP does not have any calls, in other words zero calls
monthly_hcp_calls_sdf = monthly_hcp_calls_sdf.fillna({'call_cnt': 0})
print(monthly_hcp_calls_sdf.count())

In [0]:
monthly_hcp_calls_feats = calc_cumulative_sums_excl_current(
    monthly_hcp_calls_sdf,
    id_col='BAYER_HCP_ID',
    date_col='CALL_MONTH',
    value_col='call_cnt'
)
print(monthly_hcp_calls_feats.count())

In [0]:
# Checking rows which may contain nulls in the features
# null_containing_rows = monthly_hcp_calls_feats.filter(
#     (monthly_hcp_calls_feats['CALL_MONTH'] != '2019-12') & (
#     (monthly_hcp_calls_feats['CAL_calls_prev_01M_sum'].isNull() | 
#      monthly_hcp_calls_feats['CAL_calls_prev_03M_sum'].isNull() | 
#      monthly_hcp_calls_feats['CAL_calls_prev_06M_sum'].isNull() | 
#      monthly_hcp_calls_feats['CAL_calls_prev_09M_sum'].isNull() | 
#      monthly_hcp_calls_feats['CAL_calls_prev_12M_sum'].isNull())
#     )
# )
# display(null_containing_rows)

In [0]:
# Fill missing values with 0 for the first month in the data
monthly_hcp_calls_feats = monthly_hcp_calls_feats.fillna(0)
monthly_hcp_calls_feats = monthly_hcp_calls_feats.orderBy('BAYER_HCP_ID', 'CALL_MONTH')

In [0]:
display(monthly_hcp_calls_feats)

**Writing the calls actvity feature set to the Hivestore**

In [0]:
# saving the calls activity features to the Hivestore
save_sdf(monthly_hcp_calls_feats, 'jivi_new_writer_model', 'monthly_hcp_calls_feats')

**Checking the common HCP IDs between the HCP target spine and Call activity dataset**

In [0]:
hcps_in_target = set(hcp_target_spine_sdf.select('BH_ID').rdd.flatMap(lambda x: x).collect())
print(len(hcps_in_target))
hcps_in_calls = set(monthly_hcp_calls_feats.select('BAYER_HCP_ID').rdd.flatMap(lambda x: x).collect())
print(len(hcps_in_calls))
common_hcps = hcps_in_target.intersection(hcps_in_calls)
len(common_hcps)

In [0]:
common_hcp_target_spine_sdf = hcp_target_spine_sdf.filter(hcp_target_spine_sdf.BH_ID.isin(common_hcps))
print(common_hcp_target_spine_sdf.count())
# display(common_hcp_target_spine_sdf)

In [0]:
# Checking the number of HCPs with target class flag 1 and 0 for the HCP target spine dataset
display(hcp_target_spine_sdf.groupBy('JIVI_NEW_WRITER_FLG').agg(F.countDistinct('BH_ID').alias('distinct_BH_ID_cnt')))

In [0]:
# Checking the number of HCPs with target class flag 1 and 0 for the common HCPs found in HCP target spine and calls activity dataset
display(common_hcp_target_spine_sdf.groupBy('JIVI_NEW_WRITER_FLG').agg(F.countDistinct('BH_ID').alias('distinct_BH_ID_cnt')))

In [0]:
display(common_hcp_target_spine_sdf.agg(F.min('COHORT_MONTH').alias('min_COHORT_MONTH'), F.max('COHORT_MONTH').alias('max_COHORT_MONTH')))

In [0]:
common_hcp_calls_feats_sdf = monthly_hcp_calls_feats.filter(monthly_hcp_calls_feats.BAYER_HCP_ID.isin(common_hcps))
print(common_hcp_calls_feats_sdf.count())
display(common_hcp_calls_feats_sdf)

In [0]:
display(common_hcp_calls_feats_sdf.agg(F.min('CALL_MONTH').alias('min_CALL_MONTH'), F.max('CALL_MONTH').alias('max_CALL_MONTH')))