In [None]:
import os
import numpy as np
import pandas as pd
from datetime import datetime
from scipy.stats import norm

# ATE estimation
import statsmodels.api as sm
from doubleml import DoubleMLData, DoubleMLPLR
from sklearn.linear_model import LogisticRegression, LinearRegression

# DML CNN utils
from dml_utils.doubleml_image import DoubleMLPLRImage
from dml_utils.xray_dataset import XRayDataset
from dml_utils.cnn_regressor import CNNRegressor
from dml_utils.preprocess import dml_preprocess_pipeline
from dml_utils.models import CNN_2, CNN_5
from torch.utils.data import DataLoader

# Custom modules
from feature_extraction.pretrained_models_xrv import load_torchxrayvision_model, extract_features_from_folder
from utils.project import set_root
from utils.io import save_results, load_results
from visualization.plotting import plot_ate_estimates

In [None]:
# Set working directory, dataset directory and directory for saving results
set_root()
dataset_dir = "data/xray/raw/all_unique"
results_dir = "results/dml_cnn/xray"

# Define the model name and path for saving results
model_name = "densenet121-res224-all"  # Pretrained model name
save_dir = f"data/xray/representations/{model_name}"

# Define file paths
features_path = os.path.join(save_dir, "latent_features.npy")
labels_path = os.path.join(save_dir, "labels.npy")

In [None]:
# Feature extraction and saving (Only if the features and labels do not already exist)
if not os.path.exists(features_path) or not os.path.exists(labels_path):
    print(f"Extracting features using model '{model_name}'...")
    
    # Extract features and save them
    model = load_torchxrayvision_model(model_name)
    all_features, labels = extract_features_from_folder(
        dataset_dir,
        model,
        device='cpu',
        batch_size=32,
        save_path=save_dir
    )
    
    print(f"Features extracted and saved to: {save_dir}")
else:
    print(f"Features already exist in {save_dir}. Skipping extraction.")

# Load extracted features
all_features_pretrained = np.load(features_path)
all_labels_pretrained = np.load(labels_path)
n_samples_total = len(all_labels_pretrained)

### ATE Estimation using Double Machine Learning with/without pre-trained representations

In [None]:
# Create the Dataset and DataLoader for DML with CNN regressors
dataset = XRayDataset(dataset_dir, transform=dml_preprocess_pipeline())

batch_size = n_samples_total # Use all samples avialbale and subsample later
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [6]:
# Define the CNN architecture used for the outcome and propensity score models 
cnn_type = "cnn2"  # Choose between "cnn2" and "cnn5"
if cnn_type == "cnn2":
    # Initialize raw PyTorch CNN models
    cnn_for_outcome = CNN_2(output_dim=1, is_classifier=False)  # For continuous outcome
    cnn_for_treatment = CNN_2(output_dim=1, is_classifier=True)  # For binary treatment
elif cnn_type == "cnn5":
    # Initialize raw PyTorch CNN models
    cnn_for_outcome = CNN_5(output_dim=1, is_classifier=False)  # For continuous outcome
    cnn_for_treatment = CNN_5(output_dim=1, is_classifier=True)  # For binary treatment
else:
    raise ValueError("Invalid CNN type specified.")

# Wrap models using the CNNRegressor class
cnn_for_outcome = CNNRegressor(cnn_for_outcome, epochs=30, is_classifier=False)
cnn_for_treatment = CNNRegressor(cnn_for_treatment, epochs=30, is_classifier=True)

### Label Confounding Simualtion and ATE Estimation 

In [None]:
# 1.1 Define simulation parameters
beta_true = 2.0     # True effect of A on Y
gamma_true = -2   # Effect of pneumonia on Y
p_treat_given_pneu = 0.7    # Probability of treatment if pneumonia
p_treat_given_normal = 0.3  # Probability of treatment if normal

# 1.2 Set up training loop parameters
n_runs = 5  
n_samples_run = 500  # Number of samples to use in each run
ci_alpha_level = 0.05  # Confidence interval alpha level
z_score = norm.ppf(1 - ci_alpha_level / 2)  # Z-score for 1-alpha confidence intervals

# 1.3 Initialize storage for estimates and confidence intervals
methods = ['Naive', 'Oracle', 'DML (Pre-trained)', 'DML (CNN)']
estimates_dict = {method: [] for method in methods}
cis_dict = {method: {'se': [], 'lower': [], 'upper': []} for method in methods}

# 1.4. Load image data (unshuffled)
all_images, all_labels, filenames = next(iter(dataloader))
label_mapping = {"NORMAL": 0, "PNEUMONIA": 1}
all_labels = np.array([label_mapping[label] for label in all_labels])
all_labels = all_labels.astype(int)  # Ensure binary (0/1) labels for pneumonia

# Ensure the labels match between the images and the representations
assert np.array_equal(all_labels, all_labels_pretrained), "Labels do not match between dataset and representations."

# Set seed for reproducibility
seed = 42

for run in range(n_runs):
    print(f"\n--- Simulation Run {run + 1} ---")
    # Set a unique seed for each run for variability
    seed = seed + 2  # Update seed for each run
    np.random.seed(seed)
    
    # 2.1. Generate random indices for shuffling
    random_indices = np.random.permutation(n_samples_total)
    
    # 2.2. Shuffle both images and representations using the same indices
    images = all_images[random_indices][:n_samples_run]
    labels = all_labels[random_indices][:n_samples_run]
    features_pretrained = all_features_pretrained[random_indices][:n_samples_run]
    labels_pretrained = all_labels_pretrained[random_indices][:n_samples_run]
    
    # Sanity check to ensure labels are still matching
    assert np.array_equal(labels, labels_pretrained), "Labels do not match after shuffling."
    print(f"Sampled {n_samples_run} observations of {n_samples_total} for this run.")

    # 3.1 Simulate Treatment A
    pA = labels * p_treat_given_pneu + (1 - labels) * p_treat_given_normal
    A = np.random.binomial(1, pA)

    # 3.2. Simulate Outcome Y
    noise = np.random.normal(loc=0, scale=1, size=n_samples_run)
    Y = beta_true * A + gamma_true * labels + noise

    # Store data in DataFrame
    df = pd.DataFrame({'Y': Y, 'A': A, 'pneumonia': labels})

    ## 4.4. Naive OLS (Unadjusted)
    X_naive = sm.add_constant(df['A']) 
    model_naive = sm.OLS(df['Y'], X_naive).fit()
    beta_naive = model_naive.params['A']
    se_naive = model_naive.bse['A']
    ci_lower_naive = beta_naive - z_score * se_naive
    ci_upper_naive = beta_naive + z_score * se_naive
    estimates_dict['Naive'].append(beta_naive)
    cis_dict['Naive']['lower'].append(ci_lower_naive)
    cis_dict['Naive']['upper'].append(ci_upper_naive)
    print(f"Naive OLS: β = {beta_naive:.3f}, SE = {se_naive:.3f}")
    
    ## 4.5. Oracle OLS (Adjusting for pneumonia)
    X_oracle = sm.add_constant(df[['A', 'pneumonia']])
    model_oracle = sm.OLS(df['Y'], X_oracle).fit()
    beta_oracle = model_oracle.params['A']
    se_oracle = model_oracle.bse['A']
    ci_lower_oracle = beta_oracle - z_score * se_oracle
    ci_upper_oracle = beta_oracle + z_score * se_oracle
    estimates_dict['Oracle'].append(beta_oracle)
    cis_dict['Oracle']['lower'].append(ci_lower_oracle)
    cis_dict['Oracle']['upper'].append(ci_upper_oracle)
    print(f"Oracle OLS: β = {beta_oracle:.3f}, SE = {se_oracle:.3f}")

    ## 4.6. DML (Pre-Trained and CNNs)
    # Convert pre-trained features to DataFrame
    X_dml_df = pd.DataFrame(
        features_pretrained,
        columns=[f"feat_{i}" for i in range(features_pretrained.shape[1])]
    )

    # Add outcome and treatment to DoubleMLData via column names
    X_dml_df['Y'] = df['Y'].copy()
    X_dml_df['A'] = df['A'].copy()

    # Create DoubleMLData
    data_dml = DoubleMLData(X_dml_df, "Y", "A")

    # 4.6.1. DML (Pre-Trained): DML with linear models nuisance functions from pre-trained representations
    try:
        # Define nuisance models with linear models
        ml_g_linear = LinearRegression() # Outcome model
        ml_m_linear = LogisticRegression()  # Treatment model

    
        # Instantiate and fit DoubleMLPLR
        dml_plr_linear = DoubleMLPLR(data_dml, ml_g_linear, ml_m_linear, n_folds=2) 
        dml_plr_linear.fit()
        beta_dml_linear = dml_plr_linear.coef[0]
        se_dml_linear = dml_plr_linear.se[0]
        estimates_dict['DML (Pre-trained)'].append(beta_dml_linear)
        # 95% Confidence Interval
        ci_lower_dml_linear = beta_dml_linear - z_score * se_dml_linear
        ci_upper_dml_linear = beta_dml_linear + z_score * se_dml_linear
        cis_dict['DML (Pre-trained)']['lower'].append(ci_lower_dml_linear)
        cis_dict['DML (Pre-trained)']['upper'].append(ci_upper_dml_linear)
        print(f"DML (Pre-trained): β = {beta_dml_linear:.3f}, SE = {se_dml_linear:.3f}")
    except Exception as e:
        print(f"Run {run+1}: DML (Pre-trained) failed with error: {e}")
        estimates_dict['DML (Pre-trained)'].append(np.nan)
        cis_dict['DML (Pre-trained)']['lower'].append(np.nan)
        cis_dict['DML (Pre-trained)']['upper'].append(np.nan)

    ## 4.6.2. DML (CNN): DML with CNNs as nuisance functions
    # Initialize DoubleMLPLRImage with CNNs
    try:
        # Initialize DoubleMLPLRImage with CNNs
        dml_plr_image = DoubleMLPLRImage(
            X=images.numpy(),
            y=Y,
            d=A,
            ml_l=cnn_for_outcome,
            ml_m=cnn_for_treatment,
            n_folds=2,
            n_rep=1
        )
    
        # Fit the model
        dml_plr_image.fit()
    
        # Collect results
        beta_dml = dml_plr_image.coef
        se_dml = dml_plr_image.se
        estimates_dict['DML (CNN)'].append(beta_dml)
        ci_lower_dml = beta_dml - z_score * se_dml
        ci_upper_dml = beta_dml + z_score * se_dml
        cis_dict['DML (CNN)']['se'].append(se_dml)
        cis_dict['DML (CNN)']['lower'].append(ci_lower_dml)
        cis_dict['DML (CNN)']['upper'].append(ci_upper_dml)
        print(f"DML (CNN): β = {beta_dml:.3f}, SE = {se_dml:.3f}")
    except Exception as e:
        print(f"Run {run+1}: DML (CNN) failed with error: {e}")
        estimates_dict['DML (CNN)'].append(np.nan)
        cis_dict['DML (CNN)']['se'].append(np.nan)
        cis_dict['DML (CNN)']['lower'].append(np.nan)
        cis_dict['DML (CNN)']['upper'].append(np.nan)

In [8]:
# 5. Create a directory for the experiment and save the results
experiment_name = f"{cnn_type}_{n_samples_run}"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_dir = os.path.join(results_dir, model_name, experiment_name, timestamp)
save_results(experiment_dir, estimates_dict, cis_dict)

### Plotting Results

In [5]:
# Load the results from the previous experiment
estimates_dict, cis_dict = load_results(experiment_dir)

In [None]:
# Plot ATE estimates with confidence intervals
plot_ate_estimates(
    estimates_dict=estimates_dict,
    cis_dict=cis_dict,
    plot_name=f'ate_label_conf_xray_{experiment_name}',
    save_dir=experiment_dir,
    ate_true=beta_true,
    n_runs=n_runs,
    vert_diff=0.03,
    label_break=False,
    verbose=True
)