
# Functional Bayesian Networks

Functional Bayesian Networks (FBNs) are Bayesian networks where each CPD is a Python function that returns a Pyro distribution. This lets you model arbitrary discrete, continuous, or mixed relationships—while keeping standard graph semantics for sampling, interventions, and learning.

Similar to other Bayesian Network classes in pgmpy, there are two main components of the model: 1. The graphical structure of the model, 2. The parameterization of the model (defined using `FunctionalCPD`).

This tutorial introduces **Functional Bayesian Networks (FBNs)** and the accompanying **Functional CPDs** in `pgmpy`. You'll learn how to:
- Define FunctionalCPDs as *functions that return Pyro distributions*.
- Build FunctionalBayesianNetworks on mixed data.
- Simulate data from the model.
- Use vectorized CPDs for performance.
- Fit simple parametric FBNs using **SVI** and **MCMC** via Pyro.

In [1]:

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import torch
import pyro
import pyro.distributions as dist

from pgmpy.models import FunctionalBayesianNetwork
from pgmpy.factors.hybrid import FunctionalCPD

# Reproducibility
pyro.set_rng_seed(123)



## Functional CPDs: the core idea

A **Functional CPD** has three arguments:
1. `variable`: The variable for the FunctionalCPD is being defined. For a fully parameterized model, each node in the model needs a `FunctionalCPD` associated with it.
2. `fn`: A Python callable that takes a dictionary of parents' values as input and returns a **Pyro distribution**.
3. `parents`: The parents of `variable` in the model.


## Example 1: A simple Gaussian chain: `x1 → x2 → x3`

In [2]:

# Build the structure
gauss_chain = FunctionalBayesianNetwork([("x1", "x2"), ("x2", "x3")])

# Define CPDs
cpd_x1 = FunctionalCPD("x1", fn=lambda _: dist.Normal(0.0, 1.0))  # prior
cpd_x2 = FunctionalCPD("x2", fn=lambda parents: dist.Normal(1.0 + 0.8 * parents["x1"], 0.5), parents=["x1"])
cpd_x3 = FunctionalCPD("x3", fn=lambda parents: dist.Normal(0.3 + 1.0 * parents["x2"], 1.0), parents=["x2"])

gauss_chain.add_cpds(cpd_x1, cpd_x2, cpd_x3)
gauss_chain.check_model()

# Draw a few samples
samples_gc = gauss_chain.simulate(n_samples=5, seed=123)
samples_gc


Unnamed: 0,x1,x2,x3
0,-0.111467,1.015461,1.525792
1,0.120363,0.610113,0.51927
2,-0.369635,0.32677,0.861743
3,-0.240418,0.969617,1.934877
4,-1.196924,-0.011801,0.64102


### Example 2: Complex model with mixture data

In [3]:
complex_bn = FunctionalBayesianNetwork(
    [("x1", "w"), ("x2", "w"), ("x1", "y"),
     ("x2", "y"), ("w", "y"), ("y", "z"), 
     ("w", "z"), ("y", "c"), ("w", "c")]
)

# Roots
cpd_x1 = FunctionalCPD("x1", fn=lambda _: dist.Normal(0.0, 1.0))
cpd_x2 = FunctionalCPD("x2", fn=lambda _: dist.Normal(0.5, 1.2))

# Continuous mediator: w = 0.7*x1 - 0.3*x2 + ε
cpd_w = FunctionalCPD(
    "w",
    fn=lambda parents: dist.Normal(0.7 * parents["x1"] - 0.3 * parents["x2"], 0.5),
    parents=["x1", "x2"]
)

# Bernoulli target with logistic link: y ~ Bernoulli(sigmoid(-0.7 + 1.5*x1 + 0.8*x2 + 1.2*w))
cpd_y = FunctionalCPD(
    "y",
    fn=lambda parents: dist.Bernoulli(logits=(-0.7 + 1.5 * parents["x1"] + 0.8 * parents["x2"] + 1.2 * parents["w"])),
    parents=["x1", "x2", "w"]
)

# Downstream Bernoulli influenced by y and w
cpd_z = FunctionalCPD(
    "z",
    fn=lambda parents: dist.Bernoulli(logits=(-1.2 + 0.8 * parents["y"] + 0.2 * parents["w"])),
    parents=["y", "w"]
)

# Continuous outcome depending on y and w: c = 0.2 + 0.5*y + 0.3*w + ε
cpd_c = FunctionalCPD(
    "c",
    fn=lambda parents: dist.Normal(0.2 + 0.5 * parents["y"] + 0.3 * parents["w"], 0.7),
    parents=["y", "w"]
)

complex_bn.add_cpds(cpd_x1, cpd_x2, cpd_w, cpd_y, cpd_z, cpd_c)
complex_bn.check_model()

# Simulate data from it
complex_bn.simulate(n_samples=8, seed=123)

Unnamed: 0,x1,x2,w,y,z,c
0,-0.111467,0.888683,-0.36394,0.0,1.0,-0.126706
1,0.120363,0.369773,-0.469728,0.0,0.0,0.203071
2,-0.369635,0.752397,-0.719913,0.0,0.0,0.660035
3,-0.240418,0.030989,-0.39104,0.0,1.0,0.576696
4,-1.196924,0.781968,-1.086602,0.0,0.0,0.384325
5,0.209269,1.298313,0.468003,1.0,0.0,1.734681
6,-0.972355,0.923385,-1.151972,0.0,1.0,-0.902663
7,-0.755045,1.667385,-1.473914,1.0,1.0,-0.654092



## Vectorized CPDs for speed

Set `vectorized=True` and have your `fn(parent_df)` return a **batched** Pyro distribution whose batch size equals the number of rows in the provided parent DataFrame. This makes sampling much faster for large `n_samples`.


In [4]:

from pgmpy import config

vec_bn = FunctionalBayesianNetwork([("x1", "x2")])

cpd_x1 = FunctionalCPD("x1", fn=lambda _: dist.Normal(0.0, 1.0))

def x2_fn_vec(P):
    x1 = torch.tensor(P["x1"].values, dtype=config.get_dtype(), device=config.get_device())
    mu = 0.5 + 0.9 * x1
    sigma = torch.full_like(mu, 0.3)
    return dist.Normal(mu, sigma)

cpd_x2 = FunctionalCPD("x2", fn=x2_fn_vec, parents=["x1"], vectorized=True)

vec_bn.add_cpds(cpd_x1, cpd_x2)
vec_bn.check_model()

# Large draw to highlight performance of vectorized CPDs
vec_samples = vec_bn.simulate(n_samples=5000, seed=123)
vec_samples.head()


Unnamed: 0,x1,x2
0,-0.111467,0.796616
1,0.120363,0.786959
2,-0.369635,-0.07729
3,-0.240418,0.097413
4,-1.196924,-0.559637



## Parameter learning with SVI

When CPDs contain **Pyro parameters** (`pyro.param(...)`), you can fit them to data using `model.fit(..., method="SVI")`.
Below, we synthesize data from a simple linear-Gaussian model and then recover the parameters.


In [5]:
# Generate synthetic data
true_mu, true_sigma = 0.8, 0.6
N = 2000
x1 = torch.normal(mean=true_mu, std=true_sigma, size=(N,))
# FIX: vectorized draw for x2 (no size when mean is a tensor)
x2 = torch.normal(mean=1.2 + x1, std=0.7)         # or: x2 = 1.2 + x1 + 0.7 * torch.randn_like(x1)

data = pd.DataFrame({"x1": x1.numpy(), "x2": x2.numpy()})

from torch.distributions import constraints
import pyro, pyro.distributions as dist
pyro.clear_param_store()  # helpful if you re-run the cell

svi_bn = FunctionalBayesianNetwork([("x1", "x2")])

def x1_fn(_):
    mu = pyro.param("x1_mu", torch.tensor(0.0))
    sigma = pyro.param("x1_sigma", torch.tensor(1.0), constraint=constraints.positive)
    return dist.Normal(mu, sigma)

def x2_fn(p):
    inter = pyro.param("x2_inter", torch.tensor(0.0))
    sigma = pyro.param("x2_sigma", torch.tensor(1.0), constraint=constraints.positive)
    return dist.Normal(inter + p["x1"], sigma)

svi_bn.add_cpds(
    FunctionalCPD("x1", fn=x1_fn),
    FunctionalCPD("x2", fn=x2_fn, parents=["x1"]),
)
svi_bn.check_model()

# Fit with SVI
params_svi = svi_bn.fit(data, method="SVI", num_steps=300)
{k: v.item() if torch.is_tensor(v) and v.ndim == 0 else v for k, v in params_svi.items()}

INFO:pgmpy:Step 0 | Loss: 6574.2528
INFO:pgmpy:Step 50 | Loss: 4983.4878
INFO:pgmpy:Step 100 | Loss: 4229.9033
INFO:pgmpy:Step 150 | Loss: 3987.9143
INFO:pgmpy:Step 200 | Loss: 3981.4799
INFO:pgmpy:Step 250 | Loss: 3981.4706


{'x1_mu': 0.7751224637031555,
 'x1_sigma': 0.6131476759910583,
 'x2_inter': 1.1970738172531128,
 'x2_sigma': 0.6990763545036316}


## Bayesian inference with MCMC

If you prefer a fully Bayesian approach, provide **priors** and switch `method="MCMC"`.  
Here, we place priors on the parameters, then sample from the posterior.


In [6]:
mcmc_bn = FunctionalBayesianNetwork([("x1", "x2")])

# Priors with matched dtype/device (avoid dtype mismatches)
dtype = config.get_dtype()
device = config.get_device()
def prior_fn():
    t = lambda v: torch.tensor(v, dtype=dtype, device=device)
    return {
        "x1_mu":   dist.Normal(t(0.0), t(5.0)),
        "x1_sigma": dist.HalfNormal(t(2.0)),
        "x2_inter": dist.Normal(t(0.0), t(5.0)),
        "x2_sigma": dist.HalfNormal(t(2.0)),
    }

# CPDs consume *sampled* prior values (from the model) + parents
def x1_fn_prior(priors, _):
    return dist.Normal(priors["x1_mu"], priors["x1_sigma"])

def x2_fn_prior(priors, p):
    return dist.Normal(priors["x2_inter"] + p["x1"], priors["x2_sigma"])

mcmc_bn.add_cpds(
    FunctionalCPD("x1", fn=x1_fn_prior),
    FunctionalCPD("x2", fn=x2_fn_prior, parents=["x1"]),
)

pyro.clear_param_store()

post = mcmc_bn.fit(
    data,
    method="MCMC",
    prior_fn=prior_fn,
    num_steps=200,
    nuts_kwargs={"target_accept_prob": 0.8},
    mcmc_kwargs={"num_chains": 1, "warmup_steps": 200},
)

# Peek at posterior summaries
{k: (v.mean().item(), v.std().item()) for k, v in post.items() if torch.is_tensor(v)}


Sample: 100%|█████████████████████████████████████████| 400/400 [00:03, 109.04it/s, step size=6.89e-01, acc. prob=0.881]                                                                                                                                                                            


{'x1_mu': (0.777364034922299, 0.0144704934171712),
 'x1_sigma': (0.6129607990640703, 0.010122547469210461),
 'x2_inter': (1.1970455001675568, 0.014884526650204519),
 'x2_sigma': (0.7008062278971647, 0.010433891834526142)}


## Interventions and conditioning (preview)

Conceptually, an intervention `do(X = value)` **severs incoming edges into `X`** and replaces its CPD with a constant or a new distribution.  
A simple approach for simulations is to *temporarily swap the CPD of `X`* with a `FunctionalCPD` that ignores parents and returns `dist.Delta(value)` (or any desired distribution).

Conditioning on continuous **point evidence** should use **likelihood weighting** or proper inference (SVI/MCMC), not plain rejection sampling. The high-level recipe for likelihood weighting is:
1. Sample all **non-evidence** nodes in topological order (respecting any `do(...)` replacements).
2. Clamp evidence nodes to their observed values.
3. Weight each draw by the product of evidence likelihoods under their parents’ sampled values.
4. Normalize weights and either keep **weighted samples** or **resample** for an unweighted posterior sample.

Future versions may expose `simulate(do=..., evidence=...)` directly in the API.



---

### Key takeaways
- Functional CPDs let you specify **any** distribution you can write in Pyro.
- Mixed types (discrete/continuous) are straightforward.
- Use **vectorized CPDs** for performance on large simulations.
- For learning: quick **SVI** with `pyro.param(...)` or fully Bayesian **MCMC** with `prior_fn`.
