In [1]:
import pandas as pd
import numpy as np
import os
import sys
import time
from tabulate import tabulate
from dotenv import load_dotenv
import snowflake.connector

# --- Initialize Connection ---
load_dotenv()
conn = None
try:
    conn = snowflake.connector.connect(
        user=os.getenv('SNOWFLAKE_USER'),
        password=os.getenv('SNOWFLAKE_PASSWORD'),
        account=os.getenv('SNOWFLAKE_ACCOUNT'),
        warehouse=os.getenv('SNOWFLAKE_WAREHOUSE'),
        database='INCREMENTALITY',
        schema='INCREMENTALITY_RESEARCH'
    )
    print("✅ Connection to Snowflake successful!")
except Exception as e:
    print(f"❌ ERROR: Could not connect to Snowflake. {e}", file=sys.stderr)
    exit()

# --- Load the Master Purchaser Universe ---
UNIVERSE_FILE = "final_purchaser_universe.parquet"
print(f"\n--- Loading master list from '{UNIVERSE_FILE}' ---")
try:
    df_universe = pd.read_parquet(UNIVERSE_FILE)
    print(f"✅ Loaded {len(df_universe):,} total unique purchasing users.")
except FileNotFoundError:
    print(f"❌ FATAL ERROR: The universe file '{UNIVERSE_FILE}' was not found.")
    exit()

✅ Connection to Snowflake successful!

--- Loading master list from 'final_purchaser_universe.parquet' ---
✅ Loaded 4,926,305 total unique purchasing users.


In [2]:
import polars as pl
import os
from tabulate import tabulate

# --- Configuration ---
# This is the main 32M row panel dataset
PANEL_FILE = "/Users/pranjal/Code/topsort-incrementality/panel/user_panel_full_history.parquet"
# This is the final output from your long-running script
HOLDOUT_FILE = "final_holdout_user_ids_final.parquet"
# This is the file we will create
OUTPUT_FILE = "user_panel_with_holdout_flag.parquet"

# --- Load the datasets using Polars ---
print("--- Loading Datasets ---")
try:
    print(f"-> Loading main panel from '{PANEL_FILE}'...")
    df_panel = pl.read_parquet(PANEL_FILE)
    print(f"   ✅ Loaded panel with {df_panel.height:,} rows.")

    print(f"-> Loading final holdout user list from '{HOLDOUT_FILE}'...")
    df_holdouts = pl.read_parquet(HOLDOUT_FILE)
    print(f"   ✅ Loaded {df_holdouts.height:,} holdout user IDs.")

    # --- NEW: Identify and load the "purchaser" list from the panel ---
    print(f"-> Identifying all users with at least one purchase from the panel...")
    df_purchasers = df_panel.filter(pl.col("purchases") > 0).select(pl.col("user_id").unique())
    print(f"   ✅ Identified {df_purchasers.height:,} unique purchasing users.")

except FileNotFoundError as e:
    print(f"\n❌ FATAL ERROR: A required file was not found. Please ensure the previous scripts have completed.")
    print(f"   Details: {e}")
    df_panel = None 

--- Loading Datasets ---
-> Loading main panel from '/Users/pranjal/Code/topsort-incrementality/panel/user_panel_full_history.parquet'...
   ✅ Loaded panel with 32,060,768 rows.
-> Loading final holdout user list from 'final_holdout_user_ids_final.parquet'...
   ✅ Loaded 784,133 holdout user IDs.
-> Identifying all users with at least one purchase from the panel...
   ✅ Identified 4,926,674 unique purchasing users.


In [3]:
if df_panel is not None:
    print("\n--- Enriching Panel with Holdout and Purchaser Flags ---")
    
    # --- Standardize column names to prevent errors ---
    print("-> Standardizing column names to lowercase...")
    df_panel.columns = [col.lower() for col in df_panel.columns]
    df_holdouts.columns = [col.lower() for col in df_holdouts.columns]
    df_purchasers.columns = [col.lower() for col in df_purchasers.columns]
    print("   ✅ Column names standardized.")
    
    # Use sets for maximum performance
    holdout_set = set(df_holdouts['user_id'])
    purchaser_set = set(df_purchasers['user_id'])

    # --- Create the flag columns using efficient expressions ---
    print("-> Creating 'is_holdout' and 'is_purchaser' columns...")
    df_enriched = df_panel.with_columns(
        is_holdout = pl.col('user_id').is_in(holdout_set).cast(pl.UInt8),
        is_purchaser = pl.col('user_id').is_in(purchaser_set).cast(pl.UInt8)
    )
    print("   ✅ Flag columns created successfully.")

    # --- Verification Step ---
    print("\n--- Verifying the new flags ---")
    print("Distribution of the 'is_holdout' flag:")
    print(df_enriched['is_holdout'].value_counts())
    print("\nDistribution of the 'is_purchaser' flag:")
    print(df_enriched['is_purchaser'].value_counts())


--- Enriching Panel with Holdout and Purchaser Flags ---
-> Standardizing column names to lowercase...
   ✅ Column names standardized.
-> Creating 'is_holdout' and 'is_purchaser' columns...
   ✅ Flag columns created successfully.

--- Verifying the new flags ---
Distribution of the 'is_holdout' flag:
shape: (2, 2)
┌────────────┬──────────┐
│ is_holdout ┆ count    │
│ ---        ┆ ---      │
│ u8         ┆ u32      │
╞════════════╪══════════╡
│ 0          ┆ 31220194 │
│ 1          ┆ 840574   │
└────────────┴──────────┘

Distribution of the 'is_purchaser' flag:
shape: (2, 2)
┌──────────────┬──────────┐
│ is_purchaser ┆ count    │
│ ---          ┆ ---      │
│ u8           ┆ u32      │
╞══════════════╪══════════╡
│ 0            ┆ 7703346  │
│ 1            ┆ 24357422 │
└──────────────┴──────────┘


In [4]:
if 'df_enriched' in locals():
    print("\n--- Performing Overall and Temporal EDA on the Enriched Panel ---")

    # --- 1. Calculate Overall Summary Statistics ---
    print("-> Calculating overall summary stats...")
    summary_stats = {
        "Total Rows (User-Weeks)": df_enriched.height,
        "Total Unique Users": df_enriched['user_id'].n_unique(),
        "Total Unique Purchasers": df_enriched.filter(pl.col('is_purchaser') == 1)['user_id'].n_unique(),
        "Total Unique Holdouts": df_enriched.filter(pl.col('is_holdout') == 1)['user_id'].n_unique(),
        "Total Clicks": df_enriched['clicks'].sum(),
        "Total Purchases": df_enriched['purchases'].sum(),
        "Total Revenue": df_enriched['revenue_dollars'].sum()
    }
    print("   ✅ Overall stats calculated.")

    # --- 2. Calculate Weekly Evolution of Metrics ---
    print("-> Calculating weekly evolution of metrics...")
    df_weekly = df_enriched.group_by('week').agg(
        active_users=pl.n_unique('user_id'),
        purchasing_users=pl.col('user_id').filter(pl.col('is_purchaser') == 1).n_unique(),
        holdout_users=pl.col('user_id').filter(pl.col('is_holdout') == 1).n_unique(),
        total_clicks=pl.sum('clicks'),
        total_purchases=pl.sum('purchases'),
        total_revenue=pl.sum('revenue_dollars')
    ).sort('week')
    print("   ✅ Weekly aggregates calculated.")

    # --- 3. Calculate Key Weekly Fractions ---
    print("-> Calculating weekly fractions...")
    df_weekly_fractions = df_weekly.with_columns(
        # What percentage of active users in a week are purchasers?
        purchaser_rate = pl.col('purchasing_users') / pl.col('active_users'),
        # What percentage of active users in a week are from the holdout group?
        holdout_rate = pl.col('holdout_users') / pl.col('active_users'),
        # How many clicks per active user?
        clicks_per_user = pl.col('total_clicks') / pl.col('active_users')
    ).select(['week', 'purchaser_rate', 'holdout_rate', 'clicks_per_user'])
    print("   ✅ Weekly fractions calculated.")
    
    # --- 4. Generate the Final Report ---
    report_filename = "panel_eda_summary_report.txt"
    with open(report_filename, "w") as f:
        f.write("Exploratory Data Analysis of the Enriched User-Week Panel\n")
        f.write("=" * 58 + "\n\n")

        f.write("Overall Dataset Summary\n")
        f.write("-----------------------\n")
        for key, value in summary_stats.items():
            f.write(f"- {key}: {value:,.0f}\n")
        f.write("\n\n")

        f.write("Evolution of Key Weekly Metrics\n")
        f.write("-------------------------------\n")
        f.write(tabulate(df_weekly.to_pandas(), headers='keys', tablefmt='grid', showindex=False, floatfmt=",.0f"))
        f.write("\n\n")

        f.write("Evolution of Key Weekly Fractions\n")
        f.write("---------------------------------\n")
        f.write(tabulate(df_weekly_fractions.to_pandas(), headers='keys', tablefmt='grid', showindex=False, floatfmt=".4f"))

    print(f"\n✅ ANALYSIS COMPLETE. A detailed EDA report has been saved to '{report_filename}'")
    
    # --- 5. Print a sample of the results to the console ---
    print("\n--- Evolution of Key Weekly Metrics (First 10 Weeks) ---")
    print(tabulate(df_weekly.head(10).to_pandas(), headers='keys', tablefmt='grid', showindex=False, floatfmt=",.0f"))
    
    print("\n--- Evolution of Key Weekly Fractions (First 10 Weeks) ---")
    print(tabulate(df_weekly_fractions.head(10).to_pandas(), headers='keys', tablefmt='grid', showindex=False, floatfmt=".4f"))


--- Performing Overall and Temporal EDA on the Enriched Panel ---
-> Calculating overall summary stats...
   ✅ Overall stats calculated.
-> Calculating weekly evolution of metrics...
   ✅ Weekly aggregates calculated.
-> Calculating weekly fractions...
   ✅ Weekly fractions calculated.

✅ ANALYSIS COMPLETE. A detailed EDA report has been saved to 'panel_eda_summary_report.txt'

--- Evolution of Key Weekly Metrics (First 10 Weeks) ---
+---------------------+----------------+--------------------+-----------------+----------------+-------------------+-----------------+
| week                |   active_users |   purchasing_users |   holdout_users |   total_clicks |   total_purchases |   total_revenue |
| 2025-03-10 00:00:00 |         571746 |             459848 |           14772 |        1537431 |            374695 |      17,165,085 |
+---------------------+----------------+--------------------+-----------------+----------------+-------------------+-----------------+
| 2025-03-17 00:00:00

In [1]:
import polars as pl
from tabulate import tabulate

# --- Configuration ---
ENRICHED_PANEL_FILE = "user_panel_with_holdout_flag.parquet"
FINAL_REPORT_FILE = "covariate_balance_check_report.txt"
CUTOFF_DATE = "2025-07-01"

# ==============================================================================
# PHASE 1: LOAD DATA AND DEFINE TIME PERIODS
# ==============================================================================
print("--- Phase 1: Loading Data and Defining Periods ---")
try:
    df_enriched = pl.read_parquet(ENRICHED_PANEL_FILE)
    print(f"✅ Successfully loaded panel with {df_enriched.height:,} rows.")
except FileNotFoundError:
    print(f"❌ FATAL ERROR: The enriched panel file '{ENRICHED_PANEL_FILE}' was not found.")
    df_enriched = None

if df_enriched is not None:
    cutoff_date_pl = pl.lit(CUTOFF_DATE).str.to_date()
    df_period1 = df_enriched.filter(pl.col("week") < cutoff_date_pl)
    df_period2 = df_enriched.filter(pl.col("week") >= cutoff_date_pl)
    print(f"   -> Split data with cutoff: {CUTOFF_DATE}")

    # ==============================================================================
    # PHASE 2: CONSTRUCT THE ANALYSIS DATAFRAME
    # ==============================================================================
    print("\n--- Phase 2: Constructing DataFrame with Controls ---")

    # Define Base Population (>= 3 purchases in P1)
    p1_user_purchases = df_period1.group_by("user_id").agg(total_purchases_p1=pl.sum("purchases"))
    base_users = p1_user_purchases.filter(pl.col("total_purchases_p1") >= 3).select("user_id")
    print(f"   -> Identified {base_users.height:,} users for the analysis base.")

    # Engineer Controls (X) from Period 1
    df_controls_p1 = df_period1.group_by("user_id").agg(
        revenue_p1=pl.sum("revenue_dollars"),
        purchases_p1=pl.sum("purchases"),
        clicks_p1=pl.sum("clicks")
    )
    
    # Get user holdout status
    user_holdout_status = df_enriched.select(["user_id", "is_holdout"]).unique(subset="user_id")
    
    # Construct the final DataFrame
    df_analysis = base_users.join(user_holdout_status, on="user_id", how="inner")
    df_analysis = df_analysis.join(df_controls_p1, on="user_id", how="left").fill_null(0)
    print(f"   -> Final analysis DataFrame has {df_analysis.height:,} users (rows).")
    
    # ==============================================================================
    # PHASE 3: COVARIATE BALANCE CHECK (INSPECTING THE CONTROLS)
    # ==============================================================================
    print("\n--- Phase 3: Performing Covariate Balance Check ---")

    # Group by cohort and calculate summary stats for Period 1 controls
    control_summary = df_analysis.group_by("is_holdout").agg(
        user_count=pl.len(),
        avg_revenue_p1=pl.mean("revenue_p1"),
        avg_purchases_p1=pl.mean("purchases_p1"),
        avg_clicks_p1=pl.mean("clicks_p1"),
        median_revenue_p1=pl.median("revenue_p1")
    ).with_columns(
        cohort = pl.when(pl.col('is_holdout') == 1).then(pl.lit("Control")).otherwise(pl.lit("Treatment"))
    ).select("cohort", "user_count", "avg_revenue_p1", "avg_purchases_p1", "avg_clicks_p1", "median_revenue_p1")
    
    print("   -> Summary statistics for controls calculated.")
    
    # ==============================================================================
    # PHASE 4: FINAL REPORTING
    # ==============================================================================
    print(f"\n--- Phase 4: Generating Final Report ---")

    with open(FINAL_REPORT_FILE, "w") as f:
        f.write("Covariate Balance Check Report (Comparison of Period 1 Controls)\n")
        f.write("=" * 64 + "\n\n")
        f.write("Methodology:\n")
        f.write("This table compares the average characteristics of the Treatment and Control groups\n")
        f.write("based *only* on their activity in Period 1 (before the outcome period).\n")
        f.write("Large differences here indicate strong selection bias that regression must control for.\n\n")
        
        f.write(tabulate(control_summary.to_pandas(), headers='keys', tablefmt='grid', showindex=False, floatfmt=".2f"))

    print(f"\n✅ ANALYSIS COMPLETE. Covariate balance report saved to '{FINAL_REPORT_FILE}'")
    
    # Print the results to the console
    print("\n--- Covariate Balance Check ---")
    print(tabulate(control_summary.to_pandas(), headers='keys', tablefmt='grid', showindex=False, floatfmt=".2f"))

--- Phase 1: Loading Data and Defining Periods ---
✅ Successfully loaded panel with 32,060,768 rows.
   -> Split data with cutoff: 2025-07-01

--- Phase 2: Constructing DataFrame with Controls ---
   -> Identified 1,119,128 users for the analysis base.
   -> Final analysis DataFrame has 1,119,128 users (rows).

--- Phase 3: Performing Covariate Balance Check ---
   -> Summary statistics for controls calculated.

--- Phase 4: Generating Final Report ---

✅ ANALYSIS COMPLETE. Covariate balance report saved to 'covariate_balance_check_report.txt'

--- Covariate Balance Check ---
+-----------+--------------+------------------+--------------------+-----------------+---------------------+
| cohort    |   user_count |   avg_revenue_p1 |   avg_purchases_p1 |   avg_clicks_p1 |   median_revenue_p1 |
| Treatment |      1111277 |           391.34 |               9.30 |           48.92 |              173.00 |
+-----------+--------------+------------------+--------------------+-----------------+----

In [8]:
!pip install doubleml scikit-learn lightgbm -q

import polars as pl
import pandas as pd
import numpy as np
import os
from tabulate import tabulate
import doubleml as dml
from lightgbm import LGBMRegressor, LGBMClassifier
from sklearn.base import clone

# --- Configuration ---
ENRICHED_PANEL_FILE = "user_panel_with_holdout_flag.parquet"
FINAL_REPORT_FILE = "doubleml_irm_tenure_control_report.txt"
CUTOFF_DATE = "2025-07-01"

# ==============================================================================
# PHASE 1: FEATURE & OUTCOME ENGINEERING (WITH TENURE PROXY)
# ==============================================================================
print("--- Phase 1: Engineering Features, Outcomes, and Tenure Proxy ---")

df_enriched = pl.read_parquet(ENRICHED_PANEL_FILE)
cutoff_date_pl = pl.lit(CUTOFF_DATE).str.to_date()
df_period1 = df_enriched.filter(pl.col("week") < cutoff_date_pl)
df_period2 = df_enriched.filter(pl.col("week") >= cutoff_date_pl)

# Define Base Population (>= 3 purchases in P1)
p1_user_purchases = df_period1.group_by("user_id").agg(total_purchases_p1=pl.sum("purchases"))
base_users = p1_user_purchases.filter(pl.col("total_purchases_p1") >= 3).select("user_id")
print(f"   -> Identified {base_users.height:,} users for the analysis base.")

# --- Engineer the Tenure Proxy ---
df_tenure = df_period1.group_by("user_id").agg(
    first_week_p1=pl.min("week")
)
period1_start_date = df_period1.select(pl.min("week"))[0, 0]

# --- THE FIX IS HERE: Use .dt.total_days() instead of .dt.days() ---
df_tenure = df_tenure.with_columns(
    join_week_index_p1 = ((pl.col('first_week_p1') - period1_start_date).dt.total_days() // 7)
).select(["user_id", "join_week_index_p1"])
print("   -> Engineered 'join_week_index_p1' as a tenure proxy.")

# Engineer Per-Week Controls (X) from Period 1
df_controls_p1 = df_period1.group_by("user_id").agg(
    avg_weekly_revenue_p1=pl.mean("revenue_dollars"),
    avg_weekly_purchases_p1=pl.mean("purchases"),
    avg_weekly_clicks_p1=pl.mean("clicks")
)

# Engineer Per-Week Outcome (Y) from Period 2
df_outcomes_p2 = df_period2.group_by("user_id").agg(
    avg_weekly_revenue_p2=pl.mean("revenue_dollars")
)

# ==============================================================================
# PHASE 2: CONSTRUCT FINAL REGRESSION DATAFRAME
# ==============================================================================
print("\n--- Phase 2: Constructing Final DataFrame for Regression ---")
user_holdout_status = df_enriched.select(["user_id", "is_holdout"]).unique(subset="user_id")

df_analysis = base_users.join(user_holdout_status, on="user_id", how="inner")
df_analysis = df_analysis.join(df_controls_p1, on="user_id", how="left").fill_null(0)
df_analysis = df_analysis.join(df_outcomes_p2, on="user_id", how="left").fill_null(0)
df_analysis = df_analysis.join(df_tenure, on="user_id", how="left").fill_null(0)
df_analysis = df_analysis.with_columns(is_treated = 1 - pl.col("is_holdout"))

data_pd = df_analysis.to_pandas()
float_cols = [col for col in data_pd.columns if 'revenue' in col or 'purchases' in col or 'clicks' in col or 'index' in col]
data_pd[float_cols] = data_pd[float_cols].astype('float64')

print(f"   -> Final DataFrame for analysis has {len(data_pd):,} users (rows).")

# ==============================================================================
# PHASE 3: DOUBLEML SETUP AND ATE ESTIMATION
# ==============================================================================
print("\n--- Phase 3: Estimating ATE with Tenure Control ---")

x_cols = ['avg_weekly_revenue_p1', 'avg_weekly_purchases_p1', 'avg_weekly_clicks_p1', 'join_week_index_p1']
dml_data = dml.DoubleMLData(data_pd,
                              y_col='avg_weekly_revenue_p2',
                              d_cols='is_treated',
                              x_cols=x_cols)

learner_g = LGBMRegressor(n_jobs=-1, random_state=42, verbose=-1)
learner_m = LGBMClassifier(n_jobs=-1, random_state=42, verbose=-1)

dml_irm_obj = dml.DoubleMLIRM(dml_data,
                            ml_g=clone(learner_g),
                            ml_m=clone(learner_m))

dml_irm_obj.fit(store_predictions=True)
print("   -> DoubleML IRM model fitting complete.")
print(dml_irm_obj.summary)

# ==============================================================================
# PHASE 4: HETEROGENEITY ANALYSIS (GATES)
# ==============================================================================
print("\n--- Phase 4: Analyzing Heterogeneous Effects with GATEs ---")

tenure_median = data_pd['join_week_index_p1'].median()

groups = pd.DataFrame({
    'Early_Joiners_P1': (data_pd['join_week_index_p1'] <= tenure_median),
    'Late_Joiners_P1': (data_pd['join_week_index_p1'] > tenure_median)
})

gate_results = dml_irm_obj.gate(groups)
print("   -> Group Average Treatment Effects (GATEs) calculated.")
print(gate_results)

# ==============================================================================
# PHASE 5: FINAL REPORTING
# ==============================================================================
print(f"\n--- Phase 5: Generating Final Report ---")

with open(FINAL_REPORT_FILE, "w") as f:
    f.write("DoubleML IRM Analysis with Tenure Control\n")
    f.write("=" * 41 + "\n\n")
    f.write("Methodology:\n")
    f.write("A DoubleML IRM model was used to estimate the causal effect of ad exposure on the\n")
    f.write("average weekly revenue in Period 2. The model now includes a 'join_week_index_p1'\n")
    f.write("as a tenure proxy to control for when a user first became active.\n\n")
    
    f.write("Overall Average Treatment Effect (ATE) on Weekly Revenue\n")
    f.write("--------------------------------------------------------\n")
    f.write(str(dml_irm_obj.summary))
    f.write("\n\n")
    
    f.write("Heterogeneous Effects by User Tenure (GATEs)\n")
    f.write("--------------------------------------------\n")
    f.write("This table shows if the ATE on weekly revenue is different for users who joined\n")
    f.write("early in Period 1 vs. those who joined later.\n\n")
    f.write(str(gate_results))

print(f"\n✅ ANALYSIS COMPLETE. Final report saved to '{FINAL_REPORT_FILE}'")

--- Phase 1: Engineering Features, Outcomes, and Tenure Proxy ---
   -> Identified 1,119,128 users for the analysis base.
   -> Engineered 'join_week_index_p1' as a tenure proxy.

--- Phase 2: Constructing Final DataFrame for Regression ---
   -> Final DataFrame for analysis has 1,119,128 users (rows).

--- Phase 3: Estimating ATE with Tenure Control ---




   -> DoubleML IRM model fitting complete.
                coef   std err         t  P>|t|    2.5 %   97.5 %
is_treated  7.837105  0.171577  45.67686    0.0  7.50082  8.17339

--- Phase 4: Analyzing Heterogeneous Effects with GATEs ---
   -> Group Average Treatment Effects (GATEs) calculated.

------------------ Fit summary ------------------
                      coef   std err          t         P>|t|    [0.025  \
Early_Joiners_P1  9.797284  0.200115  48.958217  0.000000e+00  9.405065   
Late_Joiners_P1   4.873772  0.306875  15.881931  8.453542e-57  4.272307   

                     0.975]  
Early_Joiners_P1  10.189502  
Late_Joiners_P1    5.475236  

--- Phase 5: Generating Final Report ---

✅ ANALYSIS COMPLETE. Final report saved to 'doubleml_irm_tenure_control_report.txt'


In [12]:
!pip install doubleml scikit-learn lightgbm -q

import polars as pl
import pandas as pd
import numpy as np
import os
from tabulate import tabulate
import doubleml as dml
from lightgbm import LGBMRegressor, LGBMClassifier
from sklearn.base import clone

# --- Configuration ---
ENRICHED_PANEL_FILE = "user_panel_with_holdout_flag.parquet"
FINAL_REPORT_FILE = "final_interpreted_causal_analysis_report.txt"
CUTOFF_DATE = "2025-07-01"

# ==============================================================================
# PHASE 1: FEATURE & OUTCOME ENGINEERING (WITH TENURE PROXY)
# ==============================================================================
print("--- Phase 1: Engineering Features, Outcomes, and Tenure Proxy ---")

df_enriched = pl.read_parquet(ENRICHED_PANEL_FILE)
cutoff_date_pl = pl.lit(CUTOFF_DATE).str.to_date()
df_period1 = df_enriched.filter(pl.col("week") < cutoff_date_pl)
df_period2 = df_enriched.filter(pl.col("week") >= cutoff_date_pl)

# Define Base Population (>= 3 purchases in P1)
p1_user_purchases = df_period1.group_by("user_id").agg(total_purchases_p1=pl.sum("purchases"))
base_users = p1_user_purchases.filter(pl.col("total_purchases_p1") >= 3).select("user_id")
print(f"   -> Identified {base_users.height:,} users for the analysis base.")

# Engineer the Tenure Proxy
df_tenure = df_period1.group_by("user_id").agg(first_week_p1=pl.min("week"))
period1_start_date = df_period1.select(pl.min("week"))[0, 0]
df_tenure = df_tenure.with_columns(
    join_week_index_p1 = ((pl.col('first_week_p1') - period1_start_date).dt.total_days() // 7)
).select(["user_id", "join_week_index_p1"])
print("   -> Engineered 'join_week_index_p1' as a tenure proxy.")

# Engineer Per-Week Controls (X) from Period 1
df_controls_p1 = df_period1.group_by("user_id").agg(
    avg_weekly_revenue_p1=pl.mean("revenue_dollars"),
    avg_weekly_purchases_p1=pl.mean("purchases"),
    avg_weekly_clicks_p1=pl.mean("clicks")
)

# Engineer Per-Week Outcome (Y) from Period 2
df_outcomes_p2 = df_period2.group_by("user_id").agg(
    avg_weekly_revenue_p2=pl.mean("revenue_dollars")
)

# ==============================================================================
# PHASE 2: CONSTRUCT FINAL REGRESSION DATAFRAME
# ==============================================================================
print("\n--- Phase 2: Constructing Final DataFrame for Regression ---")
user_holdout_status = df_enriched.select(["user_id", "is_holdout"]).unique(subset="user_id")

df_analysis = base_users.join(user_holdout_status, on="user_id", how="inner")
df_analysis = df_analysis.join(df_controls_p1, on="user_id", how="left").fill_null(0)
df_analysis = df_analysis.join(df_outcomes_p2, on="user_id", how="left").fill_null(0)
df_analysis = df_analysis.join(df_tenure, on="user_id", how="left").fill_null(0)
df_analysis = df_analysis.with_columns(is_treated = 1 - pl.col("is_holdout"))

data_pd = df_analysis.to_pandas()
# Fix for TypeError: Convert all key numeric columns to standard float64
float_cols = [col for col in data_pd.columns if 'revenue' in col or 'purchases' in col or 'clicks' in col or 'index' in col]
data_pd[float_cols] = data_pd[float_cols].astype('float64')

print(f"   -> Final DataFrame for analysis has {len(data_pd):,} users (rows).")

# ==============================================================================
# PHASE 3: DOUBLEML ATE ESTIMATION WITH TENURE CONTROL
# ==============================================================================
print("\n--- Phase 3: Estimating ATE with Tenure Control ---")

x_cols = ['avg_weekly_revenue_p1', 'avg_weekly_purchases_p1', 'avg_weekly_clicks_p1', 'join_week_index_p1']
dml_data = dml.DoubleMLData(data_pd,
                              y_col='avg_weekly_revenue_p2',
                              d_cols='is_treated',
                              x_cols=x_cols)

learner_g = LGBMRegressor(n_jobs=-1, random_state=42, verbose=-1)
learner_m = LGBMClassifier(n_jobs=-1, random_state=42, verbose=-1)

dml_irm_obj = dml.DoubleMLIRM(dml_data,
                            ml_g=clone(learner_g),
                            ml_m=clone(learner_m))

dml_irm_obj.fit(store_predictions=True)
print("   -> DoubleML IRM model fitting complete.")
print(dml_irm_obj.summary)

# ==============================================================================
# PHASE 4: HETEROGENEITY ANALYSIS (GATES)
# ==============================================================================
print("\n--- Phase 4: Analyzing Heterogeneous Effects by Tenure ---")

tenure_median = data_pd['join_week_index_p1'].median()
groups = pd.DataFrame({
    'Early_Joiners_P1': (data_pd['join_week_index_p1'] <= tenure_median),
    'Late_Joiners_P1': (data_pd['join_week_index_p1'] > tenure_median)
})

gate_results = dml_irm_obj.gate(groups)
print("   -> Group Average Treatment Effects (GATEs) calculated.")
print(gate_results.summary)

# ==============================================================================
# PHASE 5: FINAL INTERPRETABLE REPORTING
# ==============================================================================
print(f"\n--- Phase 5: Generating Final Interpretable Report ---")

# Calculate baseline revenue for the control group
control_group_data = data_pd[data_pd['is_treated'] == 0]
baseline_revenue = control_group_data['avg_weekly_revenue_p2'].mean()

# Interpret Overall ATE
ate_summary = dml_irm_obj.summary
ate_abs = ate_summary.loc['is_treated', 'coef']
ate_lower = ate_summary.loc['is_treated', '2.5 %']
ate_upper = ate_summary.loc['is_treated', '97.5 %']
ate_pct = (ate_abs / baseline_revenue) * 100 if baseline_revenue > 0 else float('inf')
ate_pct_lower = (ate_lower / baseline_revenue) * 100 if baseline_revenue > 0 else float('inf')
ate_pct_upper = (ate_upper / baseline_revenue) * 100 if baseline_revenue > 0 else float('inf')

# Interpret GATEs
gate_summary_df = gate_results.summary
interpreted_gates = []
for group_name in gate_summary_df.index:
    gate_abs = gate_summary_df.loc[group_name, 'coef']
    gate_lower = gate_summary_df.loc[group_name, '2.5 %']
    gate_upper = gate_summary_df.loc[group_name, '97.5 %']
    
    group_mask = groups[group_name]
    control_in_gate = data_pd[(data_pd['is_treated'] == 0) & (group_mask)]
    gate_baseline = control_in_gate['avg_weekly_revenue_p2'].mean()
    
    if gate_baseline > 0:
        gate_pct = (gate_abs / gate_baseline) * 100
        gate_pct_lower = (gate_lower / gate_baseline) * 100
        gate_pct_upper = (gate_upper / gate_baseline) * 100
    else:
        gate_pct, gate_pct_lower, gate_pct_upper = np.nan, np.nan, np.nan
        
    interpreted_gates.append({
        'Subgroup': group_name,
        'ATE ($ Lift)': f"${gate_abs:.2f}",
        '95% CI Range ($)': f"(${gate_lower:.2f}, ${gate_upper:.2f})",
        'ATE (% Lift)': f"{gate_pct:+.2f}%",
        'Baseline Revenue': f"${gate_baseline:.2f}"
    })
gates_df = pd.DataFrame(interpreted_gates)

# Write the final report file
with open(FINAL_REPORT_FILE, "w") as f:
    f.write("Final Interpreted Report: Causal Lift of Advertising\n")
    f.write("=" * 51 + "\n\n")
    
    f.write("Overall Average Treatment Effect (ATE)\n")
    f.write("--------------------------------------\n")
    f.write(f"This is the average causal impact of advertising on a core customer's weekly revenue,\n")
    f.write(f"after controlling for their prior behavior and tenure.\n\n")
    f.write(f"  - Baseline (Control Group Avg. Weekly Revenue): ${baseline_revenue:.2f}\n")
    f.write(f"  - Absolute Incremental Lift:                    ${ate_abs:.2f} per user per week\n")
    f.write(f"  - 95% Confidence Interval (Absolute):         (${ate_lower:.2f}, ${ate_upper:.2f})\n\n")
    f.write(f"  - Relative Incremental Lift:                    {ate_pct:+.2f}%\n")
    f.write(f"  - 95% Confidence Interval (Relative):         ({ate_pct_lower:+.2f}%, {ate_pct_upper:+.2f}%)\n\n\n")
    
    f.write("Heterogeneous Effects by User Tenure (GATEs)\n")
    f.write("--------------------------------------------\n")
    f.write("This table shows how the treatment effect varies based on a user's tenure in Period 1.\n\n")
    f.write(tabulate(gates_df, headers='keys', tablefmt='grid', showindex=False))

print(f"\n✅ ANALYSIS COMPLETE. Final, interpreted report saved to '{FINAL_REPORT_FILE}'")

--- Phase 1: Engineering Features, Outcomes, and Tenure Proxy ---
   -> Identified 1,119,128 users for the analysis base.
   -> Engineered 'join_week_index_p1' as a tenure proxy.

--- Phase 2: Constructing Final DataFrame for Regression ---
   -> Final DataFrame for analysis has 1,119,128 users (rows).

--- Phase 3: Estimating ATE with Tenure Control ---




   -> DoubleML IRM model fitting complete.
                coef   std err          t  P>|t|     2.5 %    97.5 %
is_treated  7.952556  0.165686  47.997892    0.0  7.627818  8.277294

--- Phase 4: Analyzing Heterogeneous Effects by Tenure ---
   -> Group Average Treatment Effects (GATEs) calculated.
                       coef   std err          t         P>|t|    [0.025  \
Early_Joiners_P1  10.074092  0.207872  48.463047  0.000000e+00  9.666671   
Late_Joiners_P1    4.745287  0.272761  17.397251  8.655826e-68  4.210686   

                     0.975]  
Early_Joiners_P1  10.481513  
Late_Joiners_P1    5.279888  

--- Phase 5: Generating Final Interpretable Report ---


KeyError: '2.5 %'