
## What is a Truncated Factorisation Causal Bootstrap Algorithm?

- It is an algorithm which resamples the dataset to create a new dataset using the **truncated factorisation** of a causal graph.
- The idea comes from Pearl’s **do-calculus** and the **g-formula**.
- Instead of conditioning on mediators (front door) or confounders (back door) alone, it uses the **entire causal structure**.
- The joint distribution of all variables is factorised according to the causal graph.
- To simulate an intervention `do(X = x)`, we **truncate** (remove) the conditional distribution of `X` given its parents.
- All other variables are resampled from their conditional distributions given their parents.
- This allows us to estimate causal effects even in complex systems with:
  - multiple confounders,
  - mediators,
  - chains of causes.
- The key requirement is that the **causal graph is correctly specified** and all parent variables are observed.


## Example

Imagine you want to assess whether a new teaching method (the “cause” **X**) improves exam performance (the “effect” **Y**).

- There are multiple observed variables:
  - prior grades,
  - motivation,
  - study hours,
  - practice problems.
- Some of these variables affect both the teaching method assignment and exam performance.
- Others lie on causal paths between **X** and **Y**.

Causal structure (simplified):

- Prior Grades -> Teaching Method -> Study Hours -> Exam Performance
- Motivation -> Study Hours


- We assume all parent variables of each node are observed.
- To estimate the causal effect of the teaching method:
  - We **intervene** on **X** by fixing it to a chosen value.
  - We remove the probability model `P(X | Parents(X))`.
  - We resample all downstream variables using their conditional distributions.
- By repeatedly resampling, we generate a bootstrapped dataset that reflects:
  - “What exam performance would look like if everyone received the same teaching method.”


In [8]:
import pandas as pd
import numpy as np
from pathlib import Path
from scipy.spatial.distance import cdist

df = pd.read_csv('heart_disease_preprocessed.csv')
df.head()

Unnamed: 0,age,trestbps,chol,thalach,oldpeak,ca,sex_Female,sex_Male,cp_Asymptomatic,cp_AtypicalAngina,...,restecg_STTAbnormality,exang_NoExAngina,exang_YesExAngina,slope_Downsloping,slope_Flat,slope_Upsloping,thal_FixedDefect,thal_Normal,thal_ReversibleDefect,heartdiseasepresence
0,0.950883,0.743598,-0.289108,0.040935,1.180495,-0.740979,0,1,0,0,...,0,1,0,1,0,0,1,0,0,0
1,1.397584,1.593663,0.78534,-1.757678,0.647625,2.527338,0,1,1,0,...,0,0,1,0,1,0,0,1,0,1
2,1.397584,-0.673176,-0.370199,-0.858371,1.3475,1.437899,0,1,1,0,...,0,0,1,0,1,0,0,0,1,1
3,-1.952676,-0.106466,0.055526,1.625427,1.77579,-0.740979,0,1,0,0,...,0,1,0,1,0,0,0,1,0,0
4,-1.505975,-0.106466,-0.877014,0.983065,0.569273,-0.740979,1,0,0,1,...,0,1,0,0,0,1,0,1,0,0


In [9]:
mediators = ['thalach', 'oldpeak', 'slope_Downsloping', 'slope_Flat', 'slope_Upsloping', 'cp_Asymptomatic', 'cp_AtypicalAngina', 'cp_NonAnginalPain', 'cp_TypicalAngina']
confounders = ['age', 'sex_Male', 'sex_Female']

effect = 'heartdiseasepresence'  # binary/bounded exposure in your heart dataset

causes = [c for c in df.columns if c not in mediators + confounders + [effect]]

print('effect variable:', effect)
print('front-door/mediator set:', mediators)
print('back-door/confounder set:', confounders)
print('cause variables:', causes)

effect variable: heartdiseasepresence
front-door/mediator set: ['thalach', 'oldpeak', 'slope_Downsloping', 'slope_Flat', 'slope_Upsloping', 'cp_Asymptomatic', 'cp_AtypicalAngina', 'cp_NonAnginalPain', 'cp_TypicalAngina']
back-door/confounder set: ['age', 'sex_Male', 'sex_Female']
cause variables: ['trestbps', 'chol', 'ca', 'fbs_<=120', 'fbs_>120', 'restecg_LVHypertrophy', 'restecg_NormalECG', 'restecg_STTAbnormality', 'exang_NoExAngina', 'exang_YesExAngina', 'thal_FixedDefect', 'thal_Normal', 'thal_ReversibleDefect']


**important functions**

- gaussian_kernel_matrix function is a kernel function applied to vectors made of the confounder values from the dataframe to measure similarity,
which can influence the probability of an associated row occuring in the resampled dataset 
- ensure_2d is a function to ensure that a dataframe is converted to a 2D numpy array to therefore allow doing things like applying the previous gaussian kernel function

In [10]:
def gaussian_kernel_matrix(A, B=None, bandwidth=1.0):
    """Return Gaussian kernel matrix K_ij = exp(-0.5 * ||a_i - b_j||^2 / h^2)."""
    if B is None:
        B = A
    dists = cdist(A, B, metric='euclidean')
    K = np.exp(-0.5 * (dists / bandwidth) ** 2)
    return K

def ensure_2d(df, cols):
    return df[cols].to_numpy(dtype=float).reshape(len(df), -1) # e

In [11]:
N = len(df) #number of samples
y = df[effect].to_numpy() #converting target column to numpy array
unique_y = np.unique(y) 

In [12]:
conditionstointervene = confounders + mediators + [effect]
conditionstointervene_matrix = ensure_2d(df, conditionstointervene)

bandwidth_gen = 1.0 # bandwidth for general kernel
conditionstointervene_kernel = gaussian_kernel_matrix(conditionstointervene_matrix, bandwidth=bandwidth_gen)
#similarities between all samples based on confounders + mediators + effect

print(conditionstointervene_matrix)
print(conditionstointervene_kernel) 

[[0.9508825  1.         0.         ... 0.         1.         0.        ]
 [1.39758378 1.         0.         ... 0.         0.         1.        ]
 [1.39758378 1.         0.         ... 0.         0.         1.        ]
 ...
 [1.5092591  1.         0.         ... 0.         0.         1.        ]
 [0.28083058 1.         0.         ... 0.         0.         1.        ]
 [0.28083058 0.         1.         ... 0.         0.         1.        ]]
[[1.00000000e+00 1.27879743e-02 4.88944127e-02 ... 5.59886540e-02
  1.57554366e-02 1.02994844e-03]
 [1.27879743e-02 1.00000000e+00 5.22417150e-01 ... 2.03224583e-01
  4.97281332e-01 2.84640042e-04]
 [4.88944127e-02 5.22417150e-01 1.00000000e+00 ... 8.08601210e-01
  2.86498134e-01 5.51626718e-04]
 ...
 [5.59886540e-02 2.03224583e-01 8.08601210e-01 ... 1.00000000e+00
  1.04418929e-01 4.11204007e-04]
 [1.57554366e-02 4.97281332e-01 2.86498134e-01 ... 1.04418929e-01
  1.00000000e+00 1.76963743e-03]
 [1.02994844e-03 2.84640042e-04 5.51626718e-04 ... 4.112

In [13]:
rng = np.random.default_rng(0)
tf_dfs = []
for y_star in unique_y:
    rows = []
    for i in range(N): #going through each sample in dataframe
        # query vector = original S/Z of row i, but Y=y_star
        orig = df.iloc[i] #values in a single row associated with columns at index i
        query_vals = [] 
        for c in conditionstointervene:
            if c == effect:
                query_vals.append(float(y_star)) # if the column is the target column, use y_star
            else:
                query_vals.append(float(orig[c])) #else use the original value from the row
        query = np.array(query_vals).reshape(1, -1)  #basically flattens the list to a 2D array with one row e.g. [[val1 val2 val3 ...]]
        K_q = np.exp(-0.5 * (cdist(query, conditionstointervene_matrix, metric='euclidean') / bandwidth_gen) ** 2).ravel() # for each conditional row in cond_mat, compute each row similarity to the query
        p = K_q / np.maximum(K_q.sum(), 1e-8) # normalize the similarities to probabilities
        idx = rng.choice(np.arange(N), p=p) # greater similarity to query which is a row with Y=y_star, higher chance of being selected
        row = df.iloc[[idx]].copy() # takes the row at index idx and makes a new dataframe
        row['do_' + effect] = y_star  # add intervention info
        row[effect] = y_star # in the interventional world, Y is set to y_star
        rows.append(row)  # add this row to the list
    tf_df = pd.concat(rows, ignore_index=True) # concatenate all rows into a dataframe
    tf_dfs.append(tf_df) # add this dataframe to the list

tf_df = pd.concat(tf_dfs, ignore_index=True) #concatenate all dataframes into one

tf_df

Unnamed: 0,age,trestbps,chol,thalach,oldpeak,ca,sex_Female,sex_Male,cp_Asymptomatic,cp_AtypicalAngina,...,exang_NoExAngina,exang_YesExAngina,slope_Downsloping,slope_Flat,slope_Upsloping,thal_FixedDefect,thal_Normal,thal_ReversibleDefect,heartdiseasepresence,do_heartdiseasepresence
0,-0.389221,-0.389821,-0.045837,0.726120,1.237794,-0.740979,0,1,0,0,...,1,0,0,1,0,0,1,0,0,0
1,0.839207,-0.673176,0.400161,-2.143095,0.865141,1.437899,0,1,1,0,...,0,1,0,1,0,0,0,1,0,0
2,1.397584,-0.673176,-0.370199,-0.858371,1.347500,1.437899,0,1,1,0,...,0,1,0,1,0,0,0,1,0,0
3,-1.952676,-0.106466,0.055526,1.625427,1.775790,-0.740979,0,1,0,0,...,1,0,1,0,0,0,1,0,0,0
4,-0.612572,-0.786518,-1.992008,-0.986844,0.017112,2.527338,0,1,0,0,...,1,0,0,0,1,0,1,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
539,0.504181,0.120217,-0.877014,0.554824,0.017112,1.437899,0,1,0,0,...,1,0,0,0,1,0,1,0,1,1
540,0.280831,-1.239886,1.778698,-0.258834,1.549723,0.348460,0,1,1,0,...,0,1,0,1,0,0,0,1,1,1
541,0.615857,-0.389821,0.217707,-0.344482,1.451274,0.348460,0,1,1,0,...,0,1,0,1,0,0,0,1,1,1
542,0.057480,1.593663,0.846158,-0.173186,0.017112,0.348460,0,1,1,0,...,0,1,0,1,0,0,0,1,1,1
