## TRUST: OrderSPN Learning

TRUST is a Bayesian structure learning method that approximately infers a posterior over Bayesian network structures
given data. The posterior distribution over graphs (structures) is represented as an OrderSPN, which is a variant of
the sum-product network (SPN) for distributions over orderings/graphs. 

The distinguishing feature of OrderSPNs is their
ability to perform tractable *exact* inference for a number of useful queries, such as the marginal probability of an
edge, or the Bayesian model averaged causal effect.

In [1]:
import math
import torch
import numpy as np

from trust.learning.learn_trust import learn_ordergraph, learn_orderspn
from trust.oracle.gadget_oracle import GadgetOracle
from trust.oracle.dibs_oracle import DibsOracle
from trust.oracle.enumeration_oracle import EnumerationOracle
from trust.oracle.random_oracle import RandomOracle
from trust.utils.generation import generate_linear_data, make_linear_model, make_erdosrenyi_graph
from trust.learning.split_strategy import ThresholdStrategy
from trust.utils.bge import BGe
from trust.utils.metrics import auroc, mll, pairwise_linear_ce, pairwise_linear_ce_mse, pdag_shd
from trust.leaf_scores.leaf_handler import LeafHandler
from trust.orderspn.evidence import MarginalEvidence
from trust.utils.misc import HiddenPrints

from dibs.config.example import DiBSMarginalExampleSettings
from sumu.gadget import Gadget

  if (self.terms is not 'auto') and not (isinstance(self.terms, (TermList, Term, type(None)))):
  if self.terms is 'auto':
  if flatten_attrs and k is not 'terms':


### Problem Configuration

In [None]:
# BN Graph
d = 8
edges_per_node = 2 # Erdos-Renyi graph parameter

# Data
train_size = 100
test_size = 1000
weight_mean = 0.0 # mean edge weight for linear model

# OrderSPN
expansion_factors = [ 30, 6, 2]  # should be of length ceil(log_2(d))
strong_oracle = "gadget"  # "gadget" or "dibs" or "random"
weak_oracle = "enumeration"
min_dimension = 5 # minimum dimension to use the strong oracle for

learning_rate = 0.1
epochs = 500

# Misc
num_candidate_parents = min(d - 1, 16)
device = "cuda"  # or "cpu"
total_time_budget = 300 # Increase or decrease depending on d, expansion_factors
precomputation_ratio = 0.2

seed = 2

assert(len(expansion_factors) == math.ceil(math.log2(d)))

### Generate Data

In [11]:
rng = np.random.default_rng(seed)
G = make_erdosrenyi_graph(d=d, edges_per_node=edges_per_node)
B = make_linear_model(G, rng, weight_mean=weight_mean)
X_train = generate_linear_data(train_size, B, rng)
X_test = generate_linear_data(test_size, B, rng)
print('Ground truth weight matrix', B)

Ground truth weight matrix [[ 0.         -0.         -0.         -0.          0.          0.
  -0.          0.          0.         -0.          0.         -0.
  -0.         -0.          0.         -0.        ]
 [ 0.54528871 -0.          0.         -0.          0.          0.
   0.          0.         -0.          0.          0.         -0.
  -0.         -0.          0.          0.        ]
 [ 1.07834244  0.          0.          0.         -0.          0.
  -0.         -0.          0.          1.80142086 -0.76446412 -1.07906046
  -0.          0.         -0.          0.        ]
 [-0.          0.          0.         -0.          0.          0.
   0.          0.          0.          0.         -0.         -0.
   0.         -0.         -0.         -0.        ]
 [ 0.         -1.06633961 -0.         -0.          0.          0.
   0.         -0.         -0.          0.         -0.         -1.25418666
  -0.          0.          0.          0.        ]
 [-0.27655172 -0.          0.         -0. 

### Precomputation of Leaf Scores

In [12]:
with HiddenPrints():
    candidates, score, score_array = Gadget(data=X_train, mcmc={"n_dags": 10000},run_mode={"name": "budget", "params": {"t": total_time_budget*precomputation_ratio}},cons={"K": num_candidate_parents}).return_cand_parents_and_score()

Number of candidate parent sets after pruning (unpruned 2^K = 32768):

node	psets	ratio
0	31978	0.975891
1	32733	0.998932
2	32745	0.999298
3	32526	0.992615
4	32767	0.999969
5	32528	0.992676
6	32768	1
7	32768	1
8	32756	0.999634
9	32317	0.986237
10	32767	0.999969
11	32751	0.999481
12	32513	0.992218
13	32656	0.996582
14	32743	0.999237
15	30398	0.927673

Number of score sums stored in cc cache: 0



### Precomputation of Max and Sum Score Arrays

In [13]:
lh = LeafHandler(C=candidates, c_r_score=score.c_r_score, c_c_score=score.c_c_score, score_array=score_array)
lh.precompute_sum_and_max(log=True)

100%|██████████████████████████| 16/16 [00:49<00:00,  3.06s/it]


### Set up Oracle(s) and Splitting Strategy

In [14]:
with HiddenPrints():
    if strong_oracle == "gadget":
        strong_oracle = GadgetOracle(X_train, K=num_candidate_parents)
    elif strong_oracle == "dibs":
        strong_oracle = DibsOracle(G, B, X_train)
    
    if weak_oracle == "enumeration":
        weak_oracle = EnumerationOracle()
    elif weak_oracle == "random":
        weak_oracle = RandomOracle()
    
    strategy = ThresholdStrategy(strong_oracle, weak_oracle, min_dimension=min_dimension)

Number of candidate parent sets after pruning (unpruned 2^K = 32768):

node	psets	ratio
0	31978	0.975891
1	32733	0.998932
2	32745	0.999298
3	32526	0.992615
4	32767	0.999969
5	32528	0.992676
6	32768	1
7	32768	1
8	32756	0.999634
9	32317	0.986237
10	32767	0.999969
11	32751	0.999481
12	32513	0.992218
13	32656	0.996582
14	32743	0.999237
15	30398	0.927673

Number of score sums stored in cc cache: 0



### Structure Learning

Perform learning of the OrderSPN structure, using the chosen oracle methods.

In [15]:
og = learn_ordergraph(d, strategy, expansion_factors, time_budget=total_time_budget*(1-precomputation_ratio), seed=seed,
                      suppress_prints=True, log=True)

Layer 0


  0%|                                    | 0/1 [00:00<?, ?it/s]

100%|████████████████████████████| 1/1 [00:59<00:00, 59.54s/it]


Layer 1


100%|████████████████████████| 128/128 [00:58<00:00,  2.18it/s]


Layer 2


100%|███████████████████| 4096/4096 [00:00<00:00, 56313.96it/s]


Layer 3


100%|█████████████████| 49152/49152 [00:00<00:00, 66217.88it/s]


### Parameter Learning

Solve for the optimal parameters of the OrderSPN.

In [18]:
ospn = learn_orderspn(og, device='cpu', leaf_function=lh)
print(ospn.learn_spn())

-616.36017


### Evaluation Metrics

Compute evaluation metrics for the OrderSPN posterior.

In [19]:
G_samples = ospn.sample(10000, d)
bge_model = BGe(d=d, alpha_u=1)

# Compute expected SHD
trust_shd = pdag_shd(np.copy(G_samples), np.copy(G))
print(f'SHD |  {trust_shd:4.1f},')

# Compute marginal edge probabilities
marg_details = MarginalEvidence(d)
pairwise_edge_probs = np.zeros((d, d))
for j in range(d):
    for i in range(d):
        if (j != i) and (i not in marg_details[j]):
            marg_details.add_node_evidence(j, i)
            pairwise_edge_probs[i][j] = np.exp(ospn.marginal(marg_details).cpu().detach().numpy())
            marg_details.remove_node_evidence(j, i)

# Compute AUROC
trust_auroc = auroc(pairwise_edge_probs, np.copy(G))
print(f'AUROC| {trust_auroc:5.2f},')

# Compute KL-divergence
trust_kl = mll(np.copy(G_samples), X_test, bge_model)
print(f'MLL| {trust_kl:4.1f},')

# Compute BACE matrix; compare approximate (sampling-based) and exact methods
approx_avg_pairwise_effects = pairwise_linear_ce(np.copy(G_samples), X_train, bge_model)
exact_avg_pairwise_effects = ospn.bace(X_train, bge_model).cpu().detach().numpy()

approx_trust_mse = pairwise_linear_ce_mse(approx_avg_pairwise_effects, B)
print(f'CE_MSE| {approx_trust_mse:5.5f},')
exact_trust_mse = pairwise_linear_ce_mse(exact_avg_pairwise_effects, B)
print(f'CE_MSE_exact| {exact_trust_mse:5.5f},')







SHD |  27.9,
AUROC|  0.97,
MLL| -4753.2,
CE_MSE| 0.10931,
CE_MSE_exact| 0.14998,


### Example Queries

A key feature of OrderSPNs is the ability to perform exact inference (i.e., without sampling) on the learned posterior distribution of graphs.

In particular, we demonstrate here the following queries:
- MPE $\max_{G} p(G)$: Finding the most likely graph in the posterior distribution;
- COND $p(G_{ij}|G_{kl})$: Finding the conditional probability of an edge, given the presence (or absence) of other edges;
- BACE $\mathbb{E}[ACE(i \to j)]$: Computing the average causal effect of variable $X_i$ on $X_j$, averaged over the posterior distribution of causal graphs.

In [20]:
# Most likely graph
marg_details = MarginalEvidence(d)
scc, scG = ospn.mpe(marg_details, d)
print("Maximum a posterior graph:")
print(scG)

# Marginal and Conditional probability:
evidence1 = MarginalEvidence(d)
evidence1.add_node_evidence(1, 4)
evidence2 = MarginalEvidence(d)
evidence2.add_node_evidence(1, 4)
evidence2.add_node_evidence(1, 3)
prob = torch.exp(ospn.conditional(evidence1, evidence2)).item()
print("Conditional probability of edge 3->1 given edge 4->1:")
print(prob)

evidence1 = MarginalEvidence(d)
evidence1.add_node_evidence(1, 7)
evidence2 = MarginalEvidence(d)
evidence2.add_node_evidence(1, 7)
evidence2.add_node_evidence(7, 1)
prob = torch.exp(ospn.conditional(evidence1, evidence2)).item()
print("Conditional probability of edge 1->7 given edge 7->1 (should be exactly 0 as this is a cycle):")
print(prob)

# Bayesian averaged causal effect (BACE) matrix
print("BACE matrix:")
print(exact_avg_pairwise_effects)

Maximum a posterior graph:
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Conditional probability of edge 3->1 given edge 4->1:
0.042344268411397934
Conditional probability of edge 1->7 given edge 7->1 (should be exactly 0 as this