# Abstract Models Visualization

This notebook provides a visualization
of the sampling process for pairs
of abstract and concrete linear SCMs.

In [None]:
from tqdm.auto import tqdm
import igraph as ig
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import seaborn as sns

from causabs.dataset import generate_datasets, load_dataset
from causabs.utils import check_cancelling_paths
from causabs.utils import seed_everything

seed_everything(42)

Generate dataset with the provided parameters.

In [None]:
data_dir = "data/"
num_runs = 10
dset_params = {
    "abs_nodes": 10,
    "abs_edges": 20,
    "abs_type": "ER",
    "min_block_size": 10,
    "max_block_size": 15,
    "alpha": 1e3,
    "relevant_ratio": 0.5,
    "internal": True,
    "n_samples": 1000,
    "noise_term": "exponential",
    "noise_abs": 0.0,
}
signature = generate_datasets(dset_params, data_dir, num_runs, force=False)

Analyze the generated weigths in each concrete adjacency matrix $\mathbf{W}$.

In [None]:
# empty dataframe
run_summary = pd.DataFrame()
for run in range(num_runs):
    cnc_weights, _, _, _, _, dset = load_dataset(data_dir, signature, run)
    samples_x, samples_y = dset
    cancelling = not check_cancelling_paths(cnc_weights)
    max_w = np.max(cnc_weights)
    min_w = np.min(cnc_weights)
    mean_w = np.mean(cnc_weights)
    std_w = np.std(cnc_weights)
    run_summary = pd.concat(
        [
            run_summary,
            pd.DataFrame(
                {
                    "Canc. Paths": [cancelling],
                    "Max |W_ij|": [max_w],
                    "Min |W_ij|": [min_w],
                    "Mean |W_ij|": [mean_w],
                    "Std |W_ij|": [std_w],
                }
            ),
        ]
    )
run_summary

Analyze a given run.

In [None]:
run = 3
cnc_weights, abs_weights, tau_adj, gamma_adj, partitions, _ = load_dataset(
    data_dir, signature, run
)

Check the necessary condition from Theorem 3 for each pair $(i,j)$ of abstract nodes $\mathbf{W}_{ij} \mathbf{F}_{jj} \bm{t}_j = m_{ij} \bm{t}_i$.

In [None]:
# nodes
abs_nodes = abs_weights.shape[0]
cnc_nodes = cnc_weights.shape[0]

# partitions boundaries
block_start = np.concatenate([[0], np.cumsum(partitions[:-1])])
block_end = np.cumsum(partitions)


# Helper to get block-indices
def get_block(y: int) -> slice:
    return slice(block_start[y], block_start[y] + partitions[y])


F = np.linalg.inv(np.eye(cnc_nodes) - cnc_weights)
for y1 in range(abs_nodes):
    for y2 in range(y1 + 1, abs_nodes):
        W_ij = cnc_weights[get_block(y1), get_block(y2)]
        # print(y1, y2, W_ij.shape)
        F_jj = F[get_block(y2), get_block(y2)]
        # print(F_jj.shape)
        t_j = tau_adj[get_block(y2), y2]
        # print(t_j.shape)
        t_i = tau_adj[get_block(y1), y1]
        # print(t_i.shape)
        m_ij = abs_weights[y1, y2]
        # print(m_ij.shape)
        test = np.allclose(W_ij @ F_jj @ t_j, m_ij * t_i)
        if not test:
            print(y1, y2, test, np.linalg.norm(W_ij @ F_jj @ t_j - m_ij * t_i))

Construct the sets of relevant variables $\Pi_R$ and block variables $\Pi$.

In [None]:
# relevant variables
relevant = []
for y in range(abs_nodes):
    relevant.append(list(np.where(np.abs(tau_adj[:, y]) > 0.0)[0]))

# constitutive variables
block = []
for y in range(abs_nodes):
    block.append(list(np.where(np.abs(gamma_adj[:, y]) > 0.0)[0]))

# block matrix
p_matrix = np.zeros_like(tau_adj)
for y in range(abs_nodes):
    for x in range(block_start[y], block_end[y]):
        p_matrix[x, y] = 1

Visualize the abstract model, the concrete model, and their abstraction function.

In [None]:
plt.figure(figsize=(15, 10))

vmin = -2
vmax = 2

# Tau
plt.subplot(2, 3, 1)
sns.heatmap(
    tau_adj, annot=False, cmap="RdBu_r", vmin=vmin, vmax=vmax, cbar=False
)
plt.title(r"$\tau$-abstraction")

# Gamma
plt.subplot(2, 3, 2)
sns.heatmap(gamma_adj, annot=False, cmap="RdBu_r", vmin=vmin, vmax=vmax)
plt.title(r"$\gamma$-abstraction")

# partition
plt.subplot(2, 3, 3)
sns.heatmap(
    p_matrix, annot=False, cmap="RdBu_r", vmin=-1, vmax=1.0, cbar=False
)
plt.title(r"$\Pi$-partition")

# concrete model weights
plt.subplot(2, 3, 4)
sns.heatmap(
    cnc_weights, annot=False, cmap="RdBu_r", vmin=vmin, vmax=vmax, cbar=False
)
# Drawing additional grid lines after the first and third rows
color = "#666666FF"
lw = 1
plt.axhline(0.0, color=color, lw=lw)  # Horizontal lines
plt.axvline(cnc_weights.shape[1], color=color, lw=lw)  # Horizontal lines
for cnc_start, cnc_size in zip(
    block_start, partitions
):  # Adjust the range based on the size of your heatmap
    xmin = cnc_start / cnc_weights.shape[1]
    ymin = 1 - (cnc_start + cnc_size) / cnc_weights.shape[0]
    ymax = 1.0  # ymin + cnc_size / cnc_weights.shape[0]
    plt.axhline(
        cnc_start + cnc_size, xmin=xmin, color=color, lw=lw
    )  # Horizontal lines
    plt.axvline(
        cnc_start, ymin=ymin, ymax=ymax, color=color, lw=lw
    )  # Horizontal lines
plt.title(r"$\mathcal{L}$ Concrete model")

# abstract model weights
plt.subplot(2, 3, 5)
sns.heatmap(abs_weights, annot=False, cmap="RdBu_r", vmin=vmin, vmax=vmax)
plt.title(r"$\mathcal{H}$ Abstract model")

plt.tight_layout()
plt.show()