In [None]:
# =============================================================================
# CONFIGURATION - All inputs and parameters
# =============================================================================

# Client configuration
CLIENT_DATA = 'EMBLEM'
CLIENT_NAME = 'EMBLEM'

# Database configuration
DATABASE = "CSZNB_PRD_PS_PFA_DB"
SCHEMA = "STAGE"
BASE_SCHEMA = "BASE"
SUPP_SCHEMA = "SUPP_DATA"

# Snowflake connection parameters
SF_ACCOUNT = "uhgdwaas.east-us-2.azure"
SF_ROLE = "AZU_SDRP_CSZNB_PRD_DEVELOPER_ROLE"
SF_WAREHOUSE = "CSZNB_PRD_ANALYTICS_XS_WH"

# Pipeline parameters
AUTO_WINDOW = True   # Set to False to use manual dates below
DRY_RUN = False      # Set to True to skip writing to Snowflake (for testing)
DEBUG_MODE = True    # Set to True to cache intermediate dataframes for debugging

# Manual date configuration (used if AUTO_WINDOW = False)
MANUAL_BIRTH_START = "2021-01-01"
MANUAL_BIRTH_END = "2022-12-31"
MANUAL_RUNOUT_END = "2023-03-31"

# Output table suffix (for testing/production)
TABLE_SUFFIX = "_TST"

# =============================================================================
# Clinical and Business Rule Thresholds
# =============================================================================

# Readmission and hospitalization thresholds
INIT_HOSP_THRESHOLD_DAYS = 4        # Initial hospitalization window after birth
READMIT_THRESHOLD_DAYS = 30         # Readmission window (days)
HOSP_STAY_GAP_DAYS = 4              # Gap > 4 days between discharges = new episode
NEWBORN_SERVICE_WINDOW_DAYS = 4     # Services within 4 days of birth count as newborn

# Cost thresholds
HIGH_COST_CLAIM_THRESHOLD = 500000           # CMS extreme outlier definition ($500k)
NICU_LOW_COST_PER_DAY_THRESHOLD = 150       # Data quality: minimum expected NICU cost/day

# Length of stay thresholds
INAPPROPRIATE_NICU_MAX_LOS = 5      # Max LOS for inappropriate NICU (DRG-based, short stay)
LONG_STAY_THRESHOLD = 3             # LOS >= 3 days = "Long Stay"

# DRG code ranges for NICU identification
NICU_MS_DRG_RANGE = (580, 640)      # MS-DRG 580-640: Neonate diagnoses
NICU_APR_DRG_RANGE = (789, 795)     # APR-DRG 789-795: Extreme neonate conditions

# Revenue code ranges  
NICU_REV_CODE_RANGE = (170, 179)    # Rev codes 170-179: Nursery levels (I-IV)
ROOM_BOARD_REV_PREFIXES = ["011", "012", "013", "014", "015", "016", "017", "020"]

# Manageable and critical care CPT codes
MANAGEABLE_CPT_CODES = ["99233", "99479", "99480", "99478", "99231", "99232", "99462"]
CRITICAL_CARE_CPT_CODES = ["99468", "99469", "99471", "99472"]

# Place of Service codes
POS_INPATIENT = "21"
POS_EMERGENCY = "23"

# Discharge status code groups (for prioritization)
DISCHARGE_STATUS_DEATH = "20"
DISCHARGE_STATUS_AMA = "07"
DISCHARGE_STATUS_TRANSFERS = ["02", "05", "66", "43", "62", "63", "65"]
DISCHARGE_STATUS_SNF = "30"
DISCHARGE_STATUS_HOME = ["01", "06"]
DISCHARGE_STATUS_EXCLUDED = ["04", "41", "50", "51", "70", "03", "64"]

# =============================================================================
# Configuration Validation
# =============================================================================
if not CLIENT_DATA:
    raise ValueError("CLIENT_DATA must be set")
if TABLE_SUFFIX and not TABLE_SUFFIX.startswith('_'):
    raise ValueError("TABLE_SUFFIX must start with underscore or be empty")
if INIT_HOSP_THRESHOLD_DAYS < 1:
    raise ValueError("INIT_HOSP_THRESHOLD_DAYS must be >= 1")
if READMIT_THRESHOLD_DAYS < 1:
    raise ValueError("READMIT_THRESHOLD_DAYS must be >= 1")

print(f"✓ Configuration validated for client: {CLIENT_DATA}")
print(f"✓ Database: {DATABASE}")
print(f"✓ Table suffix: {TABLE_SUFFIX}")
print(f"✓ Dry-run mode: {DRY_RUN}")
print(f"✓ Debug mode: {DEBUG_MODE}")

# NRS Beta 2 - Newborn Risk Stratification Analysis

This notebook implements the NRS (Newborn Risk Stratification) analytics pipeline for processing newborn and NICU claims data.

## Overview

The pipeline performs the following key steps:
1. **Membership Processing**: Loads and processes member eligibility data
2. **Newborn Identification**: Identifies newborns using diagnosis and revenue codes
3. **Claims Processing**: Loads and enriches claims data with reference flags
4. **Hospital Rollup**: Aggregates hospital stays and calculates length of stay
5. **NICU Analysis**: Identifies and analyzes NICU admissions with cost breakdowns
6. **Final Export**: Combines newborn and NICU data for final output

## Key Changes from Original
- Uses `MEM_EFF_DT` and `MEM_EXP_DT` directly from source instead of deriving from YEARMO
- Changed field name from `PRODUCT_CD` to `PRDCT_CD`
- All configuration parameters centralized in the first cell

In [6]:
from snowflake.snowpark import Session
from snowflake.snowpark.functions import (
    col, row_number, to_date, concat, lit, when,
    min as smin, max as smax, greatest, least,
    datediff, first_value, sum as ssum, abs as sabs,
    coalesce, length, lag, sql_expr, to_char,
    substring, count_distinct, try_cast
)
from snowflake.snowpark.window import Window
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
import pandas as pd
import os
import logging
import builtins
import numpy as np
from datetime import timedelta, datetime, date
from dateutil.relativedelta import relativedelta

# ---------------------------------------------
# Logging setup
# ---------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)

# ---------------------------------------------
# Configuration
# ---------------------------------------------
def get_snowflake_session():
    pkey_pem = os.getenv("MY_SF_PKEY")
    pkey = serialization.load_pem_private_key(
        pkey_pem.encode("utf-8"),
        password=None,
        backend=default_backend()
    )
    connection = {
        "account": "uhgdwaas.east-us-2.azure",
        "user": os.getenv('MY_SF_USER'),
        "private_key": pkey,
        "role": "AZU_SDRP_CSZNB_PRD_DEVELOPER_ROLE",
        "warehouse": "CSZNB_PRD_ANALYTICS_XS_WH",
        "database": "CSZNB_PRD_PS_PFA_DB",
        "schema": "STAGE"
    }
    return Session.builder.configs(connection).create()

def get_table_name(table_type: str, client: str = None) -> str:
    """
    Build fully qualified table names consistently.
    
    Args:
        table_type: Type of table ('medical', 'membership', 'ps_membership', 'ps_newborns', 'ref_*')
        client: Client name (uses CLIENT_DATA from config if None)
    
    Returns:
        str: Fully qualified table name
    """
    client = client or CLIENT_DATA
    
    if table_type == 'medical':
        return f"{DATABASE}.{SCHEMA}.FA_MEDICAL_{client}"
    elif table_type == 'membership':
        return f"{DATABASE}.{SCHEMA}.FA_MEMBERSHIP_{client}"
    elif table_type == 'ps_membership':
        return f"{DATABASE}.{BASE_SCHEMA}.PS_MEMBERSHIP_{client}{TABLE_SUFFIX}"
    elif table_type == 'ps_newborns':
        return f"{DATABASE}.{BASE_SCHEMA}.PS_NEWBORNS_{client}{TABLE_SUFFIX}"
    elif table_type.startswith('ref_'):
        ref_name = table_type[4:].upper()
        return f"{DATABASE}.{SUPP_SCHEMA}.{ref_name}"
    else:
        raise ValueError(f"Unknown table type: {table_type}")

# ---------------------------------------------
# Auto-calculate birth window dates
# ---------------------------------------------
def calculate_birth_window(session, client_data):
    logger.info("Calculating birth and runout window from source data")
    table = f"CSZNB_PRD_PS_PFA_DB.STAGE.FA_MEDICAL_{client_data}"
    query = f"""
    SELECT
    MIN(SRVC_FROM_DT) AS MIN_FROMDATE,
    MAX(SRVC_FROM_DT) AS MAX_FROMDATE,
    MAX(PROCESS_DT) AS MAX_PAIDDATE
    FROM {table}
    WHERE SRVC_FROM_DT IS NOT NULL AND PROCESS_DT IS NOT NULL
    """
    df = session.sql(query).to_pandas()
    min_dt = pd.to_datetime(df.at[0, 'MIN_FROMDATE'])
    max_dt = pd.to_datetime(df.at[0, 'MAX_FROMDATE'])
    max_ro_dt = pd.to_datetime(df.at[0, 'MAX_PAIDDATE'])
    
    if pd.isna(min_dt) or pd.isna(max_dt):
        raise ValueError("FROMDATE range is invalid. Cannot determine birth window.")
    
    num_months = (max_dt.year - min_dt.year) * 12 + (max_dt.month - min_dt.month) + 1
    if num_months < 24:
        raise ValueError(f"Only {num_months} months available. Minimum 24 months required.")
    
    last_complete_month_end = max_dt.replace(day=1) - pd.Timedelta(days=1)
    runout_end = last_complete_month_end
    runout_start = runout_end - relativedelta(months=3) + pd.Timedelta(days=1)
    birth_window_end = runout_start - pd.Timedelta(days=1)
    birth_window_start = birth_window_end - relativedelta(months=24) + pd.Timedelta(days=1)
    birth_window_mid = birth_window_start + relativedelta(months=12)
    
    birth_window_start = pd.to_datetime(birth_window_start).to_pydatetime()
    birth_window_end = pd.to_datetime(birth_window_end).to_pydatetime()
    runout_end = pd.to_datetime(runout_end).to_pydatetime()
    
    logger.info(f"Birth window: {birth_window_start.date()} to {birth_window_end.date()}")
    logger.info(f"Runout end: {runout_end.date()}")
    return birth_window_start, birth_window_end, birth_window_mid, runout_end

def _pydate(x):
    """
    Convert various date types to Python date object.
    
    Args:
        x: Date-like object (pandas Timestamp, datetime, or date)
        
    Returns:
        date: Python date object
        
    Raises:
        TypeError: If x is not a recognized date type
    """
    if hasattr(x, "to_pydatetime"):
        return x.to_pydatetime().date()
    if isinstance(x, datetime):
        return x.date()
    if isinstance(x, date):
        return x
    raise TypeError(type(x))

def process_membership(session: Session, client: str, birth_start, birth_mid, birth_end, client_nm: str):
    """
    Process membership data and create PS_MEMBERSHIP table.

    Uses MEM_EFF_DT and MEM_EXP_DT directly from source.
    Processes ALL enrollment periods per member (not just most recent).
    Aggregates member-months per member per study period.

    Args:
        session: Snowpark session
        client: Client identifier
        birth_start: Start of birth window
        birth_mid: Midpoint of birth window (splits Previous/Current periods)
        birth_end: End of birth window
        client_nm: Client name for output
    """

    birth_start = _pydate(birth_start)
    birth_mid = _pydate(birth_mid)
    birth_end = _pydate(birth_end)


    src = (session.table(f"FA_MEMBERSHIP_{client}")
           .filter(col("INDV_ID").is_not_null() & col("YEARMO").is_not_null())
           # YEARMO like '202401' -> 2024-01-01
           .with_column("MM_DATE", to_date(concat(col("YEARMO"), lit("01")), "YYYYMMDD")))

    # Step 1: Get most recent demographics per member (by MM_DATE desc)
    w = Window.partition_by("INDV_ID").order_by(col("MM_DATE").desc())
    demographics = (src.with_column("RN", row_number().over(w))
                       .filter(col("RN") == 1)
                       .select("INDV_ID", "GENDER", "BTH_DT", "BUS_LINE_CD", "PRDCT_CD", "STATE"))

    # Step 2: Get ALL distinct enrollment periods per member (no row_number filter)
    enrollment = (src.select("INDV_ID",
                            col("MEM_EFF_DT").cast("DATE").alias("MEM_EFF_DT"),
                            col("MEM_EXP_DT").cast("DATE").alias("MEM_EXP_DT"))
                     .distinct())

    # Step 3: Join demographics to all enrollment periods
    base = enrollment.join(demographics, "INDV_ID")

    prev_high = birth_mid - timedelta(days=1)  # first window ends day before mid
    prev_high = _pydate(prev_high)

    # Calculate overlaps using MEM_EFF_DT and MEM_EXP_DT directly (no aliasing)
    eff_prev = greatest(col("MEM_EFF_DT"), lit(birth_start))
    exp_prev = least(col("MEM_EXP_DT"), lit(prev_high))
    eff_curr = greatest(col("MEM_EFF_DT"), lit(birth_mid))
    exp_curr = least(col("MEM_EXP_DT"), lit(birth_end))

    with_mmyr = (base
        .with_column("MMYR1", when(exp_prev < eff_prev, lit(0))
            .otherwise(datediff("month", eff_prev, exp_prev) + lit(1)))
        .with_column("MMYR2", when(exp_curr < eff_curr, lit(0))
            .otherwise(datediff("month", eff_curr, exp_curr) + lit(1)))
        .with_column("AGE",
            when(col("BTH_DT").is_null(), lit(None))
            .otherwise(datediff("year", col("BTH_DT"), lit(birth_end))))
        .with_column("CLIENT_NAME", lit(client_nm))
        .with_column("PREVIOUS_PERIOD",
            lit(f"{birth_start:%b %Y} - {prev_high:%b %Y}"))
        .with_column("CURRENT_PERIOD",
            lit(f"{birth_mid:%b %Y} - {birth_end:%b %Y}"))
    )

    # Materialize two study-year slices and prepare for aggregation
    prev_df = (with_mmyr.filter(col("MMYR1") > 0)
               .with_column("STUDY_YR", lit("Previous"))
               .with_column("MEMBER_MONTHS", col("MMYR1")))

    curr_df = (with_mmyr.filter(col("MMYR2") > 0)
               .with_column("STUDY_YR", lit("Current"))
               .with_column("MEMBER_MONTHS", col("MMYR2")))

    # Union and aggregate member-months per member per study year
    member_df = (prev_df.union_all(curr_df)
                 .group_by("INDV_ID", "STUDY_YR", "GENDER", "BTH_DT", "BUS_LINE_CD",
                           "PRDCT_CD", "STATE", "AGE", "CLIENT_NAME",
                           "PREVIOUS_PERIOD", "CURRENT_PERIOD")
                 .agg(ssum("MEMBER_MONTHS").alias("MEMBER_MONTHS")))

    # Write PS_MEMBERSHIP_<CLIENT>
    export_to_snowflake(member_df, f"CSZNB_PRD_PS_PFA_DB.BASE.PS_MEMBERSHIP_{client}_TST")
# ---------------------------------------------
# Export final DataFrame to Snowflake
# ---------------------------------------------
def fetch_newborn_keys(session, client_data, birth_start, birth_end, runout_end):
    """
    Identify all distinct INDV_IDs with birth-related claims during the birth window.
    Uses reference tables in SUPP_DATA instead of in-memory dictionaries.
    """
    table_name = f"FA_MEDICAL_{client_data}"
    df = session.table(table_name).filter(
        (col("SRVC_FROM_DT") >= birth_start) &
        (col("SRVC_FROM_DT") <= birth_end) &
        (col("PROCESS_DT") <= runout_end) &
        col("INDV_ID").is_not_null()
    )
    # --- Load reference tables ---
    rev_ref = session.table("SUPP_DATA.REF_NEWBORN_REVCODE").select(col("CODE").alias("REV_CODE"))
    icd_ref = session.table("SUPP_DATA.REF_NEWBORN_ICD").select(col("CODE").alias("ICD_CODE"))
    msdrg_ref = session.table("SUPP_DATA.REF_NICU_MSDRG").select(col("CODE").alias("MSDRG"))
    aprdrg_ref = session.table("SUPP_DATA.REF_NICU_APRDRG").select(col("CODE").alias("APRDRG"))
    # --- Join claims to reference tables ---
    cond_rev = df["RVNU_CD"].cast("string") == rev_ref["REV_CODE"]
    cond_icd = (
        (df["DIAG_1_CD"].cast("string") == icd_ref["ICD_CODE"]) |
        (df["DIAG_2_CD"].cast("string") == icd_ref["ICD_CODE"]) |
        (df["DIAG_3_CD"].cast("string") == icd_ref["ICD_CODE"]) |
        (df["DIAG_4_CD"].cast("string") == icd_ref["ICD_CODE"]) |
        (df["DIAG_5_CD"].cast("string") == icd_ref["ICD_CODE"])
    )
    cond_msdrg = df["DRG"].substr(1,3) == msdrg_ref["MSDRG"]
    cond_aprdrg = df["DRG"].substr(1,3) == aprdrg_ref["APRDRG"]
    newborn_keys = (
        df.join(rev_ref, cond_rev, "left")
          .join(icd_ref, cond_icd, "left")
          .join(msdrg_ref, cond_msdrg, "left")
          .join(aprdrg_ref, cond_aprdrg, "left")
          .filter(
              col("REV_CODE").is_not_null() |
              col("ICD_CODE").is_not_null() |
              col("MSDRG").is_not_null() |
              col("APRDRG").is_not_null()
          )
          .select("INDV_ID").distinct()
          .to_pandas()
    )
    return newborn_keys["INDV_ID"].tolist()


def load_newborn_claims(session, client_data, newborn_keys, birth_start, birth_end, runout_end):
    fa_medical = session.table(f"FA_MEDICAL_{client_data}")
    
    if not newborn_keys:
        logger.warning("No newborn keys found - skipping claims pull")
        return session.create_dataframe([], schema=["INDV_ID"])
      
    newborn_keys = [str(k) for k in newborn_keys if k is not None]
    
    claims_df = (        fa_medical
        .filter((fa_medical['SRVC_FROM_DT'] >= birth_start) & (fa_medical['SRVC_FROM_DT'] <= birth_end))        .filter((fa_medical['PROCESS_DT'] <= runout_end))
        .filter(fa_medical['INDV_ID'].isin(newborn_keys))
        .select(            fa_medical['INDV_ID'],            fa_medical['CLM_AUD_NBR'].alias('CLAIMNO'),            fa_medical['SRVC_FROM_DT'].alias('FROMDATE'),            fa_medical['SRVC_THRU_DT'].alias('THRUDATE'),            fa_medical['PROCESS_DT'].alias('PAIDDATE'),            fa_medical['ADMIT_DT'],            fa_medical['DSCHRG_DT'].alias('DISCH_DT'),            fa_medical['DIAG_1_CD'].alias('DIAG1'),            fa_medical['DIAG_2_CD'].alias('DIAG2'),            fa_medical['DIAG_3_CD'].alias('DIAG3'),            fa_medical['DIAG_4_CD'].alias('DIAG4'),            fa_medical['DIAG_5_CD'].alias('DIAG5'),            fa_medical['PROC_1_CD'].alias('PROC1'),            fa_medical['PROC_2_CD'].alias('PROC2'),            fa_medical['PROC_3_CD'].alias('PROC3'),            fa_medical['PROC_CD'].alias('CPTCODE'),            fa_medical['DSCHRG_STS'].alias('DSCHRG_STS'),            fa_medical['SBMT_CHRG_AMT'].alias('SBMT_CHRG_AMT'),            fa_medical['DRG'].substr(0, 3).alias('DERIV_DRG_CD'),            fa_medical['DRG_TYPE'],            fa_medical['DRG_OTLR_FLG'],            fa_medical['DRG_OTLR_COST'],            fa_medical['NET_PD_AMT'].alias('AMTPAID'),            fa_medical['PL_OF_SRVC_CD'].alias('POS'),            fa_medical['RVNU_CD'].alias('REV_CD'),            fa_medical['PROV_NPI'].alias("PROV_ID"),            fa_medical['PROV_TIN'],            fa_medical['PROV_FULL_NM'].alias('PROV_FULL_NM'),            fa_medical['PROV_STATE'],            fa_medical['PROV_TYP_CD'].alias('PROV_TYP_CD')        )    )
    
    return claims_df


def create_fa_elig(session: Session, client: str):
    """
    Returns a Snowpark DataFrame containing the most recent row per INDV_ID
    from PS_MEMBERSHIP_<client>, preferring 'Current' over 'Previous'.
    """
    src = session.table(f"CSZNB_PRD_PS_PFA_DB.BASE.PS_MEMBERSHIP_{client}_TST")     # Prefer 'Current' study year; if MEM_EXP exists, use it as a tie-breaker
    prefer = when(col("STUDY_YR") == lit("Current"), lit(0)).otherwise(lit(1))
    order_cols = [prefer]
    if "MEM_EXP" in src.columns:
        order_cols.append(col("MEM_EXP").desc_nulls_last())
    w = Window.partition_by("INDV_ID").order_by(*order_cols)
    elig_df = (        src.with_column("RN", row_number().over(w))           .filter(col("RN") == 1)           .select(               col("INDV_ID"),               col("GENDER"),               col("BTH_DT"),               col("BUS_LINE_CD"),               col("PRDCT_CD"),               col("STATE")           )    )
    return elig_df


def merge_eligibility(session, client_data, newborn_keys, claims_df, elig_df):
    """
    Join eligibility details for each newborn using the ELIG table.
    Assumes eligibility is at the MEMBER_ID level and static for now.
    """
    merged = claims_df.join(elig_df, claims_df["INDV_ID"] == elig_df["INDV_ID"], "left")
    elig_non_key_vals = [c for c in elig_df.columns if c.upper() != "INDV_ID"]       # Explicitly state the columns we want to keep and drop the duplicate INDV_ID column created on the join
    result = merged.select(        claims_df["INDV_ID"].alias("INDV_ID"),        *[col(f'"{c}"') for c in claims_df.columns if c.upper() != "INDV_ID"],        *[col(f'"{c}"') for c in elig_non_key_vals]    )     #result = result.with_column_renamed("l_0030_INDV_ID", "INDV_ID")
    return result


def assign_claim_type(df):
    # Define conditions for CLAIM_TYPE assignment
    return (
        df.with_column(
            "CLAIM_TYPE",
            when(
                (col("POS") == POS_INPATIENT) |
                (col("REV_CD").between("0100", "0210")) |
                (col("REV_CD") == "0987") |
                (col("CPTCODE").between("99221", "99239")) |
                (col("CPTCODE").between("99251", "99255")) |
                (col("CPTCODE").between("99261", "99263")) |
                (col("DERIV_DRG_CD").is_not_null()),
                lit("IP")
            ).when(
                (col("POS") == POS_EMERGENCY) |
                (col("CPTCODE").isin(["99281", "99282", "99283", "99284", "99285", "99286", "99287", "99288"])) |
                (col("REV_CD").startswith("045")) |
                (col("REV_CD") == "0981"),
                lit("ER")
            ).otherwise(lit("OP"))
        )
    )

def tag_icd_flag(session, claims_df, ref_table_name, diag_cols, flag_name):
    """
    Tag claims with ICD diagnosis codes from reference table.
    
    Performance optimizations:
    - Filters null diagnosis codes before union (reduces row count ~60-80%)
    - Uses distinct to deduplicate diagnosis matches
    """
    ref_icd = session.table(ref_table_name).select(col("CODE").cast("STRING").alias("ICD_CODE")).distinct()
    diag_union = None
    
    for diag_col in diag_cols:
        diag_part = claims_df.select(
            col("INDV_ID"),
            col("CLAIMNO"),
            col(diag_col).cast("STRING").alias("DIAG_CODE")
        ).filter(col(diag_col).is_not_null())  # Filter nulls to reduce union size
        
        diag_union = diag_part if diag_union is None else diag_union.union_all(diag_part)
    
    # Deduplicate before join - significantly improves performance
    diag_union = diag_union.distinct()
    
    diag_flagged = diag_union.join(ref_icd, diag_union["DIAG_CODE"] == ref_icd["ICD_CODE"]) \
                             .select(
                                 col("INDV_ID").alias("MATCH_INDV_ID"),
                                 col("CLAIMNO").alias("MATCH_CLAIMNO")
                             ).distinct()
    
    flagged_claims_df = claims_df.join(
        diag_flagged,
        on=(claims_df["INDV_ID"] == diag_flagged["MATCH_INDV_ID"]) & (claims_df["CLAIMNO"] == diag_flagged["MATCH_CLAIMNO"]),
        how="left"
    ).with_column(
        flag_name,
        col("MATCH_CLAIMNO").is_not_null()
    ).drop("MATCH_INDV_ID", "MATCH_CLAIMNO")
    
    return flagged_claims_df


def tag_rev_flag(session, claims_df, ref_table_name, flag_name):
    ref_rev = session.table(ref_table_name).select(col("CODE").cast("STRING").alias("REV_CODE")).distinct()
    flagged = claims_df.join(        ref_rev,        claims_df["REV_CD"].cast("STRING") == ref_rev["REV_CODE"],        how="left"    ).with_column(        flag_name,        col("REV_CODE").is_not_null()    ).drop("REV_CODE")
    return flagged


def tag_drg_flag(session, claims_df, ref_table_name, flag_name):
    ref_drg = session.table(ref_table_name).select(col("CODE").cast("STRING").alias("DRG_CODE")).distinct()
    flagged = claims_df.with_column("DRG_3", col("DERIV_DRG_CD").cast("STRING").substr(1, 3)) \
        .join(            ref_drg,            col("DRG_3") == ref_drg["DRG_CODE"],            how="left"        ).with_column(            flag_name,            col("DRG_CODE").is_not_null()        ).drop("DRG_CODE", "DRG_3")
    return flagged


def tag_all_reference_flags(session, claims_df):
    """
    Tag claims with newborn/NICU reference flags from lookup tables.
        Args:        session: Snowpark session        claims_df: Claims DataFrame to tag            Returns:        Tagged claims DataFrame            Note:        Uses lazy evaluation - does NOT cache intermediate results.        Final caching happens in main() after column selection for optimal performance.
    """
    diag_cols = ['DIAG1', 'DIAG2', 'DIAG3', 'DIAG4', 'DIAG5']     # ICD tags - NO intermediate caching    logger.info("Tagging ICD codes (4 tags)...")
    icd_tags = [        ('SUPP_DATA.REF_NEWBORN_ICD', 'NEWBORN_ICD'),        ('SUPP_DATA.REF_SINGLETON_ICD', 'SINGLE'),        ('SUPP_DATA.REF_TWIN_ICD', 'TWIN'),        ('SUPP_DATA.REF_MULTIPLE_ICD', 'MULTIPLE')    ]    for i, (ref_table, flag) in enumerate(icd_tags, 1):
        logger.info(f"  [{i}/4] Tagging {flag}...")        claims_df = tag_icd_flag(session, claims_df, ref_table, diag_cols, flag)     # Revenue code tags - NO intermediate caching    logger.info("Tagging revenue codes (2 tags)...")
    rev_tags = [        ('SUPP_DATA.REF_NEWBORN_REVCODE', 'NEWBORN_REV'),        ('SUPP_DATA.REF_NICU_REVCODE', 'NICU_REV')    ]    for i, (ref_table, flag) in enumerate(rev_tags, 1):
        logger.info(f"  [{i}/2] Tagging {flag}...")        claims_df = tag_rev_flag(session, claims_df, ref_table, flag)     # DRG tags - NO intermediate caching    logger.info("Tagging DRG codes (2 tags)...")
    drg_tags = [        ('SUPP_DATA.REF_NICU_MSDRG', 'NICU_MSDRG'),        ('SUPP_DATA.REF_NICU_APRDRG', 'NICU_APRDRG')    ]    for i, (ref_table, flag) in enumerate(drg_tags, 1):
        logger.info(f"  [{i}/2] Tagging {flag}...")        claims_df = tag_drg_flag(session, claims_df, ref_table, flag)
    logger.info("✓ All reference flags tagged (using lazy evaluation)")
    return claims_dfdef newborn_rollup(session, client, claims_df):
    """
    Snowpark version of apply_birth_type_hierarchy_and_aggregate.
    Input: claims_df Snowpark DF with flags: NEWBORN_ICD, NEWBORN_REV, MULTIPLE, TWIN, SINGLE,           plus INDV_ID, BTH_DT, FROMDATE, NICU_REV, NICU_MSDRG, NICU_APRDRG, AMTPAID, etc.    Output: Snowpark DF of newborn_claims (enriched and filtered) like your Pandas version.
    """     # 0) Ensure date types are DATE (if needed)
    c = (claims_df         .with_column("BTH_DT", col("BTH_DT").cast("DATE"))         .with_column("FROMDATE", col("FROMDATE").cast("DATE")))     # 1) Per-claim birth type priority (Multiple > Twin > Single)
    c = (c         .with_column("BIRTH_PRI",            when(col("MULTIPLE"), lit(3))            .when(col("TWIN"),     lit(2))            .when(col("SINGLE"),   lit(1))            .otherwise(lit(0)))         .with_column("BIRTH_TYPE",            when(col("BIRTH_PRI")==3, lit("Multiple"))            .when(col("BIRTH_PRI")==2, lit("Twin"))            .when(col("BIRTH_PRI")==1, lit("Single"))            .otherwise(lit("Unknown")))    )     # 2) Likely newborn records (same as Pandas filter)
    newborn_only = c.filter(col("NEWBORN_ICD") | col("NEWBORN_REV"))     # 3) Group per baby: INDV_ID + BTH_DT    #    - pick the highest priority BIRTH_TYPE (via max priority, then map)    #    - SVC_DATE = min(FROMDATE)    #    - NICU flags = any -> max over boolean cast to int    newborns = (newborn_only        .group_by("INDV_ID", "BTH_DT")        .agg(            smax("BIRTH_PRI").alias("MAX_PRI"),            smin("FROMDATE").alias("SVC_DATE"),            smax(when(col("NICU_REV"),   lit(1)).otherwise(lit(0))).alias("HAS_NICU_REV"),            smax(when(col("NICU_MSDRG"), lit(1)).otherwise(lit(0))).alias("HAS_NICU_MSDRG"),            smax(when(col("NICU_APRDRG"),lit(1)).otherwise(lit(0))).alias("HAS_NICU_APRDRG"),        )        .with_column("BIRTH_TYPE",       
    when(col("MAX_PRI")==3, lit("Multiple"))            .when(col("MAX_PRI")==2, lit("Twin"))            .when(col("MAX_PRI")==1, lit("Single"))            .otherwise(lit("Unknown")))        .drop("MAX_PRI")    )     # 4) Derived baby-level fields
    newborns = (newborns        .with_column("IN_DAYS",            (sabs(datediff("day", col("SVC_DATE"), col("BTH_DT"))) <= lit(NEWBORN_SERVICE_WINDOW_DAYS)))        .with_column("BABY_TYPE",            when( (col("HAS_NICU_REV")==1) | (col("HAS_NICU_MSDRG")==1) | (col("HAS_NICU_APRDRG")==1),                  lit("NICU")).otherwise(lit("Normal Newborn")))        .with_column("CONTRACT",            when( (col("HAS_NICU_MSDRG")==1) | (col("HAS_NICU_APRDRG")==1),                  lit("DERIV_DRG_CD")).otherwise(lit("Per-Diem")))        .with_column_renamed("BIRTH_TYPE", "EP_BIRTH_TYPE")    )     # 5) DELIVERY_DT = earliest service date per INDV_ID (same as Pandas transform('min'))
    w_key = Window.partition_by("INDV_ID").order_by(col("SVC_DATE").asc())
    newborns = newborns.with_column("DELIVERY_DT", first_value("SVC_DATE").over(w_key))     # 6) Join back to all claims on INDV_ID (left), like your Pandas merge    joined = (newborns.select("INDV_ID", "EP_BIRTH_TYPE", "DELIVERY_DT", "IN_DAYS", "BABY_TYPE", "CONTRACT")              .join(c, "INDV_ID", "left"))     # 7) Filter claims to on/after delivery date
    newborn_claims = joined.filter(col("FROMDATE") >= col("DELIVERY_DT"))     # 8) Optional: high-cost flag (keep only if it’s actually part of business rules)
    newborn_claims = newborn_claims.with_column("HIGH_COST", col("AMTPAID") > lit(HIGH_COST_CLAIM_THRESHOLD))        return newborns, newborn_claims  def build_hosp_rollup(claims_df, runout_end):
    """
    Recreates the original newborn hospital stay rollup in Snowpark.
     Requirements implemented:    - filter: ~HIGH_COST (if column exists), CLAIM_TYPE == 'IP',              and (ADMIT_DT >= DELIVERY_DT) OR (FROMDATE >= DELIVERY_DT)    - fill ADMIT_DT/DSCHRG_DT from FROMDATE/THRUDATE if null    - stitch stays per member with a new stay when gap > 4 days    - stay-level ADMIT = min admit, DSCHRG = max discharge, sum AMTPAID    - clip DSCHRG to runout_end; flag OUT_END_DATE    - ADMIT = max(ADMIT, DELIVERY_DT)    - LOS = days between ADMIT and DSCHRG; if equal, LOS = 1; keep LOS >= 1
    """    # 1) Fill missing dates
    c = (claims_df         .with_column("ADMIT_DT_FIL", coalesce(col("ADMIT_DT"), col("FROMDATE")))         .with_column("DISCH_DT_FIL", coalesce(col("DISCH_DT"), col("THRUDATE")))         .with_column("ADMIT_DT_FIL", col("ADMIT_DT_FIL").cast("DATE"))         .with_column("DISCH_DT_FIL", col("DISCH_DT_FIL").cast("DATE"))         .with_column("DELIVERY_DT",  col("DELIVERY_DT").cast("DATE")))     # 2) Core filter (uses 'IP' because assign_claim_type emits 'IP', not 'Inpatient')
    base_filter = (        (col("CLAIM_TYPE") == lit("IP")) &        ((col("ADMIT_DT_FIL") >= col("DELIVERY_DT")) | (col("FROMDATE") >= col("DELIVERY_DT")))    )
    if "HIGH_COST" in c.columns:        c = c.filter(~col("HIGH_COST") & base_filter)    else:        c = c.filter(base_filter)     # 3) Order + gap logic (new stay when previous discharge gap > 4 days)
    w_sort = Window.partition_by("INDV_ID", "DELIVERY_DT").order_by(col("ADMIT_DT_FIL"), col("DISCH_DT_FIL"))
    prev_dis = lag(col("DISCH_DT_FIL")).over(w_sort)
    gap_days = datediff("day", prev_dis, col("ADMIT_DT_FIL"))
    new_stay_flag = when(prev_dis.is_null() | (gap_days > lit(HOSP_STAY_GAP_DAYS)), lit(1)).otherwise(lit(0))
    c = c.with_column("NEW_STAY", new_stay_flag)     # Cumulative sum → HOSP_STAY number per INDV_ID + DELIVERY_DT
    w_cume = Window.partition_by("INDV_ID", "DELIVERY_DT").order_by(col("ADMIT_DT_FIL"), col("DISCH_DT_FIL")).rows_between(Window.UNBOUNDED_PRECEDING, Window.CURRENT_ROW)
    c = c.with_column("HOSP_STAY", ssum(col("NEW_STAY")).over(w_cume))     # 4) Aggregate to stay level    stays = (c.group_by("INDV_ID", "DELIVERY_DT", "HOSP_STAY")               .agg(                   smin("ADMIT_DT_FIL").alias("ADMIT"),                   smax("DISCH_DT_FIL").alias("DSCHRG"),              
    ssum("AMTPAID").alias("PAID_AMT")               ))     # 5) Clip to runout, floor to delivery date, compute LOS
    stays = (stays             .with_column("OUT_END_DATE", (col("DSCHRG") > lit(runout_end)).cast("int"))             .with_column("DSCHRG", least(col("DSCHRG"), lit(runout_end)))             .with_column("ADMIT", greatest(col("ADMIT"), col("DELIVERY_DT")))             .with_column("LOS_RAW", datediff("day", col("ADMIT"), col("DSCHRG")))             .with_column("LOS", when(col("DSCHRG") == col("ADMIT"), lit(1)).otherwise(col("LOS_RAW")))             .drop("LOS_RAW")             .filter(col("LOS") >= lit(1))            )
    return stays  # equivalent to hosp_rollup_df (KEY→INDV_ID in your pipeline)


def build_newborn_and_nicu_ids(    claims_df,    hosp_rollup_df,    birth_window_start,    birth_window_mid,    init_hosp_threshold_days=INIT_HOSP_THRESHOLD_DAYS,        readmit_threshold_days=READMIT_THRESHOLD_DAYS):    # 1) Join claims to episode windows (INDV_ID + DELIVERY_DT)
    nh = (        claims_df        .join(            hosp_rollup_df.select("INDV_ID","DELIVERY_DT","ADMIT","DSCHRG","LOS"),            ["INDV_ID","DELIVERY_DT"],            "inner"        )    )    # 2) Core in‑window filter & optional high‑cost filter
    base_filter = (col("ADMIT") <= col("FROMDATE")) & (col("FROMDATE") <= col("DSCHRG"))
    nh = nh.filter(base_filter)
    if "HIGH_COST" in nh.columns:        nh = nh.filter(~col("HIGH_COST"))    # 3) Keep only the columns we actually need downstream (prevents accidental dup sources)
    keep_cols = [        "INDV_ID","DELIVERY_DT","ADMIT","DSCHRG","LOS",        "CLAIMNO","FROMDATE","THRUDATE","ADMIT_DT","DISCH_DT","PAIDDATE",        "AMTPAID","CPTCODE","REV_CD","DERIV_DRG_CD",        "DIAG1","DIAG2","DIAG3","DIAG4","DIAG5",        "PROC1","PROC2","PROC3",        "BUS_LINE_CD",        "PRDCT_CD",     # later mapped to LOB if LOB missing        "BTH_DT",        "BIRTH_TYPE","BABY_TYPE","CONTRACT",        "DSCHRG_STS",        "PROV_ID","PROV_TIN","PROV_FULL_NM","PROV_STATE"    ]    keep_cols = [c for c in keep_cols if c in nh.columns]    nh = nh.select(*[col(c) for c in keep_cols])    # 4) Episode metadata & study year
    nh = (        nh        .with_column("STAY_TYPE", when(col("LOS") >= lit(LONG_STAY_THRESHOLD), lit("Long Stay")).otherwise(lit("Short Stay")))        .with_column("IN_DAYS",   col("ADMIT") <= (col("DELIVERY_DT") + lit(init_hosp_threshold_days)))        .with_column("STUDY_YR",            when(                (col("DELIVERY_DT") >= lit(birth_window_start)) &                (col("DELIVERY_DT") <  lit(birth_window_mid)),                lit("Previous")            ).otherwise(lit("Current"))        )        .with_column("ADMIT_GAP", datediff("day", col("DELIVERY_DT"), col("ADMIT")))        .filter((col("ADMIT_GAP") < lit(readmit_threshold_days)) & col("IN_DAYS"))    )    # 5) *** Claim‑level de‑dup per episode *** (prevents AMTPAID inflation)
    # If a claim can appear multiple times in the joined set, pick ONE row per INDV_ID, DELIVERY_DT, CLAIMNO    # Prefer the latest service info within the episode.
    w_claim = (        Window.partition_by("INDV_ID","DELIVERY_DT","CLAIMNO")              .order_by(col("THRUDATE").desc_nulls_last(), col("FROMDATE").desc_nulls_last())    )
    claim_base = (        nh.with_column("RN_CLAIM", row_number().over(w_claim))          .filter(col("RN_CLAIM") == 1)          .drop("RN_CLAIM")    )
    # Map PRODUCT→LOB if LOB not present
    if "LOB" not in claim_base.columns and "PRDCT_CD" in claim_base.columns:        claim_base = claim_base.with_column("LOB", col("PRDCT_CD"))    # 6) Episode‑level rollup from the de‑duped claims only    grp_ep = [        "INDV_ID","BTH_DT","DELIVERY_DT","ADMIT","DSCHRG","LOS",        "STAY_TYPE","BIRTH_TYPE","CONTRACT","BUS_LINE_CD","LOB","STUDY_YR"    ]    grp_ep = [g for g in grp_ep if g in claim_base.columns]
    newborn_ident_ep = (        claim_base.group_by(*grp_ep)                  .agg(                      ssum("AMTPAID").alias("AMTPAID"),                      # BABY_TYPE/CONTRACT should be constant within an episode; if not, pick a stable representative:                 
    smin("BABY_TYPE").alias("BABY_TYPE")                  )    )    # 7) Newborn‑level rollup (one row per INDV_ID + DELIVERY_DT)
    # Collapse multiple episodes (if any) and derive final fields. AMTPAID remains correct because    # each episode was built from de‑duped claims.
    newborn_ident_df = (        newborn_ident_ep        .group_by("INDV_ID","BTH_DT","DELIVERY_DT","BUS_LINE_CD","LOB","STUDY_YR")        .agg(            smin("ADMIT").alias("ADMIT"),            smax("DSCHRG").alias("DSCHRG"),            ssum("AMTPAID").alias("AMTPAID"),            # Any NICU inside the newborn ⇒ NICU            smax(when(col("BABY_TYPE")==lit("NICU"), lit(1)).otherwise(lit(0))).alias("ANY_NICU"),            # Birth type priority: Multiple > Twin > Single > Unknown            smax(                when(col("BIRTH_TYPE")==lit("Multiple"), lit(3))                .when(col("BIRTH_TYPE")==lit("Twin"),     lit(2))                .when(col("BIRTH_TYPE")==lit("Single"),   lit(1))                .otherwise(lit(0))            ).alias("BT_PRI"),            # Contract priority: if any episode is DRG ⇒ DRG            smax(when(col("CONTRACT")==lit("DERIV_DRG_CD"), lit(1)).otherwise(lit(0))).alias("ANY_DRG")        )        .with_column("LOS_RAW", datediff("day", col("ADMIT"), col("DSCHRG")))        .with_column("LOS", when(col("DSCHRG")==col("ADMIT"), lit(1)).otherwise(col("LOS_RAW")))        .drop("LOS_RAW")        .with_column("BABY_TYPE",            when(col("ANY_NICU")==lit(1), lit("NICU")).otherwise(lit("Normal Newborn")))        .with_column("BIRTH_TYPE",            when(col("BT_PRI")==lit(3), lit("Multiple"))            .when(col("BT_PRI")==lit(2), lit("Twin"))            .when(col("BT_PRI")==lit(1), lit("Single"))            .otherwise(lit("Unknown")))        .with_column("CONTRACT",       
    when(col("ANY_DRG")==lit(1), lit("DERIV_DRG_CD")).otherwise(lit("Per-Diem")))        .drop("ANY_NICU","BT_PRI","ANY_DRG")    )    # 8) NICU subset (one row per newborn)
    nicu_ident = (        newborn_ident_df        .filter(col("BABY_TYPE")==lit("NICU"))        .with_column_renamed("AMTPAID","TOTAL_NICU_COST")    )    # 9) Build nicu_claims_df from the de‑duped claim_base to avoid any row blow‑up    claim_cols = [        "INDV_ID","CLAIMNO","FROMDATE","THRUDATE","ADMIT_DT","DISCH_DT",        "PROV_ID","CPTCODE","REV_CD","DERIV_DRG_CD","DSCHRG_STS","AMTPAID",        "DIAG1","DIAG2","DIAG3","DIAG4","DIAG5","PROC1","PROC2","PROC3"    ]    claim_cols = [c for c in claim_cols if c in claim_base.columns]
    nicu_claims_df = (        claim_base.select(*claim_cols, "DELIVERY_DT","LOS")
    # LOS comes only from rollup path                  .join(nicu_ident.select("INDV_ID","ADMIT","DSCHRG"), ["INDV_ID"], "inner")                  .filter((col("FROMDATE") >= col("ADMIT")) & (col("FROMDATE") <= col("DSCHRG")))    )    # 10) LAST_DISCHARGE_STATUS
    order_col = (        when(col("DSCHRG_STS") == lit("20"), lit(0))        .when(col("DSCHRG_STS") == lit("07"), lit(1))        .when(col("DSCHRG_STS").isin(["02","05","66","43","62","63","65"]), lit(2))        .when(col("DSCHRG_STS") == lit("30"), lit(3))        .when(col("DSCHRG_STS").isin(["01","06"]), lit(4))        .when(            (length(col("DSCHRG_STS")) < lit(2)) |            (col("DSCHRG_STS").isin(["04","41","50","51","70","03","64"])) |            (col("DSCHRG_STS").between(lit("08"), lit("19"))) |            (col("DSCHRG_STS").between(lit("21"), lit("29"))) |            (col("DSCHRG_STS").between(lit("31"), lit("39"))) |            (col("DSCHRG_STS").between(lit("44"), lit("49"))) |            (col("DSCHRG_STS").between(lit("52"), lit("60"))) |            (col("DSCHRG_STS").between(lit("67"), lit("69"))) |            (col("DSCHRG_STS").between(lit("71"), lit("99")))        , lit(6)).otherwise(lit(9))    )
    ranked = (        nicu_claims_df        .filter((col("DSCHRG_STS") != lit("00")) | col("DSCHRG_STS").is_not_null())        .with_column("ORDER", order_col)        .with_column(            "RN",            row_number().over(                Window.partition_by("INDV_ID","ADMIT","DSCHRG")                      .order_by(col("ORDER").asc(),                                col("DISCH_DT").desc(),                                col("FROMDATE").desc(),                                col("DSCHRG_STS").asc())            )        )    )
    last_status = (        ranked.filter(col("RN")==1)              .select("INDV_ID","ADMIT","DSCHRG", col("DSCHRG_STS").alias("LAST_DISCHARGE_STATUS"))    )
    nicu_claims_df = nicu_claims_df.join(last_status, ["INDV_ID","ADMIT","DSCHRG"], "left")    # 11) Discharge provider attribution (episode × provider) using de‑duped claims
    if all(x in nicu_claims_df.columns for x in ["AMTPAID","ADMIT_DT","DISCH_DT","PROV_TIN"]):        ep = (            nicu_claims_df            .filter(col("PROV_ID").is_not_null())            .group_by("INDV_ID","ADMIT","DSCHRG","DELIVERY_DT","LOS","PROV_ID")            .agg(                ssum("AMTPAID").alias("HOSPPAID"),                smin("ADMIT_DT").alias("HOSPADMIT"),                smax("DISCH_DT").alias("HOSPDISCHG")            )            .with_column(                "HOSPLOS",                when(col("HOSPDISCHG")==col("DSCHRG"),                     datediff("day", col("HOSPADMIT"), col("HOSPDISCHG")))                .otherwise(datediff("day", col("HOSPADMIT"), col("HOSPDISCHG")) + lit(1))            )        )        w_best = (            Window.partition_by("INDV_ID","ADMIT","DSCHRG")                  .order_by(col("HOSPDISCHG").desc(),                            col("HOSPLOS").desc(),                            col("HOSPPAID").desc())        )        best = ep.with_column("RN", row_number().over(w_best)).filter(col("RN")==1).drop("RN")        hosplist = (            claims_df.select("PROV_ID","PROV_FULL_NM","PROV_STATE")                     .filter(col("PROV_ID").is_not_null())                     .distinct()        )        nicu_dischg_provider = (            best.join(hosplist, ["PROV_ID"], "left")                .with_column("PROV_FULL_NM", coalesce(col("PROV_FULL_NM"), lit("Unknown")))                .with_column("PROV_STATE",   coalesce(col("PROV_STATE"),   lit("Unknown")))        )    else:    
    nicu_dischg_provider = None    # 12) REV & DRG episode features (computed from de‑duped claims only, so stable)
    rev_ep = (        nicu_claims_df        .select("INDV_ID","ADMIT","DSCHRG","REV_CD")        .with_column("REV_NUM", sql_expr("TRY_TO_NUMBER(REV_CD)"))        .filter(col("REV_NUM").between(*NICU_REV_CODE_RANGE))        .select("INDV_ID","ADMIT","DSCHRG","REV_NUM")        .distinct()    )
    rev_min = (        rev_ep.group_by("INDV_ID","ADMIT","DSCHRG")              .agg(smin("REV_NUM").alias("FINAL_REV_NUM"))              .with_column("FINAL_REV_CD", sql_expr("TO_VARCHAR(FINAL_REV_NUM)"))              .select("INDV_ID","ADMIT","DSCHRG","FINAL_REV_CD")    )
    w_rev = Window.partition_by("INDV_ID","ADMIT","DSCHRG").order_by(col("REV_NUM").asc())
    rev_second = (        rev_ep.with_column("RN", row_number().over(w_rev))              .filter(col("RN")==2)              .select("INDV_ID","ADMIT","DSCHRG", col("REV_NUM").alias("REV_NUM_2"))    )
    rev_features = (        rev_min.join(rev_second, ["INDV_ID","ADMIT","DSCHRG"], "left")               .with_column("REV_LEVELING", col("REV_NUM_2").is_not_null())               .select("INDV_ID","ADMIT","DSCHRG","FINAL_REV_CD","REV_LEVELING")    )
    drg_ep = (        nicu_claims_df        .select("INDV_ID","ADMIT","DSCHRG","DERIV_DRG_CD")        .with_column("DRG_NUM", sql_expr("TRY_TO_NUMBER(DRG)"))        .filter( (col("DRG_NUM").between(*NICU_MS_DRG_RANGE)) | (col("DRG_NUM").between(*NICU_APR_DRG_RANGE)) )        .select("INDV_ID","ADMIT","DSCHRG","DRG_NUM")        .distinct()    )
    drg_min = (        drg_ep.group_by("INDV_ID","ADMIT","DSCHRG")              .agg(smin("DRG_NUM").alias("FINAL_DRG_NUM"))              .with_column("FINAL_DRG_CD", sql_expr("TO_VARCHAR(FINAL_DRG_NUM)"))              .select("INDV_ID","ADMIT","DSCHRG","FINAL_DRG_CD")    )    # 13) Episode features joined to newborn_ident_ep (episode‑level) or newborn_ident_df (newborn‑level)
    ep_with_rev = newborn_ident_ep.join(rev_features, ["INDV_ID","ADMIT","DSCHRG"], "left")
    ep_with_drg = (        ep_with_rev.join(drg_min, ["INDV_ID","ADMIT","DSCHRG"], "left")                   .with_column("FINAL_DRG_CD", sql_expr("IFF(CONTRACT='DERIV_DRG_CD', FINAL_DRG_CD, NULL)"))    )
    # Outputs
    return {        "newborn_hosp_clms": claim_base,         # claims limited to episode window, de‑duped by claim        "newborn_ident_df": newborn_ident_df,    # one row per newborn (INDV_ID + DELIVERY_DT)        "nicu_ident": nicu_ident,                # subset of newborn_ident_df        "nicu_claims_df": nicu_claims_df,        # claims within NICU newborn stay, de‑duped        "nicu_dischg_provider": nicu_dischg_provider,        "rev_out": rev_features,                  # episode + REV features        "drg_out": drg_min                 # episode + REV + DRG features    } # --- 1) Professional fees (all, manageable CPT set, critical care set) ---def _prof_fee_aggregates(nicu_claims_df):    # Only professional claims (CPT present)
    prof = nicu_claims_df.filter(col("CPTCODE").is_not_null()) \                         .select("INDV_ID","ADMIT","DSCHRG","LOS","FROMDATE","CPTCODE","AMTPAID","ADMIT_DT","DISCH_DT")
    # All professional fees    all_prof = (prof.group_by("INDV_ID","ADMIT","DSCHRG")                     .agg(ssum("AMTPAID").alias("ALL_PROFFEE")))        # Manageable CPT set    manageable = MANAGEABLE_CPT_CODES
    man = prof.filter(col("CPTCODE").isin(manageable))
    # Unique service-days counted per (INDV_ID, ADMIT, DSCHRG) & CPT & FROMDATE    # Equivalent to nunique(KEY-ADMIT-DSCHRG-FROMDATE-CPT) in Pandas
    man_days_key = concat(col("INDV_ID").cast("string"), lit("-"),                          to_char(col("ADMIT"), "YYYY-MM-DD"), lit("-"),                          to_char(col("DSCHRG"), "YYYY-MM-DD"), lit("-"),                          to_char(col("FROMDATE"), "YYYY-MM-DD"), lit("-"),                          col("CPTCODE"))
    man_aggs = (man.with_column("CPT_DAYS_KEY", man_days_key)                   .group_by("INDV_ID","ADMIT","DSCHRG")                   .agg(                       ssum("AMTPAID").alias("MANAGEABLE_PROFFEE"),                       count_distinct("CPT_DAYS_KEY").alias("MANAGEABLE_SVC_DAYS")                   ))
    # Critical care CPT set    critical = CRITICAL_CARE_CPT_CODES
    crit = prof.filter(col("CPTCODE").isin(critical))
    crit_days_key = concat(col("INDV_ID").cast("string"), lit("-"),                           to_char(col("ADMIT"), "YYYY-MM-DD"), lit("-"),                           to_char(col("DSCHRG"), "YYYY-MM-DD"), lit("-"),                           to_char(col("FROMDATE"), "YYYY-MM-DD"))
    crit_aggs = (crit.with_column("CRITICAL_DAYS_KEY", crit_days_key)                    .group_by("INDV_ID","ADMIT","DSCHRG")                    .agg(                        ssum("AMTPAID").alias("CRITICAL_CARE_PROFFEE"),                        count_distinct("CRITICAL_DAYS_KEY").alias("CRITICAL_CARE_DAYS")                    ))
    return all_prof, man_aggs, crit_aggs# --- 2) Room & Board (REV 011–017,020) w/ CPT null ---def _room_and_board(nicu_claims_df):
    rb_prefix_ok = substring(col("REV_CD"), 1, 3).isin(ROOM_BOARD_REV_PREFIXES)
    room = (nicu_claims_df            .filter(rb_prefix_ok & col("CPTCODE").is_null())            .group_by("INDV_ID","ADMIT","DSCHRG")            .agg(ssum("AMTPAID").alias("FACILITY_RM_COST")))
    return room# --- 3) Readmissions (next episode within 30 days) ---def _readmissions(nicu_ident, hosp_rollup_df):    # nicu_ident: episode-level (INDV_ID, ADMIT, DSCHRG, ... TOTAL_NICU_COST, LOS, etc.)    # hosp_rollup_df: (INDV_ID, DELIVERY_DT, HOSP_STAY, ADMIT, DSCHRG, PAID_AMT, LOS, ...)
    future = (hosp_rollup_df              .select(col("INDV_ID"),                      col("ADMIT").alias("READMIT_DT"),                      col("PAID_AMT").alias("READMIT_PAID_AMT"),                      col("LOS").alias("READMIT_LOS")))    # join and keep when READMIT_DT in (DSCHRG+1, DSCHRG+30]
    j = (nicu_ident.select("INDV_ID","ADMIT","DSCHRG")         .join(future, ["INDV_ID"], "inner")         .filter((col("READMIT_DT") > col("DSCHRG")) &                 (datediff("day", col("DSCHRG"), col("READMIT_DT")) <= lit(30))))
    readm = (j.group_by("INDV_ID","ADMIT","DSCHRG")               .agg(                   count_distinct("READMIT_DT").alias("READMIT"),                   ssum("READMIT_PAID_AMT").alias("READMIT_PAID_AMT"),                   ssum("READMIT_LOS").alias("READMIT_LOS")               ))
    return readm# --- 4) “Unpivot” DIAG/PROC columns using unions (Snowpark lacks UNPIVOT API) ---def _union_diag_proc(nicu_claims_df):
    diag_cols = [c for c in nicu_claims_df.columns if c.upper().startswith("DIAG")]    proc_cols = [c for c in nicu_claims_df.columns if c.upper().startswith("PROC")]    # DIAGTMP rows
    diag_parts = []    for d in diag_cols:
        diag_parts.append(            nicu_claims_df.select(                col("INDV_ID"), col("ADMIT"), col("DSCHRG"),                col(d).alias("DIAGTMP")            ).filter(col("DIAGTMP").is_not_null())        )
    diag_tmp = None    for p in diag_parts:        diag_tmp = p if diag_tmp is None else diag_tmp.union_all(p)
    if diag_tmp is None:        diag_tmp = nicu_claims_df.session.create_dataframe([], schema=["INDV_ID","ADMIT","DSCHRG","DIAGTMP"])
    # PROCTMP rows
    proc_parts = []    for pr in proc_cols:
        proc_parts.append(            nicu_claims_df.select(                col("INDV_ID"), col("ADMIT"), col("DSCHRG"),                col(pr).alias("PROCTMP")            ).filter(col("PROCTMP").is_not_null())        )
    proc_tmp = None    for p in proc_parts:        proc_tmp = p if proc_tmp is None else proc_tmp.union_all(p)
    if proc_tmp is None:        proc_tmp = nicu_claims_df.session.create_dataframe([], schema=["INDV_ID","ADMIT","DSCHRG","PROCTMP"])    # de-dup
    diag_tmp = diag_tmp.distinct()
    proc_tmp = proc_tmp.distinct()
    return diag_tmp, proc_tmp# --- 5) Birthweight / Gestational age / NAS via REF tables ---def _bw_ga_nas(session, diag_tmp):    # Expect REF tables with CODE -> category mapping    # Adjust the table/column names if your REF schemas differ.    bw_ref  = session.table("SUPP_DATA.REF_BIRTHWEIGHT_ICD").select(col("CODE").alias("ICD_CODE"), col("DESCRIPTION").alias("BW_CAT"))    ga_ref  = session.table("SUPP_DATA.REF_GEST_AGE_ICD").select(col("CODE").alias("ICD_CODE"), col("DESCRIPTION").alias("GA_CAT"))
    # Birthweight (first category per INDV_ID by lexical order)
    bw = (diag_tmp.join(bw_ref, diag_tmp["DIAGTMP"] == bw_ref["ICD_CODE"], "inner")                  .select("INDV_ID","ADMIT","DSCHRG","BW_CAT"))
    # Pick one BW_CAT per INDV_ID (if you want per episode, include ADMIT/DSCHRG in the window)
    w_bw = Window.partition_by("INDV_ID").order_by(col("BW_CAT").asc())
    bw = (bw.with_column("RN", row_number().over(w_bw))           .filter(col("RN")==1)           .drop("RN"))
    # Gestational age
    ga = (diag_tmp.join(ga_ref, diag_tmp["DIAGTMP"] == ga_ref["ICD_CODE"], "inner")                  .select("INDV_ID","ADMIT","DSCHRG", "GA_CAT"))
    w_ga = Window.partition_by("INDV_ID").order_by(col("GA_CAT").asc())
    ga = (ga.with_column("RN", row_number().over(w_ga))           .filter(col("RN")==1)           .drop("RN"))
    # NAS flag (ICD-10 code “P96.1”; your data used "P961" – keep the exact string that appears in claims)
    nas = (diag_tmp.filter(col("DIAGTMP") == lit("P961"))                  .select("INDV_ID","ADMIT","DSCHRG")                  .with_column("NAS", lit(True))                  .distinct())
    return bw, ga, nasdef build_nicu_rollup(    session,    # from prior steps    nicu_ident,            # episode-level NICU babies (INDV_ID, ADMIT, DSCHRG, LOS, TOTAL_NICU_COST, CONTRACT, etc.)    nicu_claims_df,        # claim-level subset bounded to NICU episodes    hosp_rollup_df,        # episode stays (INDV_ID, DELIVERY_DT, HOSP_STAY, ADMIT, DSCHRG, PAID_AMT, LOS)    rev_out,               # episode-level REV features: (INDV_ID, ADMIT, DSCHRG, FINAL_REV_CD, REV_LEVELING)    drg_out,               # episode-level DRG features: (INDV_ID, ADMIT, DSCHRG, FINAL_DRG_CD)    nicu_dischg_provider   # episode-level provider attribution):    # 1) prof fee rollups    all_prof, man_aggs, crit_aggs = _prof_fee_aggregates(nicu_claims_df)    # 2) room & board    room = _room_and_board(nicu_claims_df)    # 3) readmissions
    readm = _readmissions(nicu_ident, hosp_rollup_df)    # 4) diag/proc “unpivot”    diag_tmp, proc_tmp = _union_diag_proc(nicu_claims_df)    # 5) birthweight / gest age / NAS via REF tables    bw, ga, nas = _bw_ga_nas(session, diag_tmp)
    # Start from episode-level NICU set
    base = nicu_ident    # Left-join all features on (INDV_ID, ADMIT, DSCHRG)
    keys = ("INDV_ID","ADMIT","DSCHRG")
    out = (base           .join(all_prof.select(*keys, "ALL_PROFFEE"), keys, "left")           .join(man_aggs.select(*keys, "MANAGEABLE_PROFFEE"), keys, "left")           .join(crit_aggs.select(*keys, "CRITICAL_CARE_PROFFEE"), keys, "left")           .join(room.select(*keys, "FACILITY_RM_COST"), keys, "left")           .join(readm.select(*keys, "READMIT", "READMIT_PAID_AMT", "READMIT_LOS"), keys, "left")           .join(nas.select(*keys, "NAS"), keys, "left")           .join(ga.select(*keys, "GA_CAT"), keys, "left")           .join(bw.select(*keys, "BW_CAT"), keys, "left")           .join(rev_out.select(*keys, "FINAL_REV_CD", "REV_LEVELING"), keys, "left")           .join(drg_out.select(*keys, "FINAL_DRG_CD"), keys, "left")    )
    if nicu_dischg_provider is not None:        out = (out.join(            nicu_dischg_provider.select(*keys, "PROV_TIN","PROV_FULL_NM", "PROV_STATE"),            keys, "left"        ))    # 6) derived rollup metrics
    out = (out           .with_column("ALL_PROFFEE",        coalesce(col("ALL_PROFFEE"), lit(0)))           .with_column("MANAGEABLE_PROFFEE", coalesce(col("MANAGEABLE_PROFFEE"), lit(0)))           .with_column("CRITICAL_CARE_PROFFEE", coalesce(col("CRITICAL_CARE_PROFFEE"), lit(0)))           .with_column("FACILITY_RM_COST",   coalesce(col("FACILITY_RM_COST"), lit(0)))           .with_column("TOTAL_NICU_COST",    coalesce(col("TOTAL_NICU_COST"), lit(0)))           .with_column("LOS",                coalesce(col("LOS"), lit(0)))           .with_column("ALL_FACILITY_COST",  col("TOTAL_NICU_COST") - col("ALL_PROFFEE"))           .with_column("LOW_PAID_NICU", when((col("LOS") > lit(0)) &                                              ((col("TOTAL_NICU_COST")/col("LOS")) < lit(NICU_LOW_COST_PER_DAY_THRESHOLD)),                                              lit(True)).otherwise(lit(False)))           .with_column("INAPPROPRIATE_NICU",                        (col("CONTRACT")==lit("DERIV_DRG_CD")) &                        (col("LOS") <= lit(INAPPROPRIATE_NICU_MAX_LOS)) &                        col("FINAL_REV_CD").isin(["170","171"]))    )
    return out


def prepare_final_export(newborn_df, nicu_df):
    join_keys = ['INDV_ID', 'ADMIT', 'DSCHRG']    left_cols = set(newborn_df.columns)
    right_cols = [c for c in nicu_df.columns if c not in join_keys and c not in left_cols]    nicu_df_trimmed = nicu_df.select(        *[col(k) for k in join_keys],        *[col(c) for c in right_cols]    )
    newborns_out = newborn_df.join(nicu_df_trimmed, join_keys, "left")
    return newborns_out


def export_to_snowflake(df, table_name):
    """
    Write DataFrame to Snowflake table.
        Args:        df: Snowpark DataFrame to export        table_name: Fully qualified table name            Note:        Respects DRY_RUN configuration - will only preview data if DRY_RUN=True
    """
    if DRY_RUN:        row_count = df.count()        logger.info(f"[DRY-RUN] Would export {row_count:,} rows to {table_name}")        logger.info(f"[DRY-RUN] Preview of first 5 rows:")        df.show(5)        return     logger.info(f"Exporting data to Snowflake table: {table_name}")    df.write.mode("overwrite").save_as_table(table_name)
    logger.info("Export complete.") # ---------------------------------------------# Main pipeline# ---------------------------------------------def main(auto_window=True):
    logger.info("Starting NICU pipeline")
    session = get_snowflake_session()
    client_data = CLIENT_DATA     if auto_window:        birth_window_start, birth_window_end, birth_window_mid, runout_end = calculate_birth_window(session, client_data)    else:        birth_window_start = pd.Timestamp("2021-01-01")        birth_window_end = pd.Timestamp("2022-12-31")        runout_end = pd.Timestamp("2023-03-31")
    logger.info("Processing Membership Data")    process_membership(session, client_data, birth_window_start, birth_window_mid, birth_window_end, client_data)        logger.info("Fetching newborn keys")
    newborn_keys = fetch_newborn_keys(session, client_data, birth_window_start, birth_window_end, runout_end)
    logger.info(f"Found {len(newborn_keys)} unique newborn keys")
    logger.info("Loading newborn claims")
    claims_df = load_newborn_claims(session, client_data, newborn_keys, birth_window_start, birth_window_end, runout_end)
    logger.info("Creating the temporary Elig table")
    elig_df = create_fa_elig(session, client_data)
    logger.info("Merging eligibility data")
    claims_df = merge_eligibility(session, client_data, newborn_keys, claims_df, elig_df).cache_result()
    logger.info("Assigning Claim Types")
    claims_df = assign_claim_type(claims_df).cache_result()
    logger.info("Flagging newborn and NICU enrichments")
    claims_df = tag_all_reference_flags(session, claims_df)    
    # DEBUG: Cache claims after tagging for inspection
    if DEBUG_MODE:
        claims_df_tagged = claims_df.cache_result()
        logger.info(f"[DEBUG] Cached claims_df_tagged ({claims_df_tagged.count():,} rows)")
    else:
        claims_df_tagged = None
    logger.info("Selecting final claim columns and materializing results...")
    claims_df = claims_df.select(
        "INDV_ID","CLAIMNO","FROMDATE","THRUDATE","PAIDDATE","ADMIT_DT","DISCH_DT",
        "DIAG1","DIAG2","DIAG3","DIAG4","DIAG5","PROC1","PROC2","PROC3","CPTCODE",
        "DSCHRG_STS","BILLED","DERIV_DRG_CD","AMTPAID","POS","REV_CD",
        "PROV_ID","PROV_TIN","PROV_FULL_NM","PROV_STATE","PROV_TYPE",
        "GENDER","BTH_DT","BUS_LINE_CD","PRDCT_CD","STATE",
        "NEWBORN_ICD","NEWBORN_REV","SINGLE","TWIN","MULTIPLE","NICU_REV","NICU_MSDRG","NICU_APRDRG",
        "CLAIM_TYPE"
    ).cache_result()
    
    logger.info("Applying newborn rollup logic")
    newborns_df, claims_df = newborn_rollup(session, client_data, claims_df)
    
    # DEBUG: Cache newborn claims for inspection
    if DEBUG_MODE:
        newborn_claims = claims_df.cache_result()
        logger.info(f"[DEBUG] Cached newborn_claims ({newborn_claims.count():,} rows)")
    else:
        newborn_claims = None
    logger.info("Rolling up hospital stays")
    hosp_rollup_df = build_hosp_rollup(claims_df, runout_end)
    
    logger.info("Building NICU artifact tables")
    ids = build_newborn_and_nicu_ids(
        claims_df,
        hosp_rollup_df,
        birth_window_start.date(), 
        birth_window_mid.date(),
        init_hosp_threshold_days=INIT_HOSP_THRESHOLD_DAYS,
        readmit_threshold_days=READMIT_THRESHOLD_DAYS
    )
    
    newborn_hosp_clms   = ids["newborn_hosp_clms"]
    newborn_ident_df    = ids["newborn_ident_df"]
    nicu_ident          = ids["nicu_ident"]
    nicu_claims_df      = ids["nicu_claims_df"]
    nicu_dischg_provider= ids["nicu_dischg_provider"]
    rev_out             = ids["rev_out"]
    drg_out             = ids["drg_out"]
    # Debug retention - cache intermediate dataframes when DEBUG_MODE is enabled
    if DEBUG_MODE:
        logger.info("[DEBUG] Caching intermediate dataframes for inspection...")
        newborn_ident_df = newborn_ident_df.cache_result()
        nicu_ident = nicu_ident.cache_result()
        logger.info(f"[DEBUG] Cached newborn_ident_df ({newborn_ident_df.count():,} rows)")
        logger.info(f"[DEBUG] Cached nicu_ident ({nicu_ident.count():,} rows)")
    logger.info("Building NICU rollup")
    nicu_rollup = build_nicu_rollup(
        session,
        nicu_ident,
        nicu_claims_df,
        hosp_rollup_df,
        rev_out,
        drg_out,
        nicu_dischg_provider
    )
    
    logger.info("Merging Newborns and NICU tables")
    newborns_df = prepare_final_export(newborn_ident_df, nicu_rollup)
    
    export_to_snowflake(newborns_df, f"CSZNB_PRD_PS_PFA_DB.BASE.PS_NEWBORNS_{client_data}_TST")
    # Return debug dataframes when DEBUG_MODE is enabled
    if DEBUG_MODE:
        return {
            # Priority 1: Critical debugging variables
            'newborn_ident_df': newborn_ident_df,      # Final newborn identities
            'nicu_ident': nicu_ident,                  # NICU subset
            'newborns_df': newborns_df,                # Output table
            'nicu_rollup': nicu_rollup,                # NICU aggregations
            
            # Priority 1: Critical intermediate dataframes
            'claims_df_tagged': claims_df_tagged,      # Claims after flag tagging
            'hosp_rollup_df': hosp_rollup_df,          # Hospital episodes (KEY!)
            'newborn_claims': newborn_claims,          # Claims after newborn rollup
            
            # Priority 2: Detailed analysis variables
            'newborn_hosp_clms': newborn_hosp_clms,    # Claims bounded to episodes
            'nicu_claims_df': nicu_claims_df,          # NICU claims detail
            'nicu_dischg_provider': nicu_dischg_provider,  # Discharge providers
            'rev_out': rev_out,                        # Revenue code analysis
            'drg_out': drg_out                         # DRG analysis
        }
if __name__ == "__main__":
    main(auto_window=True)


2025-08-23 14:33:07,211 [INFO] Starting NICU pipeline
2025-08-23 14:33:07,268 [INFO] Snowflake Connector for Python Version: 3.15.0, Python Version: 3.12.6, Platform: macOS-15.6-arm64-arm-64bit
2025-08-23 14:33:07,268 [INFO] Connecting to GLOBAL Snowflake domain
2025-08-23 14:33:08,567 [INFO] Snowpark Session information: 
"version" : 1.34.0,
"python.version" : 3.12.6,
"python.connector.version" : 3.15.0,
"python.connector.session.id" : 66583254756907991,
"os.name" : Darwin

2025-08-23 14:33:08,568 [INFO] Calculating birth and runout window from source data
2025-08-23 14:33:08,791 [INFO] Birth window: 2023-01-01 to 2024-12-31
2025-08-23 14:33:08,793 [INFO] Runout end: 2025-03-31
2025-08-23 14:33:08,794 [INFO] Processing Membership Data
2025-08-23 14:33:10,534 [INFO] Exporting data to Snowflake table: CSZNB_PRD_PS_PFA_DB.BASE.PS_MEMBERSHIP_EMBLEM_TST
2025-08-23 14:33:28,075 [INFO] Export complete.
2025-08-23 14:33:28,075 [INFO] Fetching newborn keys
2025-08-23 14:33:31,898 [INFO] Found 

## Execute Pipeline

Run the main pipeline with the configured parameters above.

In [None]:
# Run the pipeline with configuration from the first cell
if __name__ == "__main__":
    result = main(auto_window=AUTO_WINDOW)
    
    # Extract debug dataframes if DEBUG_MODE was enabled
    if DEBUG_MODE and result is not None:
        # Priority 1: Core outputs
        newborn_ident_df = result['newborn_ident_df']
        nicu_ident = result['nicu_ident']
        newborns_df = result['newborns_df']
        nicu_rollup = result['nicu_rollup']
        
        # Priority 1: Critical intermediate dataframes
        claims_df_tagged = result['claims_df_tagged']
        hosp_rollup_df = result['hosp_rollup_df']
        newborn_claims = result['newborn_claims']
        
        # Priority 2: Detailed analysis
        newborn_hosp_clms = result['newborn_hosp_clms']
        nicu_claims_df = result['nicu_claims_df']
        nicu_dischg_provider = result['nicu_dischg_provider']
        rev_out = result['rev_out']
        drg_out = result['drg_out']
        
        
        print("\n" + "="*80)
        print("  DEBUG MODE: All intermediate dataframes available for inspection")
        print("="*80)
        
        print("\nPRIORITY 1 - Critical Debugging:")
        print("  • claims_df_tagged      : Claims after all reference flags tagged")
        print("  • hosp_rollup_df        : Hospital episodes (CHECK IF EMPTY!)")
        print("  • newborn_claims        : Claims after newborn rollup")
        print("  • newborn_ident_df      : Final newborn identities")
        print("  • nicu_ident            : NICU subset")
        
        print("\nPRIORITY 2 - Detailed Analysis:")
        print("  • newborn_hosp_clms     : Claims bounded to hospital episodes")
        print("  • nicu_claims_df        : NICU claims detail")
        print("  • nicu_dischg_provider  : Discharge provider analysis")
        print("  • rev_out               : Revenue code aggregations")
        print("  • drg_out               : DRG aggregations")
        
        print("\nOUTPUTS:")
        print("  • newborns_df           : Final newborns output table")
        print("  • nicu_rollup           : NICU rollup table")
        
        print("\n" + "="*80)
        print("DIAGNOSTIC WORKFLOW:")
        print("="*80)
        print("# 1. Check claim type distribution")
        print("   claims_df_tagged.group_by('CLAIM_TYPE').count().show()")
        print()
        print("# 2. Check if hosp_rollup_df is empty (likely culprit!)")
        print("   print(f'Hospital rollups: {hosp_rollup_df.count():,}')")
        print()
        print("# 3. If empty, check IP classification")
        print("   claims_df_tagged.filter(col('CLAIM_TYPE') == 'IP').count()")
        print()
        print("# 4. Check NICU revenue codes specifically")
        print("   claims_df_tagged.filter(col('REV_CD').between('0170', '0179')).\\")
        print("       select('REV_CD', 'CLAIM_TYPE').show()")
        print()
        print("# 5. Export to pandas for detailed inspection")
        print("   hosp_rollup_df.limit(100).to_pandas()")
        print("="*80 + "\n")


## Pipeline Results Summary

Display key metrics and statistics from the pipeline execution.

In [None]:
# Display pipeline execution summary
# Note: Run this cell after the main pipeline completes

try:
    # Check if pipeline variables exist
    if 'newborn_ident_df' in locals() and 'nicu_ident' in locals():
        total_newborns = newborn_ident_df.count()
        total_nicu = nicu_ident.count()
        nicu_rate = (total_nicu / total_newborns * 100) if total_newborns > 0 else 0
        
        # Get study year breakdown
        study_yr_counts = newborn_ident_df.group_by("STUDY_YR").count().collect()
        prev_count = next((row['COUNT'] for row in study_yr_counts if row['STUDY_YR'] == 'Previous'), 0)
        curr_count = next((row['COUNT'] for row in study_yr_counts if row['STUDY_YR'] == 'Current'), 0)
        
        # Calculate costs if available
        try:
            total_cost_data = nicu_ident.select(ssum("TOTAL_NICU_COST").alias("TOTAL")).collect()[0]
            total_nicu_cost = total_cost_data['TOTAL'] if total_cost_data['TOTAL'] else 0
            avg_nicu_cost = total_nicu_cost / total_nicu if total_nicu > 0 else 0
        except:
            total_nicu_cost = None
            avg_nicu_cost = None        
        # Calculate average LOS if available
        try:
            from snowflake.snowpark.functions import sum as ssum, count
            avg_los_data = nicu_ident.select(ssum("LOS").alias("TOTAL_LOS"), count("*").alias("COUNT")).collect()[0]
            total_los = avg_los_data['TOTAL_LOS'] if avg_los_data['TOTAL_LOS'] else 0
            avg_los = total_los / avg_los_data['COUNT'] if avg_los_data['COUNT'] > 0 else 0
        except:
            avg_los = None

        
        print(f"\n{'='*70}")
        print(f"  PIPELINE EXECUTION SUMMARY")
        print(f"{'='*70}")
        print(f"\nConfiguration:")
        print(f"  Client:              {CLIENT_DATA}")
        print(f"  Database:            {DATABASE}")
        print(f"  Table Suffix:        {TABLE_SUFFIX}")
        print(f"  Dry-Run Mode:        {DRY_RUN}")
        
        print(f"\nDate Windows:")
        if 'birth_window_start' in locals():
            print(f"  Birth Window:        {birth_window_start.date()} to {birth_window_end.date()}")
            print(f"  Period Split:        {birth_window_mid.date()}")
            print(f"  Runout End:          {runout_end.date()}")
        
        print(f"\nNewborn Statistics:")
        print(f"  Total Newborns:      {total_newborns:,}")
        print(f"    - Previous Period: {prev_count:,}")
        print(f"    - Current Period:  {curr_count:,}")
        
        print(f"\nNICU Statistics:")
        print(f"  NICU Cases:          {total_nicu:,}")
        print(f"  NICU Rate:           {nicu_rate:.1f}%")
        
        if avg_los is not None:
            print(f"  Average LOS:         {avg_los:.1f} days")
        
        if total_nicu_cost is not None:
            print(f"  Total NICU Cost:     ${total_nicu_cost:,.0f}")
            print(f"  Average Cost/Case:   ${avg_nicu_cost:,.0f}")
        
        print(f"\nOutput Tables:")
        print(f"  Membership:          {get_table_name('ps_membership', CLIENT_DATA)}")
        print(f"  Newborns:            {get_table_name('ps_newborns', CLIENT_DATA)}")
        
        print(f"\n{'='*70}\n")
        
    else:
        print("\n⚠ Pipeline has not been executed yet. Run the execution cell first.\n")
        
except Exception as e:
    print(f"\n⚠ Error generating summary: {e}")
    print("This cell should be run after the main pipeline execution.\n")
