In [26]:
# Student data generator for digital transformation dataset
import pandas as pd
import numpy as np

# Write your student number here
STUDENT_number = 715056

# Load data
df = pd.read_csv('digital_transformation.csv')
np.random.seed(STUDENT_number)
noise_vars = ['revenue', 'productivity_score', 'employee_satisfaction', 'customer_satisfaction', 'ceo_experience']
for var in noise_vars:
    df[var] += np.random.normal(0, df[var].std() * 0.05)

# Save unique dataset
df.to_csv('digital_transformation_unique.csv', index=False)

In [27]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import statsmodels.api as sm
import statsmodels.formula.api as smf
from linearmodels.panel import PanelOLS
from scipy.stats import gaussian_kde, f
from statsmodels.stats.outliers_influence import variance_inflation_factor

from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.preprocessing import StandardScaler

import warnings
warnings.filterwarnings('ignore')

# Try to import CausalML, fallback if not available
try:
    from causalml.inference.tree import CausalTreeRegressor
    CAUSAL_ML_AVAILABLE = True
except ImportError:
    print("⚠️ CausalML not available. HTE will fallback.")
    CAUSAL_ML_AVAILABLE = False

# Try to import EconML as backup
try:
    from econml.dml import DMLOrthoForest
    ECONML_AVAILABLE = True
except ImportError:
    ECONML_AVAILABLE = False


In [28]:
# Load the dataset
df = pd.read_csv('digital_transformation_unique.csv')

# Get variable names
print("Variables in dataset:")
print(list(df.columns))

# Descriptive statistics for numeric variables
print("\nDescriptive Statistics (Numeric):")
print(df.describe())

# Descriptive statistics for non-numeric variables
print("\nDescriptive Statistics (Non-Numeric):")
print(df.select_dtypes(include='object').describe())


Variables in dataset:
['firm_id', 'year', 'industry_type', 'firm_strategy', 'firm_size', 'firm_age', 'ceo_experience', 'ceo_tenure', 'prior_tech_adoption', 'tech_budget_pct', 'digital_transformation_program', 'market_conditions', 'revenue', 'revenue_growth', 'operating_costs', 'rd_spending', 'productivity_score', 'employee_satisfaction', 'customer_satisfaction', 'profit_margin', 'rd_intensity', 'large_firm', 'post_2020']

Descriptive Statistics (Numeric):
           firm_id        year      firm_size    firm_age  ceo_experience  \
count  3000.000000  3000.00000    3000.000000  3000.00000     3000.000000   
mean    250.500000  2020.50000   11203.480333    23.31000       10.222409   
std     144.361341     1.70811   30609.496929    10.39724        7.752553   
min       1.000000  2018.00000      23.000000     4.00000        2.106009   
25%     125.750000  2019.00000    1364.000000    14.00000        4.381009   
50%     250.500000  2020.50000    3698.500000    23.00000        7.806009   
7

In [29]:
def preprocess_data(df):
    """Create essential variables for causal analysis"""
    # Feature engineering
    df['log_firm_size'] = np.log(df['firm_size'] + 1)
    df['post_treatment'] = (df['year'] > 2020).astype(int)
    
    # Industry simplification
    top_industries = df['industry_type'].value_counts().nlargest(3).index
    df['industry_group'] = df['industry_type'].where(
        df['industry_type'].isin(top_industries), 'Other'
    )
    return df

def assess_covariate_balance(df):
    """Check pre-treatment comparability of groups"""
    from tableone import TableOne
    
    covariates = ['log_firm_size', 'firm_age', 'prior_tech_adoption', 
                 'tech_budget_pct', 'ceo_experience', 'industry_group']
    
    table = TableOne(df, covariates, groupby='digital_transformation_program',
                    pval=True, smd=True, htest_name=True)
    return table

df = preprocess_data(df)

balance_table = assess_covariate_balance(df[df['year'] == 2018])  # Pre-treatment year
print(balance_table)

                                         Grouped by digital_transformation_program                                                                        
                                                                           Missing      Overall            0           1 SMD (0,1) P-Value            Test
n                                                                                           500          417          83                                  
log_firm_size, mean (SD)                                                         0    8.1 (1.5)    8.0 (1.5)   8.8 (1.4)     0.610  <0.001  Welch’s T-test
firm_age, mean (SD)                                                              0  20.8 (10.3)  20.9 (10.4)  20.3 (9.6)    -0.061   0.602  Welch’s T-test
prior_tech_adoption, n (%) 0                                                         311 (62.2)   251 (60.2)   60 (72.3)     0.258   0.051     Chi-squared
                           1                                          

In [30]:
def check_multicollinearity(df, covariates):
    """Check VIF for multicollinearity with enhanced reporting - FIXED VERSION"""
    print("\n🔗 MULTICOLLINEARITY ANALYSIS:")
    print("-" * 40)
    
    # Prepare data - handle categorical conversion
    X = df[covariates].copy()
    
    # Convert categoricals and ensure FULL numeric conversion
    cat_cols = X.select_dtypes(include=['object', 'category']).columns
    if not cat_cols.empty:
        print(f"Converting categorical variables: {list(cat_cols)}")
        X = pd.get_dummies(X, columns=cat_cols, drop_first=True)
    
    # CRITICAL FIX: Force conversion to numeric and drop non-convertible columns
    X = X.apply(pd.to_numeric, errors='coerce')
    non_numeric_cols = X.columns[X.isnull().any()]
    if not non_numeric_cols.empty:
        print(f"Removing non-convertible columns: {list(non_numeric_cols)}")
        X = X.drop(columns=non_numeric_cols)
    
    # Add constant and calculate VIF
    X_const = sm.add_constant(X)
    vif_data = []
    for i, col in enumerate(X_const.columns):
        if col != 'const':
            vif = variance_inflation_factor(X_const.values.astype(float), i)  # Ensure float type
            vif_data.append({'Variable': col, 'VIF': vif})
    
    vif_df = pd.DataFrame(vif_data).sort_values('VIF', ascending=False)
    
    # Critical diagnostics (unchanged)
    high_vif = vif_df[vif_df['VIF'] > 10]
    moderate_vif = vif_df[(vif_df['VIF'] > 5) & (vif_df['VIF'] <= 10)]
    
    print("\nVariance Inflation Factors (Top 10):")
    print(vif_df.head(10))
    
    if not high_vif.empty:
        print(f"\n🔴 Critical multicollinearity (VIF > 10):")
        print(high_vif)
    elif not moderate_vif.empty:
        print(f"\n🟠 Moderate multicollinearity (5 < VIF ≤ 10):")
        print(moderate_vif)
    else:
        print("\n🟢 No significant multicollinearity detected")
    
    return vif_df

# Define key covariates (use consistent naming)
key_covariates = [
    'log_firm_size',
    'firm_age',
    'ceo_experience',
    'prior_tech_adoption',
    'tech_budget_pct',
    'market_conditions',
    'industry_group'  # Will be properly converted
]

# Re-run analysis
vif_results = check_multicollinearity(df, key_covariates)


🔗 MULTICOLLINEARITY ANALYSIS:
----------------------------------------
Converting categorical variables: ['industry_group']

Variance Inflation Factors (Top 10):
                       Variable       VIF
7          industry_group_Other  2.275772
8    industry_group_Real Estate  1.811153
6  industry_group_Manufacturing  1.797048
3           prior_tech_adoption  1.013640
2                ceo_experience  1.011113
4               tech_budget_pct  1.009149
1                      firm_age  1.006049
0                 log_firm_size  1.002426
5             market_conditions  1.000971

🟢 No significant multicollinearity detected


In [34]:
# ======================
# DATA PREPARATION
# ======================

def load_and_validate_data(path='digital_transformation_unique.csv'):
    """Load and validate digital transformation dataset"""
    try:
        df = pd.read_csv(path)
        print("✅ Dataset loaded successfully")
        print(f"Dataset shape: {df.shape}")
        return df
    except FileNotFoundError:
        print("❌ Dataset file not found. Please ensure the CSV exists.")
        return None

def winsorize_series(series, lower=0.01, upper=0.99):
    """Winsorize series to handle outliers"""
    lower_bound = series.quantile(lower)
    upper_bound = series.quantile(upper)
    return np.clip(series, lower_bound, upper_bound)

def preprocess_data(df):
    """Feature engineering with proper time-varying covariates for MSM"""
    if df is None:
        return None
    
    # Basic feature engineering
    df['log_firm_size'] = np.log(df.get('firm_size', 0) + 1)
    df['firm_size_bin'] = pd.cut(df['log_firm_size'], bins=[0,7,9,11,15], 
                                labels=['Micro','Small','Medium','Large'])
    
    # Sort for proper lagging
    df = df.sort_values(['firm_id','year'])
    
    # Create lagged confounders for MSM (Session 7 approach)
    df['rd_intensity_lag1'] = df.groupby('firm_id')['rd_intensity'].shift(1).fillna(0)
    df['tech_budget_lag1'] = df.groupby('firm_id')['tech_budget_pct'].shift(1).fillna(0)
    df['ceo_exp_lag1'] = df.groupby('firm_id')['ceo_experience'].shift(1).fillna(0)
    df['productivity_lag1'] = df.groupby('firm_id')['productivity_score'].shift(1).fillna(0)
    
    # Treatment history lags
    df['treatment_lag1'] = df.groupby('firm_id')['digital_transformation_program'].shift(1).fillna(0)
    df['treatment_lag2'] = df.groupby('firm_id')['digital_transformation_program'].shift(2).fillna(0)
    
    # Winsorize key features to handle outliers
    features_to_winsorize = [
        'rd_intensity', 'tech_budget_pct', 'productivity_score',
        'employee_satisfaction', 'log_firm_size', 'rd_intensity_lag1'
    ]
    for feature in features_to_winsorize:
        if feature in df.columns:
            df[feature] = winsorize_series(df[feature])
    
    return df

# ======================
# MARGINAL STRUCTURAL MODELS (MSM) - Session 7
# ======================

def estimate_treatment_propensities(df):
    """Estimate treatment propensities for MSM/IPTW as taught in Session 7"""
    print("Estimating treatment propensities for MSM...")
    
    # Define confounders based on DAG (not empirical balance checking)
    confounders = ['rd_intensity_lag1', 'tech_budget_lag1', 'ceo_exp_lag1',
                   'productivity_lag1', 'treatment_lag1', 'treatment_lag2',
                   'firm_age', 'log_firm_size']
    
    # Check for available confounders
    available_conf = [c for c in confounders if c in df.columns]
    print(f"Using confounders: {available_conf}")
    
    # Prepare data for propensity estimation
    ps_data = df[available_conf + ['digital_transformation_program', 'year']].dropna()
    
    if ps_data.empty:
        print("No valid data for propensity estimation")
        df['msm_weights'] = 1.0
        return df
    
    # Add year fixed effects
    year_dummies = pd.get_dummies(ps_data['year'], prefix='year')
    X = pd.concat([ps_data[available_conf], year_dummies], axis=1)
    y = ps_data['digital_transformation_program']
    
    # Fit logistic regression for P(D_it | past D, past confounders)
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    lr = LogisticRegression(random_state=42, max_iter=1000)
    lr.fit(X_scaled, y)
    
    # Get predicted probabilities
    prob_treatment = lr.predict_proba(X_scaled)[:, 1]
    
    # Calculate stabilized weights
    marginal_prob = y.mean()
    weights = np.where(y == 1, 
                      marginal_prob / np.maximum(prob_treatment, 0.01),
                      (1 - marginal_prob) / np.maximum(1 - prob_treatment, 0.01))
    
    # Trim extreme weights
    weights = np.clip(weights, 0.1, 10)
    
    # Add weights to dataframe
    df_copy = df.copy()
    weight_series = pd.Series(weights, index=ps_data.index)
    df_copy['msm_weights'] = df_copy.index.map(weight_series).fillna(1.0)
    
    print(f"✅ MSM weights computed: mean={weights.mean():.3f}, min={weights.min():.3f}, max={weights.max():.3f}")
    
    return df_copy

def run_msm_analysis(df, outcome):
    """Run MSM analysis with IPTW as taught in Session 7"""
    print(f"Running MSM analysis for {outcome}...")
    
    if 'msm_weights' not in df.columns:
        df = estimate_treatment_propensities(df)
    
    # Set up panel data
    panel_data = df.set_index(['firm_id', 'year'])
    
    # MSM specification: outcome ~ D_it + EntityEffects + TimeEffects
    formula = f"{outcome} ~ digital_transformation_program + EntityEffects + TimeEffects"
    
    try:
        model = PanelOLS.from_formula(formula, data=panel_data, weights=panel_data['msm_weights'])
        results = model.fit(cov_type='clustered', cluster_entity=True)
        
        return {
            'ate': results.params['digital_transformation_program'],
            'se': results.std_errors['digital_transformation_program'],
            'pval': results.pvalues['digital_transformation_program'],
            'model': results
        }
    except Exception as e:
        print(f"MSM estimation failed: {e}")
        return None

# ======================
# PARALLEL TRENDS TEST - Session 4 Approach
# ======================

def test_parallel_trends(df, outcome, pre_years=[2018, 2019]):
    """Test parallel trends using lead dummies with collinearity fix"""
    print(f"Testing parallel trends for {outcome}...")
    
    # Create ever_treated indicator
    ever_treated = df.groupby('firm_id')['digital_transformation_program'].max()
    df = df.merge(ever_treated.rename('ever_treated'), left_on='firm_id', right_index=True)
    
    # Filter to pre-treatment years
    pre_data = df[df['year'].isin(pre_years)].copy()
    
    if pre_data.empty:
        print("No pre-treatment data available")
        return None
    
    # Create year dummies and interactions
    pre_data['year_2019'] = (pre_data['year'] == 2019).astype(int)
    pre_data['treated_x_2019'] = pre_data['ever_treated'] * pre_data['year_2019']
    
    # Run regression with collinearity fix
    try:
        panel_data = pre_data.set_index(['firm_id', 'year'])
        
        # SOLUTION: Combine both approaches
        formula = f"{outcome} ~ year_2019 + treated_x_2019 + EntityEffects"
        model = PanelOLS.from_formula(
            formula, 
            data=panel_data,
            drop_absorbed=True  # Option A: Auto-drop collinear terms
        )
        results = model.fit(cov_type='clustered', cluster_entity=True)
        
        # Test if treated_x_2019 coefficient is significant
        coef = results.params.get('treated_x_2019', np.nan)
        pval = results.pvalues.get('treated_x_2019', np.nan)
        
        print(f"Parallel trends test: coef={coef:.4f}, p-value={pval:.4f}")
        print("✅ Parallel trends holds" if pval > 0.05 else "⚠️ Possible violation")
        
        return {
            'coefficient': coef,
            'pvalue': pval,
            'model': results
        }
    except Exception as e:
        # Fallback to Option B if still failing
        print(f"Primary method failed ({e}), using simple DiD fallback")
        return simple_did_fallback(pre_data, outcome)

def simple_did_fallback(df, outcome):
    """Option B: Two-period DiD without fixed effects"""
    try:
        did_formula = f"{outcome} ~ ever_treated * year_2019"
        model = smf.ols(did_formula, data=df).fit()
        coef = model.params.get('ever_treated:year_2019', np.nan)
        pval = model.pvalues.get('ever_treated:year_2019', np.nan)
        
        print(f"Simple DiD test: coef={coef:.4f}, p-value={pval:.4f}")
        return {
            'coefficient': coef,
            'pvalue': pval,
            'model': model
        }
    except Exception as e:
        print(f"Fallback method failed: {e}")
        return None

# ======================
# SUBGROUP ANALYSIS - Session 5 Approach
# ======================

def run_basic_did(df, outcome):
    """Run basic difference-in-differences"""
    try:
        panel_data = df.set_index(['firm_id', 'year'])
        formula = f"{outcome} ~ digital_transformation_program + EntityEffects + TimeEffects"
        model = PanelOLS.from_formula(formula, data=panel_data)
        results = model.fit(cov_type='clustered', cluster_entity=True)
        return results
    except:
        return None

def run_subgroup_analysis(df, outcome, group_var):
    """Run subgroup analysis using filter-and-regress approach from Session 5"""
    print(f"Running subgroup analysis by {group_var}...")
    
    if group_var not in df.columns:
        print(f"Group variable {group_var} not found")
        return []
    
    # Get unique groups
    groups = df[group_var].cat.categories if hasattr(df[group_var], 'cat') else df[group_var].unique()
    groups = [g for g in groups if pd.notna(g)]
    
    results = []
    for group in groups:
        subset = df[df[group_var] == group]
        
        if len(subset) < 50:  # Minimum sample size
            print(f"Group {group} too small ({len(subset)} observations)")
            continue
        
        # Run DiD on subset
        did_results = run_basic_did(subset, outcome)
        
        if did_results and 'digital_transformation_program' in did_results.params.index:
            results.append({
                'group': group,
                'ate': did_results.params['digital_transformation_program'],
                'se': did_results.std_errors['digital_transformation_program'],
                'pval': did_results.pvalues['digital_transformation_program'],
                'n_obs': len(subset)
            })
    
    return results

def run_moderation_analysis(df, outcome, moderator):
    """Run moderation analysis using interaction terms from Session 5"""
    print(f"Running moderation analysis with {moderator}...")
    
    if moderator not in df.columns:
        print(f"Moderator {moderator} not found")
        return None
    
    try:
        # Create interaction term
        df_mod = df.copy()
        df_mod['treatment_x_moderator'] = (df_mod['digital_transformation_program'] * 
                                         df_mod[moderator])
        
        # Set up panel
        panel_data = df_mod.set_index(['firm_id', 'year'])
        
        # Run regression with interaction
        formula = f"{outcome} ~ digital_transformation_program + {moderator} + treatment_x_moderator + EntityEffects + TimeEffects"
        model = PanelOLS.from_formula(formula, data=panel_data)
        results = model.fit(cov_type='clustered', cluster_entity=True)
        
        return {
            'main_effect': results.params.get('digital_transformation_program', np.nan),
            'interaction_effect': results.params.get('treatment_x_moderator', np.nan),
            'interaction_pval': results.pvalues.get('treatment_x_moderator', np.nan),
            'model': results
        }
    except Exception as e:
        print(f"Moderation analysis failed: {e}")
        return None

# ======================
# CAUSAL TREES - Session 5 Approach
# ======================

def standardize_outcomes(df, outcomes):
    """Standardize outcomes to z-scores for comparable effect sizes"""
    df_std = df.copy()
    for outcome in outcomes:
        if outcome in df_std.columns:
            mean = df_std[outcome].mean()
            std = df_std[outcome].std()
            df_std[outcome] = (df_std[outcome] - mean) / std
    return df_std

def run_causal_tree_analysis(df, outcome):
    print(f"Running causal tree analysis for '{outcome}'...")

    # Standardize outcome for comparable effect sizes
    df = standardize_outcomes(df, [outcome])
    
    features = [
        'log_firm_size', 'firm_age', 'ceo_experience',
        'prior_tech_adoption', 'tech_budget_pct', 'rd_intensity'
    ]
    available_features = [f for f in features if f in df.columns]

    if not available_features:
        print("❌ No available features for causal tree analysis.")
        return None, None, None

    required_columns = available_features + [outcome, 'digital_transformation_program']
    analysis_data = df[required_columns].dropna()

    if len(analysis_data) < 100:
        print(f"❌ Insufficient data: {len(analysis_data)} observations (min 100 required).")
        return None, None, None

    # Winsorize features to handle outliers
    for feature in available_features:
        analysis_data[feature] = winsorize_series(analysis_data[feature])

    X = analysis_data[available_features].values
    T = analysis_data['digital_transformation_program'].values
    Y = analysis_data[outcome].values

    try:
        causal_tree = CausalTreeRegressor(
            random_state=42,
            min_samples_leaf=50,
            max_depth=5,
            min_impurity_decrease=-np.inf
        )
        causal_tree.fit(X=X, treatment=T, y=Y)

        X_full = df[available_features].fillna(df[available_features].mean()).values
        predictions = causal_tree.predict(X_full)

        importances = causal_tree.feature_importances_
        feature_importances = dict(zip(available_features, importances))

        df_with_hte = df.copy()
        df_with_hte['hte_prediction'] = predictions

        print("✅ Causal tree analysis completed successfully.")
        return causal_tree, df_with_hte, feature_importances

    except Exception as e:
        print(f"❌ Causal tree analysis failed: {e}")
        return None, None, None


# ======================
# VISUALIZATION FUNCTIONS
# ======================

def plot_subgroup_results(subgroup_results, outcome, group_var):
    """Plot subgroup analysis results"""
    if not subgroup_results:
        return None
    
    df_results = pd.DataFrame(subgroup_results)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Error bar plot
    ax.errorbar(range(len(df_results)), df_results['ate'], 
               yerr=1.96 * df_results['se'], fmt='o-', capsize=5, markersize=8)
    
    # Add reference line at 0
    ax.axhline(0, color='red', linestyle='--', alpha=0.7)
    
    # Formatting
    ax.set_xticks(range(len(df_results)))
    ax.set_xticklabels(df_results['group'], rotation=45)
    ax.set_ylabel('Average Treatment Effect (SD Units)')
    ax.set_title(f'Treatment Effects by {group_var.replace("_", " ").title()}\n{outcome.replace("_", " ").title()}')
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_hte_distribution(df, outcome):
    """Plot HTE distribution"""
    if 'hte_prediction' not in df.columns:
        return None
    
    hte_values = df['hte_prediction'].dropna()
    if len(hte_values) == 0:
        return None
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Distribution plot
    ax1.hist(hte_values, bins=30, alpha=0.7, edgecolor='black')
    ax1.axvline(hte_values.mean(), color='red', linestyle='--', 
               label=f'Mean: {hte_values.mean():.3f}')
    ax1.set_xlabel('Predicted Treatment Effect (SD Units)')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Distribution of Heterogeneous Treatment Effects')
    ax1.legend()
    ax1.grid(alpha=0.3)
    
    # HTE by firm size
    if 'log_firm_size' in df.columns:
        plot_data = df.dropna(subset=['hte_prediction', 'log_firm_size'])
        if len(plot_data) > 0:
            plot_data['size_quartile'] = pd.qcut(plot_data['log_firm_size'], 4, 
                                               labels=['Q1', 'Q2', 'Q3', 'Q4'])
            quartile_stats = plot_data.groupby('size_quartile')['hte_prediction'].agg(['mean', 'std'])
            
            ax2.bar(quartile_stats.index, quartile_stats['mean'], 
                   yerr=quartile_stats['std'], capsize=5, alpha=0.7)
            ax2.set_xlabel('Firm Size Quartile')
            ax2.set_ylabel('Mean Treatment Effect (SD Units)')
            ax2.set_title('HTE by Firm Size Quartile')
            ax2.grid(alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_feature_hte(df, feature, outcome):
    """Plot HTE against a specific feature"""
    if 'hte_prediction' not in df.columns or feature not in df.columns:
        return None
    
    plot_data = df.dropna(subset=['hte_prediction', feature])
    if len(plot_data) == 0:
        return None
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Add jitter to reduce overplotting
    jitter = 0.1 * (plot_data[feature].max() - plot_data[feature].min())
    x_jitter = plot_data[feature] + np.random.uniform(-jitter, jitter, size=len(plot_data))
    
    ax.scatter(x_jitter, plot_data['hte_prediction'], alpha=0.5)
    
    # Add regression line with confidence interval
    try:
        sns.regplot(x=feature, y='hte_prediction', data=plot_data, 
                   scatter=False, color='red', ax=ax,
                   ci=95, line_kws={'lw': 2})
    except:
        pass
    
    ax.set_xlabel(feature.replace('_', ' ').title())
    ax.set_ylabel('Predicted Treatment Effect (SD Units)')
    ax.set_title(f'Treatment Effect Heterogeneity by {feature.replace("_", " ").title()}')
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_ate_comparison(ate_results):
    """Plot ATE comparison across outcomes in standardized units"""
    outcomes = list(ate_results.keys())
    ates = [ate_results[o]['ate'] for o in outcomes]
    ci_lower = [ate_results[o]['ate'] - 1.96 * ate_results[o]['se'] for o in outcomes]
    ci_upper = [ate_results[o]['ate'] + 1.96 * ate_results[o]['se'] for o in outcomes]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.errorbar(outcomes, ates, yerr=[ates[i]-ci_lower[i] for i in range(len(ates))], 
               fmt='o', capsize=5)
    ax.axhline(0, color='red', linestyle='--')
    ax.set_title('Average Treatment Effects Across Outcomes (SD Units)')
    ax.set_ylabel('ATE with 95% CI (Standardized)')
    plt.xticks(rotation=45)
    plt.tight_layout()
    return fig

# ======================
# MAIN ANALYSIS PIPELINE
# ======================

def main_analysis():
    """Run the complete course-aligned causal analysis pipeline"""
    print("="*60)
    print("COURSE-ALIGNED CAUSAL ANALYSIS PIPELINE")
    print("="*60)
    
    # Load and preprocess data
    df = load_and_validate_data()
    if df is None: 
        return None
    
    df = preprocess_data(df)
    if df is None: 
        return None
    
    # Check required columns
    required_cols = ['firm_id', 'year', 'digital_transformation_program']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        print(f"❌ Missing required columns: {missing_cols}")
        return None
    
    # Standardize outcomes for comparable effect sizes
    outcomes_to_standardize = [col for col in ['productivity_score', 'employee_satisfaction'] 
                              if col in df.columns]
    df = standardize_outcomes(df, outcomes_to_standardize)
    
    # Estimate MSM weights
    df = estimate_treatment_propensities(df)
    
    # Determine available outcomes
    potential_outcomes = ['productivity_score', 'employee_satisfaction']
    available_outcomes = [outcome for outcome in potential_outcomes if outcome in df.columns]
    
    if not available_outcomes:
        print("❌ No valid outcome variables found")
        return None
    
    print(f"📊 Analyzing outcomes: {available_outcomes}")
    
    # Store all results
    all_results = {}
    ate_results = {}
    
    for outcome in available_outcomes:
        print(f"\n{'='*20} ANALYZING: {outcome.upper()} {'='*20}")
        
        outcome_results = {}
        
        # 1. MSM Analysis (Session 7)
        print("\n1️⃣ Running MSM/IPTW Analysis...")
        msm_results = run_msm_analysis(df, outcome)
        outcome_results['msm'] = msm_results
        
        if msm_results:
            ate_results[outcome] = {
                'ate': msm_results['ate'],
                'se': msm_results['se']
            }
        
        # 2. Parallel Trends Test (Session 4)
        print("\n2️⃣ Testing Parallel Trends...")
        pt_results = test_parallel_trends(df, outcome)
        outcome_results['parallel_trends'] = pt_results
        
        # 3. Subgroup Analysis (Session 5)
        print("\n3️⃣ Running Subgroup Analysis...")
        if 'firm_size_bin' in df.columns:
            size_subgroups = run_subgroup_analysis(df, outcome, 'firm_size_bin')
            outcome_results['subgroup_size'] = size_subgroups
        
        if 'prior_tech_adoption' in df.columns:
            tech_subgroups = run_subgroup_analysis(df, outcome, 'prior_tech_adoption')
            outcome_results['subgroup_tech'] = tech_subgroups
        
        # 4. Moderation Analysis (Session 5)
        print("\n4️⃣ Running Moderation Analysis...")
        if 'log_firm_size' in df.columns:
            mod_results = run_moderation_analysis(df, outcome, 'log_firm_size')
            outcome_results['moderation'] = mod_results
        
        # 5. Causal Tree Analysis (Session 5)
        print("\n5️⃣ Running Causal Tree Analysis...")
        tree_model, df_with_hte, feature_importances = run_causal_tree_analysis(df, outcome)
        outcome_results['causal_tree'] = tree_model
        outcome_results['hte_data'] = df_with_hte
        outcome_results['feature_importances'] = feature_importances
        
        # Generate visualizations
        try:
            # Subgroup plots
            if 'subgroup_size' in outcome_results and outcome_results['subgroup_size']:
                fig = plot_subgroup_results(outcome_results['subgroup_size'], outcome, 'firm_size_bin')
                if fig:
                    fig.savefig(f'{outcome}_subgroup_size.png', dpi=300, bbox_inches='tight')
                    plt.close(fig)
            
            if 'subgroup_tech' in outcome_results and outcome_results['subgroup_tech']:
                fig = plot_subgroup_results(outcome_results['subgroup_tech'], outcome, 'prior_tech_adoption')
                if fig:
                    fig.savefig(f'{outcome}_subgroup_tech.png', dpi=300, bbox_inches='tight')
                    plt.close(fig)
            
            # HTE plots
            if df_with_hte is not None:
                fig = plot_hte_distribution(df_with_hte, outcome)
                if fig:
                    fig.savefig(f'{outcome}_hte_distribution.png', dpi=300, bbox_inches='tight')
                    plt.close(fig)
                
                if feature_importances:
                    top_feature = max(feature_importances, key=feature_importances.get)
                    fig = plot_feature_hte(df_with_hte, top_feature, outcome)
                    if fig:
                        fig.savefig(f'{outcome}_hte_{top_feature}.png', dpi=300, bbox_inches='tight')
                        plt.close(fig)
                
        except Exception as e:
            print(f"⚠️ Visualization error for {outcome}: {e}")
        
        all_results[outcome] = outcome_results
    
    # Generate comprehensive report
    print("\n📝 Generating comprehensive report...")
    report_lines = []
    report_lines.append("COURSE-ALIGNED CAUSAL ANALYSIS RESULTS")
    report_lines.append("="*60)
    report_lines.append(f"Analysis Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
    report_lines.append(f"Total Observations: {len(df)}")
    report_lines.append(f"Unique Firms: {df['firm_id'].nunique()}")
    report_lines.append(f"Time Periods: {sorted(df['year'].unique())}")
    
    # ATE comparison
    report_lines.append("\n" + "="*60)
    report_lines.append("AVERAGE TREATMENT EFFECTS (ATE) ACROSS OUTCOMES")
    report_lines.append("="*60)
    for outcome, res in ate_results.items():
        report_lines.append(f"\n📊 {outcome.upper()}:")
        report_lines.append(f"  ATE: {res['ate']:.4f} SD Units")
        report_lines.append(f"  95% CI: [{res['ate']-1.96*res['se']:.4f}, {res['ate']+1.96*res['se']:.4f}]")
        report_lines.append(f"  Standard Error: ±{res['se']:.4f}")
    
    # Detailed results per outcome
    for outcome_name, results in all_results.items():
        report_lines.append(f"\n{'='*20} {outcome_name.upper()} DETAILED RESULTS {'='*20}")
        
        # MSM Results
        if results.get('msm'):
            msm = results['msm']
            report_lines.append(f"\n📊 MSM-IPTW RESULTS:")
            report_lines.append(f"   Average Treatment Effect: {msm['ate']:.4f} SD Units")
            report_lines.append(f"   Standard Error: {msm['se']:.4f}")
            report_lines.append(f"   P-value: {msm['pval']:.4f}")
            report_lines.append(f"   Significance: {'***' if msm['pval'] < 0.01 else '**' if msm['pval'] < 0.05 else '*' if msm['pval'] < 0.1 else 'Not significant'}")
        
        # Parallel Trends
        if results.get('parallel_trends'):
            pt = results['parallel_trends']
            report_lines.append(f"\n📈 PARALLEL TRENDS TEST:")
            report_lines.append(f"   Lead coefficient: {pt['coefficient']:.4f}")
            report_lines.append(f"   P-value: {pt['pvalue']:.4f}")
            report_lines.append(f"   Assessment: {'✅ Assumption likely holds' if pt['pvalue'] > 0.05 else '⚠️ Potential violation'}")
        
        # Subgroup Results
        for subgroup_type in ['subgroup_size', 'subgroup_tech']:
            if results.get(subgroup_type):
                group_name = 'Firm Size' if 'size' in subgroup_type else 'Tech Adoption'
                report_lines.append(f"\n🏢 SUBGROUP ANALYSIS ({group_name}):")
                for result in results[subgroup_type]:
                    sig = '***' if result['pval'] < 0.01 else '**' if result['pval'] < 0.05 else '*' if result['pval'] < 0.1 else ''
                    report_lines.append(f"   {result['group']}: ATE={result['ate']:.4f} (SE={result['se']:.4f}) {sig}")
        
        # Moderation Results
        if results.get('moderation'):
            mod = results['moderation']
            report_lines.append(f"\n🔄 MODERATION ANALYSIS:")
            report_lines.append(f"   Main Effect: {mod['main_effect']:.4f}")
            report_lines.append(f"   Interaction Effect: {mod['interaction_effect']:.4f}")
            report_lines.append(f"   Interaction P-value: {mod['interaction_pval']:.4f}")
        
        # HTE Results
        hte_data = results.get('hte_data')
        if isinstance(hte_data, pd.DataFrame) and 'hte_prediction' in hte_data.columns:
            hte_vals = hte_data['hte_prediction'].dropna()
            if len(hte_vals) > 0:
                report_lines.append(f"\n🌳 HETEROGENEOUS TREATMENT EFFECTS:")
                report_lines.append(f"   Mean HTE: {hte_vals.mean():.4f} SD Units")
                report_lines.append(f"   Median HTE: {hte_vals.median():.4f}")
                report_lines.append(f"   Standard Deviation: {hte_vals.std():.4f}")
                report_lines.append(f"   Range: [{hte_vals.min():.4f}, {hte_vals.max():.4f}]")
                report_lines.append(f"   % Positive Effects: {(hte_vals > 0).mean()*100:.1f}%")
        
        # Feature Importances
        feat_importances = results.get('feature_importances')
        if feat_importances:
            report_lines.append(f"\n🔍 FEATURE IMPORTANCE FOR HTE:")
            for feature, importance in sorted(feat_importances.items(), key=lambda x: x[1], reverse=True):
                report_lines.append(f"   {feature}: {importance:.4f}")
    
    # Save report
    try:
        with open('course_aligned_causal_results.txt', 'w', encoding='utf-8') as f:
            f.write('\n'.join(report_lines))
        
        # Create ATE comparison plot
        if ate_results:
            fig = plot_ate_comparison(ate_results)
            if fig:
                fig.savefig('ate_comparison.png', dpi=300, bbox_inches='tight')
                plt.close(fig)
        
        print("✅ Analysis complete! Results saved:")
        print("  - course_aligned_causal_results.txt")
        print("  - ate_comparison.png")
        print("  - Outcome-specific visualizations")
    except Exception as e:
        print(f"⚠️ Error saving report: {e}")
    
    return all_results

if __name__ == '__main__':
    results = main_analysis()

COURSE-ALIGNED CAUSAL ANALYSIS PIPELINE
✅ Dataset loaded successfully
Dataset shape: (3000, 23)
Estimating treatment propensities for MSM...
Using confounders: ['rd_intensity_lag1', 'tech_budget_lag1', 'ceo_exp_lag1', 'productivity_lag1', 'treatment_lag1', 'treatment_lag2', 'firm_age', 'log_firm_size']
✅ MSM weights computed: mean=0.813, min=0.470, max=8.514
📊 Analyzing outcomes: ['productivity_score', 'employee_satisfaction']


1️⃣ Running MSM/IPTW Analysis...
Running MSM analysis for productivity_score...

2️⃣ Testing Parallel Trends...
Testing parallel trends for productivity_score...
Parallel trends test: coef=0.5168, p-value=0.0084
⚠️ Possible violation

3️⃣ Running Subgroup Analysis...
Running subgroup analysis by firm_size_bin...
Running subgroup analysis by prior_tech_adoption...

4️⃣ Running Moderation Analysis...
Running moderation analysis with log_firm_size...

5️⃣ Running Causal Tree Analysis...
Running causal tree analysis for 'productivity_score'...
✅ Causal tree analysi