In [1]:
from snowflake.snowpark import Session
from snowflake.snowpark.functions import (col, row_number, to_date, concat, lit, lpad, year, month, dayofmonth, when, sum as ssum,
        count_distinct, coalesce)
from snowflake.snowpark.window import Window
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
import pandas as pd
import os
import logging
import builtins
import numpy as np
from datetime import timedelta, datetime, date
from dateutil.relativedelta import relativedelta


In [2]:
# ---------------------------------------------
# Logging setup
# ---------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
 
# ---------------------------------------------
# Configuration
# ---------------------------------------------
def get_snowflake_session():
    pkey_pem = os.getenv("MY_SF_PKEY")

    pkey = serialization.load_pem_private_key(
        pkey_pem.encode("utf-8"),
        password=None,
        backend=default_backend()
    )

    connection = {
        "account": "uhgdwaas.east-us-2.azure",
        "user": os.getenv('MY_SF_USERNAME'),
        "private_key": pkey,
        "role": "AZU_SDRP_CSZNB_PRD_DEVELOPER_ROLE",
        "warehouse": "CSZNB_PRD_ANALYTICS_XS_WH",
        "database": "CSZNB_PRD_PS_PFA_DB",
        "schema": "STAGE"
    }
    return Session.builder.configs(connection).create()

In [3]:
def calculate_study_period(session, client):
    logger.info("Calculating birth and runout window from source data")
    table = f"CSZNB_PRD_PS_PFA_DB.STAGE.FA_MEDICAL_{client}"
    query = f"""
    SELECT
    MIN(SRVC_FROM_DT) AS MIN_FROMDATE,
    MAX(SRVC_FROM_DT) AS MAX_FROMDATE,
    MAX(PROCESS_DT) AS MAX_PAIDDATE
    FROM {table}
    WHERE SRVC_FROM_DT IS NOT NULL AND PROCESS_DT IS NOT NULL
    """
    df = session.sql(query).to_pandas()
    min_dt = pd.to_datetime(df.at[0, 'MIN_FROMDATE'])
    max_dt = pd.to_datetime(df.at[0, 'MAX_FROMDATE'])
    max_ro_dt = pd.to_datetime(df.at[0, 'MAX_PAIDDATE'])
 
    if pd.isna(min_dt) or pd.isna(max_dt):
        raise ValueError("FROMDATE range is invalid. Cannot determine birth window.")
 
    num_months = (max_dt.year - min_dt.year) * 12 + (max_dt.month - min_dt.month) + 1
    if num_months < 12:
        raise ValueError(f"Only {num_months} months available. Minimum 24 months required.")
 
    last_complete_month_end = max_dt.replace(day=1) - pd.Timedelta(days=1)
    runout_end = last_complete_month_end
    runout_start = runout_end - relativedelta(months=3) + pd.Timedelta(days=1)
    study_end = runout_start - pd.Timedelta(days=1)
    study_start = study_end - relativedelta(months=12) + pd.Timedelta(days=1)
 
    study_start = pd.to_datetime(study_start).to_pydatetime()
    study_end = pd.to_datetime(study_end).to_pydatetime()
    runout_end = pd.to_datetime(runout_end).to_pydatetime()
 
    logger.info(f"Study window: {study_start.date()} to {study_end.date()}")
    logger.info(f"Runout end: {runout_end.date()}")

    return study_start, study_end, runout_end

In [4]:
def _pydate(x):
    if hasattr(x, "to_pydatetime"):
        return x.to_pydatetime().date()
    if isinstance(x, datetime):
        return x.date()
    if isinstance(x, date):
        return x
    raise TypeError(type(x))

In [5]:
def export_to_snowflake(df, table_name):
    """
    Writes a Pandas DataFrame to Snowflake, overwriting the target table.
    The table is auto-created if it does not exist.
    """
 
    logger.info(f"Exporting data to Snowflake table: {table_name}")
    df.write.mode("overwrite").save_as_table(
        table_name
    )
    logger.info("Export complete.")

In [6]:

def process_membership(session, client, study_start, study_end):
    study_start = _pydate(study_start)
    study_end = _pydate(study_end)

    # Load and prepare source table
    src = (
        session.table(f"FA_MEMBERSHIP_{client}")
        .filter(col("INDV_ID").is_not_null() & col("MM").is_not_null())
        .with_column("MM_DATE", to_date(concat(col("MM"), lit("01")), "YYYYMMDD"))
        .filter((col("MM_DATE") >= study_start) & (col("MM_DATE") <= study_end))
    )

    # Window to get most recent record per INDV_ID
    w = Window.partition_by("INDV_ID").order_by(col("MM_DATE").desc())

    # Select most recent eligible record per member
    member_df = (
        src.with_column("RN", row_number().over(w))
        .filter(col("RN") == 1)
        .select(
            "INDV_ID",
            "SBSCR_ID",
            "GENDER",
            "BTH_DT",
            "BUS_LINE_CD",
            "PRODUCT",
            "STATE",
            "REL_CD",
            "MBR_ZIP_CD",
            "MM_DATE"
        )
        .with_column("CLIENT_NAME", lit(client))
        .with_column("STUDY_PERIOD", lit(f"{study_start:%b %Y} - {study_end:%b %Y}"))
    )

 
    # Write PS_MEMBERSHIP_<CLIENT>
    export_to_snowflake(member_df, f"CSZNB_PRD_PS_PFA_DB.BASE.SMS_MEMBERSHIP_{client}")

In [7]:
def data_pull(session, client, study_start, study_end):

    # Load tables
    elig_df = session.table(f"BASE.SMS_MEMBERSHIP_{client}")
    clm_data = session.table(f"FA_MEDICAL_{client}")
    xref_proc = session.table("SUPP_DATA.XREF_PROC_CD_XWALK")

    clm = clm_data.alias("clm")
    mem = elig_df.alias("mem")
    xref = xref_proc.alias("xref")


    # Join and select

    clm_pull = (
        clm
        .filter((clm["SRVC_FROM_DT"] >= study_start))
        .filter((clm["LST_SRVC_DT"] <= study_end))
        .filter( clm["PL_OF_SRVC_CD"].isin("22","24","19","99") )
        .join(mem,  clm["INDV_ID"] == mem["INDV_ID"], how="left")
        .join(xref, clm["PROC_CD"] == xref["PROC_CD"], how="left")
        .select(
            clm["INDV_ID"],
            #mem["CUST_SEG_NBR"].alias("POLICY_ID"),
            mem["BUS_LINE_CD"].alias("BUS_LINE_CD"),
            mem["REL_CD"].alias("REL_CD"),
            #mem["MBR_LST_NM"].alias("MEMBER_LAST_NAME"),
            #mem["MBR_FST_NM"].alias("MEMBER_FIRST_NAME"),
            mem["BTH_DT"],
            mem["GENDER"].alias("GENDER"),
            mem["SBSCR_ID"],
            mem["MBR_ZIP_CD"].substr(0,5).alias("MBR_ZIP_CD"),
            mem["STATE"].alias("MBR_STATE"),
            clm["PROCESS_DT"],
            clm["PL_OF_SRVC_CD"],
            clm["BILL_TYPE"],
            clm["SRVC_FROM_DT"],
            clm["PROC_CD"],
            xref["PROC_CD_DESC"],
            xref['ASGN_SPCLTY'].alias("PROCEDURE_SPECIALTY"),
            xref['INCLD_ORDRNG'].alias("INCLUDE_IN_ORDERING"),
            xref['ASC_ELIG'].alias("ELIGIBLE"),
            clm["RVNU_CD"],
            clm["PROV_NPI"],
            clm["PROV_FULL_NM"],
            clm["PROV_STATE"],
            clm["PROV_ZIP_CD"].substr(0,5),
            clm["COV_AMT"],
            clm["ALLW_AMT"]
        )
        .with_column(
                "PROV_CAT",
                when(col("RVNU_CD").rlike(r"^(01|02|03|05|06|049|045)"), "Facility")
                .when(col("RVNU_CD").rlike(r"^(04|07|08|09)"), "Professional")
                .when(col("PL_OF_SRVC_CD").isin("21", "22", "23", "31", "32", "33", "34"), "Facility")
                .when(col("PL_OF_SRVC_CD").isin("11", "12", "20", "49", "50"), "Professional")
                .when(
                    (col("PROC_CD").rlike(r"^\d+$")) &
                    (col("PROC_CD").cast("int") >= 10000) &
                    (col("PROC_CD").cast("int") <= 69999),
                    "Professional"
                )
                .when(
                    (col("PROC_CD").rlike(r"^\d+$")) &
                    (col("PROC_CD").cast("int") >= 70000) &
                    (col("PROC_CD").cast("int") <= 79999),
                    "Professional"
                )
                .when(
                    (col("PROC_CD").rlike(r"^\d+$")) &
                    (col("PROC_CD").cast("int") >= 80000) &
                    (col("PROC_CD").cast("int") <= 89999),
                    "Facility"
                )
                .otherwise("Unknown")
        )
    )


    # Save to table
    clm_pull.write.save_as_table(f"BASE.TMP_SMS_DATA_{client}", mode="overwrite")

    return clm_pull


In [8]:
def derive_ids(session, client, claims_df):

 # Format date parts with leading zeros
    birth_year = lpad(year(col("BTH_DT")).cast("string"), 4, lit("0"))
    birth_month = lpad(month(col("BTH_DT")).cast("string"), 2, lit("0"))
    birth_day = lpad(dayofmonth(col("BTH_DT")).cast("string"), 2, lit("0"))

    srvc_year = lpad(year(col("SRVC_FROM_DT")).cast("string"), 4, lit("0"))
    srvc_month = lpad(month(col("SRVC_FROM_DT")).cast("string"), 2, lit("0"))
    srvc_day = lpad(dayofmonth(col("SRVC_FROM_DT")).cast("string"), 2, lit("0"))

    # Construct derived columns
    derived = (
        claims_df
        .with_column("OUTPATIENT_EVENT_ID", concat(col("INDV_ID"), srvc_year, srvc_month, srvc_day))
        #.with_column("DRVD_MBR_ID", concat(col("MBR_LST_NM"), col("MBR_FST_NM"), birth_year, birth_month, birth_day))
        .drop("year", "mon", "day", "year1", "mon1", "day1")  
    )

    # Optionally save the result
    derived.write.save_as_table(f"BASE.TMP_SMS_DATA_DRVD_{client}", mode="overwrite")

    return derived


In [9]:
def sms_analyze(session, client, claims_df, amt_col):

    # Step 1: Split into Professional and Facility
    prof = claims_df.filter(col("PROV_CAT") == "Professional")
    fac = claims_df.filter(col("PROV_CAT") != "Professional")

    # Step 2: Filter Facility claims where INCLUDE_IN_ORDERING = 'Y'
    include_in_ranking = fac.filter(col("INCLUDE_IN_ORDERING") == "Y")

    # Step 3: Facility claims where INCLUDE_IN_ORDERING != 'Y'
    not_in_ranking = (
        fac.filter(col("INCLUDE_IN_ORDERING") != "Y")
        .group_by("OUTPATIENT_EVENT_ID")
        .agg(ssum(amt_col).alias("COVERED_NOT_IN_RANKING"))
    )

    # Step 4: Rank eligible service lines
    w = Window.partition_by("OUTPATIENT_EVENT_ID").order_by(col(amt_col).desc())
    ranked = (
        include_in_ranking.filter(col(amt_col) != 0)
                        .with_column("RANK", row_number().over(w))
    )

    # Step 5: Separate eligible and not eligible CPTs
    not_eligible_cpt = (
        ranked.filter(col("ELIGIBLE") != "Y")
            .group_by("OUTPATIENT_EVENT_ID")
            .agg(ssum(amt_col).alias("NOT_ELIGIBLE_COVERED_AMT"))
    )

    eligible_cpt = ranked.filter(col("ELIGIBLE") == "Y")

    # Step 6: Primary, Secondary, Tertiary, All Other
    primary = eligible_cpt.filter(col("RANK") == 1)

    secondary = (
        eligible_cpt.filter(col("RANK") == 2)
                    .group_by("OUTPATIENT_EVENT_ID")
                    .agg(ssum(amt_col).alias("SECONDARY_COVERED_AMT"))
    )

    tertiary = (
        eligible_cpt.filter(col("RANK") == 3)
                    .group_by("OUTPATIENT_EVENT_ID")
                    .agg(ssum(amt_col).alias("TERTIARY_COVERED_AMT"))
    )

    all_other = (
        eligible_cpt.filter(~col("RANK").isin(1, 2, 3))
                    .group_by("OUTPATIENT_EVENT_ID")
                    .agg(ssum(amt_col).alias("ALLOTH_COVERED_AMT"))
    )

    # Step 7: Roll-up with joins
    cbsa = session.table("SUPP_DATA.XREF_CBSA_XWALK")

    
    # Aliases for clarity
    A = primary.alias("A")
    B = secondary.alias("B")
    C = tertiary.alias("C")
    D = all_other.alias("D")
    E = not_eligible_cpt.alias("E")
    F = not_in_ranking.alias("F")
    XC = cbsa.alias("XC")

    # Final roll-up logic
    # 0) One row per OUTPATIENT_EVENT_ID for each side table
    B1 = B.group_by("OUTPATIENT_EVENT_ID").agg(ssum(col("SECONDARY_COVERED_AMT")).alias("SECONDARY_COVERED_AMT"))
    C1 = C.group_by("OUTPATIENT_EVENT_ID").agg(ssum(col("TERTIARY_COVERED_AMT")).alias("TERTIARY_COVERED_AMT"))
    D1 = D.group_by("OUTPATIENT_EVENT_ID").agg(ssum(col("ALLOTH_COVERED_AMT")).alias("ALLOTH_COVERED_AMT"))
    E1 = E.group_by("OUTPATIENT_EVENT_ID").agg(ssum(col("NOT_ELIGIBLE_COVERED_AMT")).alias("NOT_ELIGIBLE_COVERED_AMT"))
    F1 = F.group_by("OUTPATIENT_EVENT_ID").agg(ssum(col("COVERED_NOT_IN_RANKING")).alias("COVERED_NOT_IN_RANKING"))

    # 1) Clean zipcode join (pad left; cast to string to avoid type mismatch)
    #    Adjust to your real column types/lengths
    XC1 = XC.with_column("ELIG_ZIP_CD_5", lpad(col("ELIG_ZIP_CD").cast("string"), 5, lit("0")))
    A1  = A.with_column("MBR_ZIP_CD_5",   lpad(col("MBR_ZIP_CD").cast("string"), 5, lit("0")))

    # 2) Join all sides (now safe because B1..F1 are 1:1 on the key)
    base = (
        A1
        .join(B1, ["OUTPATIENT_EVENT_ID"], "left")
        .join(C1, ["OUTPATIENT_EVENT_ID"], "left")
        .join(D1, ["OUTPATIENT_EVENT_ID"], "left")
        .join(E1, ["OUTPATIENT_EVENT_ID"], "left")
        .join(F1, ["OUTPATIENT_EVENT_ID"], "left")
        .join(XC1, A1["MBR_ZIP_CD_5"] == XC1["ELIG_ZIP_CD_5"], "left")
    )

    # 3) Compute covered total with NULL-safe arithmetic
    covered = (
        coalesce(col("ALLW_AMT"), lit(0))
        + coalesce(col("SECONDARY_COVERED_AMT"), lit(0))
        + coalesce(col("TERTIARY_COVERED_AMT"), lit(0))
        + coalesce(col("ALLOTH_COVERED_AMT"), lit(0))
        + coalesce(col("NOT_ELIGIBLE_COVERED_AMT"), lit(0))
        + coalesce(col("COVERED_NOT_IN_RANKING"), lit(0))
    ).alias("EPISODE_COVERED")

    # 4) Decide your *dimensions* (grouping columns)
    dims = [
        "BUS_LINE_CD",
        "OUTPATIENT_EVENT_ID",  # include this if you want 1 row per event; omit if you want a higher-level rollup
        "CBSA", "CBSA_NM",
        "SBSCR_ID", "REL_CD", "INDV_ID", "GENDER",
        "MBR_STATE", "MBR_ZIP_CD",
        "PROCESS_DT", "PL_OF_SRVC_CD", "BILL_TYPE",
        "SRVC_FROM_DT", "PROC_CD", "PROC_CD_DESC", "PROCEDURE_SPECIALTY", amt_col,
        "INCLUDE_IN_ORDERING", "ELIGIBLE",
        "PROV_NPI", "PROV_FULL_NM", "PROV_STATE", "PROV_ZIP_CD", "PROV_CAT",
    ]

    # Keep only columns that actually exist
    dims = [d for d in dims if d in base.columns]

    # 5) Group & aggregate
    rollup = (
        base.select(
            *[col(d) for d in dims],
            col("ALLW_AMT").alias("PRIMARY_COVERED_AMT"),
            col("SECONDARY_COVERED_AMT"),
            col("TERTIARY_COVERED_AMT"),
            col("ALLOTH_COVERED_AMT"),
            col("NOT_ELIGIBLE_COVERED_AMT"),
            col("COVERED_NOT_IN_RANKING"),
            covered
        )
        .group_by(*dims)
        .agg(
            # If you grouped by OUTPATIENT_EVENT_ID, VOLUME will be 1 for each row.
            # If you grouped at a higher level (e.g., CBSA), this will be the distinct event count.
            count_distinct(col("OUTPATIENT_EVENT_ID")).alias("VOLUME"),

            # Optional: if you removed OUTPATIENT_EVENT_ID from dims, you likely want sums of the amounts:
            ssum(col("PRIMARY_COVERED_AMT")).alias("PRIMARY_COVERED_AMT"),
            ssum(col("SECONDARY_COVERED_AMT")).alias("SECONDARY_COVERED_AMT"),
            ssum(col("TERTIARY_COVERED_AMT")).alias("TERTIARY_COVERED_AMT"),
            ssum(col("ALLOTH_COVERED_AMT")).alias("ALLOTH_COVERED_AMT"),
            ssum(col("NOT_ELIGIBLE_COVERED_AMT")).alias("NOT_ELIGIBLE_COVERED_AMT"),
            ssum(col("COVERED_NOT_IN_RANKING")).alias("COVERED_NOT_IN_RANKING"),
            ssum(col("EPISODE_COVERED")).alias("EPISODE_COVERED"),
        )
    )


    # # Save final result
    rollup.write.save_as_table(f"BASE.TMP_SMS_ROLLUP_{client}", mode="overwrite")



In [10]:
session  = get_snowflake_session()

client = 'MOLINA'

study_start, study_end, runout_end = calculate_study_period(session, client)

#process_membership(session, client, study_start, study_end)

#claims_df = data_pull(session, client, study_start, study_end)

#claims_df = derive_ids(session, client, claims_df)

#sms_analyze(session, client, claims_df, "ALLW_AMT")



2025-09-30 17:07:00,319 [INFO] Snowflake Connector for Python Version: 3.15.0, Python Version: 3.12.6, Platform: macOS-15.7-arm64-arm-64bit
2025-09-30 17:07:00,320 [INFO] Connecting to GLOBAL Snowflake domain
2025-09-30 17:07:01,666 [INFO] Snowpark Session information: 
"version" : 1.34.0,
"python.version" : 3.12.6,
"python.connector.version" : 3.15.0,
"python.connector.session.id" : 66583257609210899,
"os.name" : Darwin

2025-09-30 17:07:01,667 [INFO] Calculating birth and runout window from source data


SnowparkSQLException: (1304): 01bf6793-0d0c-65ef-00ec-8d3b1358320f: 002003 (42S02): 01bf6793-0d0c-65ef-00ec-8d3b1358320f: SQL compilation error:
Object 'CSZNB_PRD_PS_PFA_DB.STAGE.FA_MEDICAL_MOLINA' does not exist or not authorized.