
## What is a Truncated Factorisation Causal Bootstrap Algorithm?

- It is an algorithm which resamples the dataset to create a new dataset using the **truncated factorisation** of a causal graph.
- The idea comes from Pearl’s **do-calculus** and the **g-formula**.
- Instead of conditioning on mediators (front door) or confounders (back door) alone, it uses the **entire causal structure**.
- The joint distribution of all variables is factorised according to the causal graph.
- To simulate an intervention `do(X = x)`, we **truncate** (remove) the conditional distribution of `X` given its parents.
- All other variables are resampled from their conditional distributions given their parents.
- This allows us to estimate causal effects even in complex systems with:
  - multiple confounders,
  - mediators,
  - chains of causes.
- The key requirement is that the **causal graph is correctly specified** and all parent variables are observed.


## Example

Imagine you want to assess whether a new teaching method (the “cause” **X**) improves exam performance (the “effect” **Y**).

- There are multiple observed variables:
  - prior grades,
  - motivation,
  - study hours,
  - practice problems.
- Some of these variables affect both the teaching method assignment and exam performance.
- Others lie on causal paths between **X** and **Y**.

Causal structure (simplified):

- Prior Grades -> Teaching Method -> Study Hours -> Exam Performance
- Motivation -> Study Hours


- We assume all parent variables of each node are observed.
- To estimate the causal effect of the teaching method:
  - We **intervene** on **X** by fixing it to a chosen value.
  - We remove the probability model `P(X | Parents(X))`.
  - We resample all downstream variables using their conditional distributions.
- By repeatedly resampling, we generate a bootstrapped dataset that reflects:
  - “What exam performance would look like if everyone received the same teaching method.”


**Interpretation note**

- This bootstrap keeps mediator values (e.g., `ca`) fixed while intervening on the target label.
- That means we are sampling rows compatible with `(confounders, mediators, Y=y*)`, not simulating how mediators would change under `do(Y=y*)`.
- It is still graph-respecting in that we condition on parents (confounders + mediator), but the target-side intervention is label-level, not a full structural rollout.


In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from scipy.spatial.distance import cdist

df = pd.read_csv('heart_disease_preprocessed.csv')
df.head()

Unnamed: 0,age,trestbps,chol,thalach,oldpeak,ca,sex_Female,sex_Male,cp_Asymptomatic,cp_AtypicalAngina,...,restecg_STTAbnormality,exang_NoExAngina,exang_YesExAngina,slope_Downsloping,slope_Flat,slope_Upsloping,thal_FixedDefect,thal_Normal,thal_ReversibleDefect,heartdiseasepresence
0,0.950883,0.743598,-0.289108,0.040935,1.180495,-0.740979,0,1,0,0,...,0,1,0,1,0,0,1,0,0,0
1,1.397584,1.593663,0.78534,-1.757678,0.647625,2.527338,0,1,1,0,...,0,0,1,0,1,0,0,1,0,1
2,1.397584,-0.673176,-0.370199,-0.858371,1.3475,1.437899,0,1,1,0,...,0,0,1,0,1,0,0,0,1,1
3,-1.952676,-0.106466,0.055526,1.625427,1.77579,-0.740979,0,1,0,0,...,0,1,0,1,0,0,0,1,0,0
4,-1.505975,-0.106466,-0.877014,0.983065,0.569273,-0.740979,1,0,0,1,...,0,1,0,0,0,1,0,1,0,0


In [2]:
mediators = ['ca']
confounders = ['age', 'sex_Male', 'sex_Female']

effect = 'heartdiseasepresence'  # binary/bounded exposure in your heart dataset

causes = [c for c in df.columns if c not in mediators + confounders + [effect]]

print('effect variable:', effect)
print('front-door/mediator set:', mediators)
print('back-door/confounder set:', confounders)
print('cause variables:', causes)

effect variable: heartdiseasepresence
front-door/mediator set: ['ca']
back-door/confounder set: ['age', 'sex_Male', 'sex_Female']
cause variables: ['trestbps', 'chol', 'thalach', 'oldpeak', 'cp_Asymptomatic', 'cp_AtypicalAngina', 'cp_NonAnginalPain', 'cp_TypicalAngina', 'fbs_<=120', 'fbs_>120', 'restecg_LVHypertrophy', 'restecg_NormalECG', 'restecg_STTAbnormality', 'exang_NoExAngina', 'exang_YesExAngina', 'slope_Downsloping', 'slope_Flat', 'slope_Upsloping', 'thal_FixedDefect', 'thal_Normal', 'thal_ReversibleDefect']


**important functions**

- gaussian_kernel_matrix function is a kernel function applied to vectors made of the confounder values from the dataframe to measure similarity,
which can influence the probability of an associated row occuring in the resampled dataset 
- ensure_2d is a function to ensure that a dataframe is converted to a 2D numpy array to therefore allow doing things like applying the previous gaussian kernel function

In [3]:
def gaussian_kernel_matrix(A, B=None, bandwidth=1.0):
    """Return Gaussian kernel matrix K_ij = exp(-0.5 * ||a_i - b_j||^2 / h^2)."""
    if B is None:
        B = A
    dists = cdist(A, B, metric='euclidean')
    K = np.exp(-0.5 * (dists / bandwidth) ** 2)
    return K

def ensure_2d(df, cols):
    return df[cols].to_numpy(dtype=float).reshape(len(df), -1) # e

In [4]:
N = len(df) #number of samples
y = df[effect].to_numpy() #converting target column to numpy array
unique_y = np.unique(y) 

In [5]:
conditionstointervene = confounders + mediators + [effect]
conditionstointervene_matrix = ensure_2d(df, conditionstointervene)

bandwidth_gen = 1.0 # bandwidth for general kernel
conditionstointervene_kernel = gaussian_kernel_matrix(conditionstointervene_matrix, bandwidth=bandwidth_gen)
#similarities between all samples based on confounders + mediators + effect

print(conditionstointervene_matrix)
print(conditionstointervene_kernel) 

[[ 0.9508825   1.          0.         -0.74097874  0.        ]
 [ 1.39758378  1.          0.          2.52733829  1.        ]
 [ 1.39758378  1.          0.          1.43789928  1.        ]
 ...
 [ 1.5092591   1.          0.          1.43789928  1.        ]
 [ 0.28083058  1.          0.          0.34846027  1.        ]
 [ 0.28083058  0.          1.          0.34846027  1.        ]]
[[1.         0.00263014 0.05112275 ... 0.0483327  0.26769059 0.09847787]
 [0.00263014 1.         0.55242441 ... 0.54899038 0.0499205  0.01836473]
 [0.05112275 0.55242441 1.         ... 0.99378371 0.29611496 0.1089346 ]
 ...
 [0.0483327  0.54899038 0.99378371 ... 1.         0.25977043 0.0955642 ]
 [0.26769059 0.0499205  0.29611496 ... 0.25977043 1.         0.36787944]
 [0.09847787 0.01836473 0.1089346  ... 0.0955642  0.36787944 1.        ]]


In [6]:
rng = np.random.default_rng(0)
tf_dfs = []
for y_star in unique_y:
    rows = []
    for i in range(N): #going through each sample in dataframe
        # query vector = original S/Z of row i, but Y=y_star
        orig = df.iloc[i] #values in a single row associated with columns at index i
        query_vals = [] 
        for c in conditionstointervene:
            if c == effect:
                query_vals.append(float(y_star)) # if the column is the target column, use y_star
            else:
                query_vals.append(float(orig[c])) #else use the original value from the row
        query = np.array(query_vals).reshape(1, -1)  #basically flattens the list to a 2D array with one row e.g. [[val1 val2 val3 ...]]
        K_q = np.exp(-0.5 * (cdist(query, conditionstointervene_matrix, metric='euclidean') / bandwidth_gen) ** 2).ravel() # for each conditional row in cond_mat, compute each row similarity to the query
        p = K_q / np.maximum(K_q.sum(), 1e-8) # normalize the similarities to probabilities
        idx = rng.choice(np.arange(N), p=p) # greater similarity to query which is a row with Y=y_star, higher chance of being selected
        row = df.iloc[[idx]].copy() # takes the row at index idx and makes a new dataframe
        row['do_' + effect] = y_star  # add intervention info
        row[effect] = y_star # in the interventional world, Y is set to y_star
        rows.append(row)  # add this row to the list
    tf_df = pd.concat(rows, ignore_index=True) # concatenate all rows into a dataframe
    tf_dfs.append(tf_df) # add this dataframe to the list

tf_df = pd.concat(tf_dfs, ignore_index=True) #concatenate all dataframes into one

tf_df

Unnamed: 0,age,trestbps,chol,thalach,oldpeak,ca,sex_Female,sex_Male,cp_Asymptomatic,cp_AtypicalAngina,...,exang_NoExAngina,exang_YesExAngina,slope_Downsloping,slope_Flat,slope_Upsloping,thal_FixedDefect,thal_Normal,thal_ReversibleDefect,heartdiseasepresence,do_heartdiseasepresence
0,0.504181,2.387057,0.035253,-0.258834,-1.111053,-0.740979,1,0,1,0,...,0,1,0,1,0,0,1,0,0,0
1,1.174233,-1.239886,0.014981,0.383528,-0.208954,1.437899,0,1,1,0,...,1,0,0,0,1,1,0,0,0,0
2,0.839207,0.460243,0.420433,0.469176,1.817975,1.437899,1,0,1,0,...,1,0,1,0,0,0,1,0,0,0
3,-1.952676,-0.106466,0.055526,1.625427,1.775790,-0.740979,0,1,0,0,...,1,0,1,0,0,0,1,0,0,0
4,-1.394299,-0.673176,0.967794,0.554824,-1.111053,-0.740979,0,1,0,1,...,1,0,0,0,1,0,1,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
539,0.169155,-0.106466,0.724522,-1.971798,0.722903,-0.740979,0,1,1,0,...,0,1,1,0,0,0,0,1,1,1
540,-1.170949,-0.673176,-0.431017,0.854593,-1.111053,-0.740979,0,1,0,0,...,1,0,0,0,1,0,1,0,1,1
541,0.504181,-1.239886,-0.167473,-0.301658,0.402268,0.348460,0,1,1,0,...,0,1,0,1,0,0,0,1,1,1
542,-0.277546,-0.219808,-0.856742,1.496954,-1.111053,-0.740979,0,1,0,1,...,1,0,0,0,1,0,1,0,1,1


In [7]:
tf_df = tf_df.groupby("heartdiseasepresence", group_keys=False).sample(frac=0.5, random_state=42)
print(tf_df["heartdiseasepresence"].value_counts())
tf_df = tf_df.drop(columns=["do_heartdiseasepresence"])
tf_df

0    136
1    136
Name: heartdiseasepresence, dtype: int64


Unnamed: 0,age,trestbps,chol,thalach,oldpeak,ca,sex_Female,sex_Male,cp_Asymptomatic,cp_AtypicalAngina,...,restecg_STTAbnormality,exang_NoExAngina,exang_YesExAngina,slope_Downsloping,slope_Flat,slope_Upsloping,thal_FixedDefect,thal_Normal,thal_ReversibleDefect,heartdiseasepresence
30,-0.500897,0.686927,-0.958105,-0.986844,0.120886,-0.740979,0,1,1,0,...,0,0,1,0,1,0,0,0,1,0
116,1.062558,2.160373,-0.410744,0.255055,-0.208954,-0.740979,0,1,0,0,...,0,1,0,0,1,0,0,0,1,0
79,0.057480,0.176888,0.055526,0.512000,0.569273,-0.740979,1,0,0,1,...,0,1,0,0,1,0,0,1,0,0
127,1.174233,-0.673176,-1.424375,-0.387306,-0.465247,-0.740979,0,1,1,0,...,0,1,0,0,0,1,0,0,1,0
196,1.397584,-0.673176,-0.370199,-0.858371,1.347500,1.437899,0,1,1,0,...,0,0,1,0,1,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
379,-0.389221,-2.146622,-0.410744,0.212231,-1.111053,0.348460,0,1,0,0,...,0,0,1,0,0,1,0,0,1,1
433,-0.947598,-1.749925,-1.018922,0.297879,-1.111053,-0.740979,0,1,0,1,...,0,1,0,0,0,1,0,0,1,1
324,0.169155,-0.673176,-1.100013,0.554824,0.932494,-0.740979,0,1,0,0,...,0,1,0,0,1,0,0,0,1,1
427,0.057480,2.727082,1.616517,-1.372261,1.732656,-0.740979,1,0,1,0,...,1,0,1,0,1,0,0,1,0,1


In [8]:
tf_df.to_csv('heart_disease_preprocessed_tf.csv', header=True, index=False)