Discrete Probability Trees (or Staged Tree Models) are one of the simplest models of causal generative processes, different from Causal Bayesian Networks (CBNs) and Structural Causal Models (SCMs).

A directed edge $e=(h,t)$ between two nodes in this graph is interpreted as the head being the cause of the tail.

Their implementation is quite simple. Based on [Algorithms for Causal Reasoning in Probability Trees](https://arxiv.org/pdf/2010.12237.pdf).

In [14]:
import dataclasses as dc
from collections import namedtuple
from typing import NamedTuple, List

They represent PT nodes as $n=(u,S,C)$: unique identifier, list of statements (e.g. `W=rainy` and `X=0`) and a set of transitions.

> A (total) realization in the probability tree is a path from the root to a leaf, and its probability is obtained by multiplying the transition probabilities along the path; and a partial realization is any connected sub-path within a total realization.

In [15]:
Stmt = namedtuple('Stmt', 'var val')

@dc.dataclass(frozen=True)
class Node:  # (u, S, C)
  id: int
  stmts: List[Stmt] = dc.field(compare=False)
  childs: 'List[Node]' = dc.field(
    default_factory=list, compare=False)

Min-cut is a collection of nodes with **probabilities summing up to 1**, and a minimal representation of an event in terms of PT nodes.

In [13]:
null = frozenset()
MinCut = namedtuple('MinCut', 'T F', defaults=(null, null))

For a simple statement, its min-cut is the node containing that statement.

In [None]:
def prop(n: Node, s: Stmt):
  for var, val in n.stmts:
    if var == s.var:
      return MinCut(
        T={n.id} if val == s.val else set(),
        F=set() if val == s.val else {n.id})
  if not n.childs:
    raise Exception('Cannot be resolved')
  T, F = set(), set()
  for pc, nc in n.childs:
    Tc, Fc = prop(nc, s)
    T, F = T|Tc, F|Fc
  return MinCut(T=T, F=F)

Standard conjunction/disjunction/negation.

In [16]:
def and_(n: Node, d1: MinCut, d2: MinCut,
         found1=False, found2=False):
  if n.id in (d1.F|d2.F):
    return MinCut(F={n.id})
  if n.id in d1.T: found1 = True
  if n.id in d2.T: found2 = True
  if found1 and found2:
    return MinCut(T={n.id})
  T, F = set(), set()
  for pc, nc in n.childs:
    Tc, Fc = and_(nc, d1, d2, found1, found2)
    T, F = T|Tc, F|Fc
  return MinCut(T, F)

def not_(d: MinCut):
  # Just swap true and false sets.
  return MinCut(T=d.F, F=d.T)

def or_(n: Node, d1: MinCut, d2: MinCut):
  # De Morgan
  return not_(and_(n, not_(d1), not_(d2)))

Precedence relation.
There is a causal and an effect min-cut.

In [None]:
def prec(n:Node, dc:MinCut, de:MinCut, is_cause=False):
  '''A<B precedence relation.
  
  @param dc: cause
  @param de: effect
  @return min-cut for the event where precedence holds.
  '''
  if not f:
    if n.id in (de.T|dc.F|de.F):
      return MinCut(F={n.id})
    if n.id in dc.T: 
      is_cause = True
  else:
    if n.id in dc.T: return MinCut(T={n.id})
    if n.id in dc.F: return MinCut(F={n.id})
  T, F = (set(), set())
  for pc, nc in n.childs:
    Tc, Fc = prec(nc, dc, de, is_cause)
    T, F = T|Tc, F|Fc
  return MinCut(T, F)

Conditioning updates the tree after an event is revealed to be true.
Just remove all probability mass from the false min cut in `d`, and recompute probabilities otherwise.

In [None]:
def see(n: Node, d: MinCut, q=1.0):
  '''P(A|B) :: proba of A, given B.
  
  @param n: reference proba tree
  @param d: observed event
  '''
  if n.id in d.T: return n, 1, q
  if n.id in d.F: return n, 0, 0
  D = set()
  sl, sp = 0, 0
  for pc, nc in n.childs:
    nc, lc, pc = see(nc, d, pc*q)
    D |= {(nc,lc,pc)}
    sl += lc
    sp += pc
  C = norm(D, sl, sp)
  return Node(n.id, n.stmts, C), 1, sp

Intervening only affects realizations downstream of the critical set
(unlike conditions, which also affect upstream information).
Hence algo is the same as `see`, but no probabilities.

In [None]:
def do(n:Node, d:MinCut):
  '''P(A|do(B)) :: proba of A, given B was made true.
  
  @param n: reference proba tree
  @param d: intervened event (do(B))
  '''
  if n.id in d.T: return n, True
  if n.id in d.F: return n, False
  T, F = (set(), set())
  sl, sp = 0, 0
  for pc, nc in n.childs:
    n, b = do(nc, d)
    if b:
      D |= {(nc,1,pc)}
      sl += 1
      sp += 1
    else:
      D |= {(nc,0,pc)}
  C = norm(D, sl, sp)
  return Node(n.id, n.stmts, C), True

A counterfactual is a statement about a subjunctive (possible or imagined) event that could've happened had the stochastic process taken a different course during its realization.

Below: A is subjunctive, C counterfactual, B factual.

In [None]:
def counterfact(n:Node, m:Node, d:MinCut):
  '''P(Ac|B) :: proba of A given B and if C was made true.

  @param n: reference proba tree
  @param m: factual premise tree (B)
  @param d: counterfactual event (C)
  '''
  if n.id in d.T: return n, True
  if n.id in d.F: return n, False
  is_critical_bifurcation = False
  C = set()
  for (pn, nc), (pm, mc) in zip(n.childs, m.childs):
    nn, b = counterfact(nc, mc, d)
    if not b: is_critical_bifurcation = True
    C |= {(pm, nn)}
  if not is_critical_bifurcation:
    n = Node(n.id, n.stmt, C)
  return n, True

Normalize or assume uniform measure if real probability is zero.

In [None]:
def norm(D, sl, sp):
  if sp: return {(p/sp, n) for n,l,p in D}
  return {(l/sl, n) for n,l,p in D}