In [1]:
import torch
import torch.nn as nn
from representation.bayesian_nn import BayesLinearGMM, BNN_GMM
from representation.gaussian_mixture import GaussianMix

device="cpu"

In [2]:
# 1) Create a toy BNN: input=3, hidden=5, output=2, K=4 mixture components
bnn = BNN_GMM(layer_sizes=[3, 5, 2], K=4, bias=True).to(device)

print("=== From-scratch init (default reset_parameters) ===")
print("Layer 0 pi_w shape:", bnn.layers[0].pi_w.shape)   # (5, 3, 4)
print("Layer 0 mu_w mean/std:", bnn.layers[0].mu_w.mean().item(), bnn.layers[0].mu_w.std().item())
print("Layer 0 sigma_w min:", bnn.layers[0].sigma_w.min().item())
print("Layer 0 pi_w row-sum (should be ~1):",
        bnn.layers[0].pi_w[0, 0, :].sum().item())

# 2) Create a deterministic MLP with matching shapes
det = nn.Sequential(
    nn.Linear(3, 5),
    nn.ReLU(),
    nn.Linear(5, 2),
).to(device)

# (Optional) pretend it's "trained" by just doing one fake optimizer step
opt = torch.optim.SGD(det.parameters(), lr=0.1)
x = torch.randn(16, 3, device=device)
y = torch.randn(16, 2, device=device)
loss = ((det(x) - y) ** 2).mean()
loss.backward()
opt.step()

# 3) Initialize BNN mixture params from deterministic net (component 0 matches det weights)
bnn.init_from_deterministic_mlp(det, sigma0=1e-3, main_comp=0)

print("\n=== Init-from-deterministic ===")
# Check that component 0 matches
diff_W0 = (bnn.layers[0].mu_w[..., 0] - det[0].weight.data).abs().max().item()
diff_b0 = (bnn.layers[0].mu_b[:, 0] - det[0].bias.data).abs().max().item()
print("Max |mu_w[...,0] - W_det|:", diff_W0)
print("Max |mu_b[:,0]  - b_det|:", diff_b0)
print("Layer 0 pi_w[0,0,:]:", bnn.layers[0].pi_w[0, 0, :].detach().cpu())
print("Layer 0 sigma_w[0,0,:]:", bnn.layers[0].sigma_w[0, 0, :].detach().cpu())

=== From-scratch init (default reset_parameters) ===
Layer 0 pi_w shape: torch.Size([5, 3, 4])
Layer 0 mu_w mean/std: 0.000469229620648548 0.3450576961040497
Layer 0 sigma_w min: 0.028867512941360474
Layer 0 pi_w row-sum (should be ~1): 1.0

=== Init-from-deterministic ===
Max |mu_w[...,0] - W_det|: 0.0
Max |mu_b[:,0]  - b_det|: 0.0
Layer 0 pi_w[0,0,:]: tensor([1.0000e+00, 1.0000e-12, 1.0000e-12, 1.0000e-12])
Layer 0 sigma_w[0,0,:]: tensor([0.0010, 0.0010, 0.0010, 0.0010])


In [3]:
bnn = BNN_GMM(layer_sizes=[2, 3, 1], K=4, bias=True).to(device)

In [4]:
bnn.layers[0].pi_w

Parameter containing:
tensor([[[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]]], requires_grad=True)

In [5]:
def compute_neuron_preactivation(input : GaussianMix, weights : torch.Tensor, bias=None):
    """
    Compute the pre-activation Gaussian mixture of a neuron given input Gaussian mixture and weights.

    Args:
        input (GaussianMix): Input Gaussian mixture with shape (batch_size, input_dim, K).
        weights (torch.Tensor): Weights of the neuron, each of them is a gaussian mixture and thus have pi, mu, sigma.
        bias (torch.Tensor, optional): Bias of the neuron, it is also a gaussian mixture. Defaults to None.
    """

    

    
    return 

def gm_product(pi_w, mu_w, sigma_w, pi_x, mu_x, sigma_x, eps=1e-12):
    """
    Product of two independent scalar Gaussian mixtures (moment-matched).

    Inputs:
      pi_w, mu_w, sigma_w: (Kw,)
      pi_x, mu_x, sigma_x: (Kx,)

    Returns:
      pi_y:     (Kw*Kx,)
      mu_y:     (Kw*Kx,)
      sigma_y:  (Kw*Kx,)
    """

    # Ensure 1D
    pi_w, mu_w, sigma_w = pi_w.view(-1), mu_w.view(-1), sigma_w.view(-1)
    pi_x, mu_x, sigma_x = pi_x.view(-1), mu_x.view(-1), sigma_x.view(-1)

    # ---- Cartesian product via broadcasting ----
    # Weights
    pi_y = (pi_w[:, None] * pi_x[None, :]).reshape(-1)

    # Means
    mu_y = (mu_w[:, None] * mu_x[None, :]).reshape(-1)

    # Variances
    var_y = (
        (sigma_w[:, None] ** 2) * (sigma_x[None, :] ** 2)
        + (sigma_w[:, None] ** 2) * (mu_x[None, :] ** 2)
        + (sigma_x[None, :] ** 2) * (mu_w[:, None] ** 2)
    ).reshape(-1)

    sigma_y = torch.sqrt(var_y.clamp_min(eps))

    # Normalize weights (important numerically)
    pi_y = pi_y.clamp_min(eps)
    pi_y = pi_y / pi_y.sum()

    return pi_y, mu_y, sigma_y


def gm_times_scalar(pi, mu, sigma, x):
    """
    X deterministic scalar, W ~ Gaussian mixture
    """
    pi_y = pi
    mu_y = mu * x
    sigma_y = sigma * x.abs()
    return pi_y, mu_y, sigma_y



In [6]:
# ----------------------------
# Utilities: mixture moments
# ----------------------------
def mix_mean_var(pi, mu, sigma):
    """
    Moments of a scalar Gaussian mixture.
    pi, mu, sigma: (K,)
    Returns: mean, var
    """
    mean = torch.sum(pi * mu)
    second_moment = torch.sum(pi * (sigma**2 + mu**2))
    var = second_moment - mean**2
    return mean, var


# ----------------------------
# Monte Carlo sampling of a mixture
# ----------------------------
def sample_gm(pi, mu, sigma, N: int):
    """
    Sample N points from a scalar Gaussian mixture.
    pi, mu, sigma: (K,)
    returns: (N,)
    """
    K = pi.numel()
    # Sample component indices
    idx = torch.multinomial(pi, num_samples=N, replacement=True)  # (N,)
    # Sample from the chosen normal component
    eps = torch.randn(N, device=pi.device, dtype=pi.dtype)
    return mu[idx] + sigma[idx] * eps


# ----------------------------
# Test / demo
# ----------------------------
torch.manual_seed(0)

# Define two mixtures with small K (easy to inspect)
pi_w = torch.tensor([0.3, 0.7], dtype=torch.float64)
mu_w = torch.tensor([-1.0,  2.0], dtype=torch.float64)
sg_w = torch.tensor([ 0.5,  0.2], dtype=torch.float64)

pi_x = torch.tensor([0.4, 0.6], dtype=torch.float64)
mu_x = torch.tensor([ 0.5, -1.5], dtype=torch.float64)
sg_x = torch.tensor([ 0.3,  0.4], dtype=torch.float64)

# (Optional sanity checks)
assert torch.isclose(pi_w.sum(), torch.tensor(1.0, dtype=torch.float64))
assert torch.isclose(pi_x.sum(), torch.tensor(1.0, dtype=torch.float64))
assert (sg_w > 0).all() and (sg_x > 0).all()

# Compute product-mixture (approx)
pi_y, mu_y, sg_y = gm_product(pi_w, mu_w, sg_w, pi_x, mu_x, sg_x)

# Predicted moments from the output mixture
m_pred, v_pred = mix_mean_var(pi_y, mu_y, sg_y)

# Monte Carlo check
N = 500_000
w_s = sample_gm(pi_w, mu_w, sg_w, N)
x_s = sample_gm(pi_x, mu_x, sg_x, N)
y_s = w_s * x_s

m_mc = y_s.mean()
v_mc = y_s.var(unbiased=False)

print("=== Product GM sanity check (moment test) ===")
print(f"Output components: {pi_y.numel()} (should be Kw*Kx = {pi_w.numel()*pi_x.numel()})")
print(f"pi_y sums to: {pi_y.sum().item():.12f}")
print()
print(f"Predicted mean: {m_pred.item(): .6f}   | MC mean: {m_mc.item(): .6f}   | abs diff: {abs(m_pred-m_mc).item():.6e}")
print(f"Predicted var : {v_pred.item(): .6f}   | MC var : {v_mc.item(): .6f}   | abs diff: {abs(v_pred-v_mc).item():.6e}")

=== Product GM sanity check (moment test) ===
Output components: 4 (should be Kw*Kx = 4)
pi_y sums to: 1.000000000000

Predicted mean: -0.770000   | MC mean: -0.765026   | abs diff: 4.973549e-03
Predicted var :  4.474246   | MC var :  4.459763   | abs diff: 1.448291e-02


In [7]:
def gm_product_per_neuron(pi_w_j, mu_w_j, sigma_w_j, pi_x, mu_x, sigma_x, eps=1e-12, normalize_pi=True):

    I, Kw = pi_w_j.shape
    I2, Kx = pi_x.shape
    assert I == I2

    # Broadcast to (I, Kw, Kx)
    pi = pi_w_j.unsqueeze(-1) * pi_x.unsqueeze(-2)
    mu = mu_w_j.unsqueeze(-1) * mu_x.unsqueeze(-2)

    var = (sigma_w_j.unsqueeze(-1)**2) * (sigma_x.unsqueeze(-2)**2) \
        + (sigma_w_j.unsqueeze(-1)**2) * (mu_x.unsqueeze(-2)**2) \
        + (sigma_x.unsqueeze(-2)**2) * (mu_w_j.unsqueeze(-1)**2)

    sigma = torch.sqrt(var.clamp_min(eps))

    # Flatten (Kw, Kx) -> (Kw*Kx)
    pi = pi.reshape(I, Kw * Kx)
    mu = mu.reshape(I, Kw * Kx)
    sigma = sigma.reshape(I, Kw * Kx)

    #if normalize_pi:
    #    pi = pi.clamp_min(eps)
    #    pi = pi / pi.sum(dim=-1, keepdim=True).clamp_min(eps)
    
    # put to 0 weights smaller than eps, then renormalize
    mask = pi > eps
    pi = pi * mask
    pi = pi / pi.sum(dim=-1, keepdim=True)

    return pi, mu, sigma


In [8]:
I  = 3
Kx = 1

pi_x = torch.tensor([
    [1],   # x_0
    [1],   # x_1
    [1],   # x_2
])                      # shape: (I, Kx)

mu_x = torch.tensor([
    [ 1.0],
    [ 0.5],
    [-0.3],
])                      # shape: (I, Kx)

sigma_x = torch.tensor([
    [0.2],
    [0.1],
    [0.3],
])                      # shape: (I, Kx)

bnn = BNN_GMM(layer_sizes=[3, 2, 1], K=1, bias=True).to(device)

layer = bnn.layers[0]
j = 0  # neuron index

pi_y, mu_y, sigma_y = gm_product_per_neuron(
    layer.pi_w[j], layer.mu_w[j], layer.sigma_w[j],   # (I, Kw)
    pi_x, mu_x, sigma_x)

print(layer.pi_w[j])
print(layer.mu_w[j])
print(pi_y)
print(mu_y)
print(sigma_y)


tensor([[1.],
        [1.],
        [1.]], grad_fn=<SelectBackward0>)
tensor([[-0.4496],
        [-0.4856],
        [ 0.0758]], grad_fn=<SelectBackward0>)
tensor([[1.],
        [1.],
        [1.]], grad_fn=<DivBackward0>)
tensor([[-0.4496],
        [-0.2428],
        [-0.0227]], grad_fn=<ViewBackward0>)
tensor([[0.0946],
        [0.0507],
        [0.0258]], grad_fn=<ViewBackward0>)


In [9]:
bnn.layers[0].pi_b

Parameter containing:
tensor([[1.],
        [1.]], requires_grad=True)

In [None]:
def gm_product_per_neuron(pi_w_j, mu_w_j, sigma_w_j, pi_x, mu_x, sigma_x, eps=1e-12, normalize_pi=True):

    I, Kw = pi_w_j.shape
    I2, Kx = pi_x.shape
    assert I == I2

    # Broadcast to (I, Kw, Kx)
    pi = pi_w_j.unsqueeze(-1) * pi_x.unsqueeze(-2)
    mu = mu_w_j.unsqueeze(-1) * mu_x.unsqueeze(-2)

    var = (sigma_w_j.unsqueeze(-1)**2) * (sigma_x.unsqueeze(-2)**2) \
        + (sigma_w_j.unsqueeze(-1)**2) * (mu_x.unsqueeze(-2)**2) \
        + (sigma_x.unsqueeze(-2)**2) * (mu_w_j.unsqueeze(-1)**2)

    sigma = torch.sqrt(var.clamp_min(eps))

    # Flatten (Kw, Kx) -> (Kw*Kx)
    pi = pi.reshape(I, Kw * Kx)
    mu = mu.reshape(I, Kw * Kx)
    sigma = sigma.reshape(I, Kw * Kx)
    
    # put to 0 weights smaller than eps, then renormalize
    mask = pi > eps
    pi = pi * mask
    pi = pi / pi.sum(dim=-1, keepdim=True)

    return pi, mu, sigma

def gm_add(pi1, mu1, sg1, pi2, mu2, sg2, eps=1e-12):
    """
    Add two independent scalar GMs:
      (pi1, mu1, sg1): (K1,)
      (pi2, mu2, sg2): (K2,)
    Returns (pi, mu, sg): (<=max_components,)
    """
    # Cartesian product
    w = (pi1[:, None] * pi2[None, :]).reshape(-1)
    mu = (mu1[:, None] + mu2[None, :]).reshape(-1)
    var = (sg1[:, None]**2 + sg2[None, :]**2).reshape(-1)
    sg = torch.sqrt(var.clamp_min(eps))
   
    # put to 0 weights smaller than eps, then renormalize
    mask = w > eps
    w = w * mask
    w = w / w.sum(dim=-1, keepdim=True)

    return w, mu, sg


def sum_over_inputs_tree(mixtures, eps=1e-12):
    """
    Tree-reduce sum of I independent GMs (per-neuron).
    Inputs:
      pi_y, mu_y, sg_y: (I, K)  (K can be Kw*Kx from your product)
    Output:
      pi_z, mu_z, sg_z: (<=max_components,)
    """
    level = mixtures
    while len(level) > 1:
        new = []
        for i in range(0, len(level), 2):
            if i + 1 < len(level):
                pi1, mu1, sg1 = level[i]
                pi2, mu2, sg2 = level[i+1]
                new.append(gm_add(pi1, mu1, sg1, pi2, mu2, sg2,
                                            eps=eps))
            else:
                new.append(level[i])
        level = new
    return level[0]


def preactivation_gm_for_neuron_tree(layer, j, pi_x, mu_x, sg_x, eps=1e-12):
    """
    z_j = sum_i w_{j,i} x_i + b_j
    """
    # product per input i => (I, Kprod)
    pi, mu, sigma = gm_product_per_neuron(
        layer.pi_w[j], layer.mu_w[j], layer.sigma_w[j],
        pi_x, mu_x, sg_x,
        eps=eps, normalize_pi=True
    )
    mixtures = [(pi[i], mu[i], sigma[i]) for i in range(pi.shape[0])]

    # tree sum over i
    pi_z, mu_z, sg_z = sum_over_inputs_tree(mixtures, eps=eps)

    # add bias mixture
    if layer.bias:
        pi_z, mu_z, sg_z = gm_add(
            pi_z, mu_z, sg_z,
            layer.pi_b[j], layer.mu_b[j], layer.sigma_b[j], eps=eps
        )

    return pi_z, mu_z, sg_z

bnn = BNN_GMM(layer_sizes=[3, 2, 1], K=1, bias=True).to(device)
layer = bnn.layers[0]
pi_z, mu_z, sg_z = preactivation_gm_for_neuron_tree(layer, j=0, pi_x=pi_x, mu_x=mu_x, sg_x=sigma_x)

mu_z


tensor([0.2498], grad_fn=<ViewBackward0>)

NOW let's handle batches

In [None]:
def _safe_renorm(w, dim=-1, eps=1e-12):
    """
    Renormalize along `dim`, safely avoiding division by zero.
    If a row sums to 0 after thresholding, it leaves it as all-zeros.
    """
    s = w.sum(dim=dim, keepdim=True)
    return torch.where(s > 0, w / s.clamp_min(eps), w)


def gm_product_per_neuron(pi_w_j, mu_w_j, sigma_w_j,
                          pi_x, mu_x, sigma_x,
                          eps=1e-12, normalize_pi=True):
    """
    Product per input i (per-neuron), batched over B.

    Weights:
      pi_w_j, mu_w_j, sigma_w_j: (I, Kw)

    Inputs x:
      pi_x, mu_x, sigma_x: either (I, Kx) or (B, I, Kx)

    Returns:
      pi, mu, sigma: (B, I, Kw*Kx)  (or (I, Kw*Kx) if B not provided)
    """
    I, Kw = pi_w_j.shape

    # Make x batched: (B, I, Kx)
    if pi_x.dim() == 2:
        # (I, Kx) -> (1, I, Kx) (broadcastable)
        pi_x = pi_x.unsqueeze(0)
        mu_x = mu_x.unsqueeze(0)
        sigma_x = sigma_x.unsqueeze(0)

    B, I2, Kx = pi_x.shape
    assert I == I2, f"I mismatch: weights have {I}, inputs have {I2}"

    # Broadcast to (B, I, Kw, Kx)
    pi = pi_w_j.unsqueeze(0).unsqueeze(-1) * pi_x.unsqueeze(-2)
    mu = mu_w_j.unsqueeze(0).unsqueeze(-1) * mu_x.unsqueeze(-2)

    # Var for product of independent Gaussians (your formula)
    sw2 = (sigma_w_j.unsqueeze(0).unsqueeze(-1) ** 2)  # (B, I, Kw, 1)
    sx2 = (sigma_x.unsqueeze(-2) ** 2)                 # (B, I, 1, Kx)
    mw2 = (mu_w_j.unsqueeze(0).unsqueeze(-1) ** 2)     # (B, I, Kw, 1)
    mx2 = (mu_x.unsqueeze(-2) ** 2)                    # (B, I, 1, Kx)

    var = sw2 * sx2 + sw2 * mx2 + sx2 * mw2
    sigma = torch.sqrt(var.clamp_min(eps))

    # Flatten (Kw, Kx) -> (Kw*Kx): (B, I, Kprod)
    Kprod = Kw * Kx
    pi = pi.reshape(B, I, Kprod)
    mu = mu.reshape(B, I, Kprod)
    sigma = sigma.reshape(B, I, Kprod)

    if normalize_pi:
        # Threshold tiny weights, then renormalize per (B, I, :)
        pi = pi * (pi > eps)
        pi = _safe_renorm(pi, dim=-1, eps=eps)

    # If original x was unbatched, return unbatched to preserve old behavior
    if B == 1 and pi_w_j.dim() == 2 and (pi_x is not None) and (pi_x.shape[0] == 1):
        # NOTE: This keeps it compatible with your old code path;
        # if you prefer always-batched outputs, just delete this block.
        return pi.squeeze(0), mu.squeeze(0), sigma.squeeze(0)

    return pi, mu, sigma


def gm_add(pi1, mu1, sg1, pi2, mu2, sg2, eps=1e-12):
    """
    Add two independent scalar GMs, batch-aware.

    Supports:
      - (K1,) + (K2,) -> (K1*K2,)
      - (B,K1) + (B,K2) -> (B,K1*K2)
      - (B,K1) + (K2,)  -> bias broadcast across batch
      - (K1,)  + (B,K2) -> broadcast across batch
    """
    # Remember original batched-ness
    batched1 = (pi1.dim() == 2)
    batched2 = (pi2.dim() == 2)
    batched = batched1 or batched2

    # Promote to batched
    if pi1.dim() == 1:
        pi1, mu1, sg1 = pi1.unsqueeze(0), mu1.unsqueeze(0), sg1.unsqueeze(0)
    if pi2.dim() == 1:
        pi2, mu2, sg2 = pi2.unsqueeze(0), mu2.unsqueeze(0), sg2.unsqueeze(0)

    B, K1 = pi1.shape
    B2, K2 = pi2.shape

    # Broadcast batch dimension if one is 1
    if B != B2:
        if B == 1:
            pi1 = pi1.expand(B2, -1)
            mu1 = mu1.expand(B2, -1)
            sg1 = sg1.expand(B2, -1)
            B = B2
        elif B2 == 1:
            pi2 = pi2.expand(B, -1)
            mu2 = mu2.expand(B, -1)
            sg2 = sg2.expand(B, -1)
        else:
            raise AssertionError(f"Batch mismatch: {B} vs {B2}")

    # Cartesian product per batch
    w = (pi1.unsqueeze(-1) * pi2.unsqueeze(-2)).reshape(B, K1 * K2)
    mu = (mu1.unsqueeze(-1) + mu2.unsqueeze(-2)).reshape(B, K1 * K2)
    var = ((sg1.unsqueeze(-1) ** 2) + (sg2.unsqueeze(-2) ** 2)).reshape(B, K1 * K2)
    sg = torch.sqrt(var.clamp_min(eps))

    # Threshold + renorm per batch row
    w = w * (w > eps)
    s = w.sum(dim=-1, keepdim=True)
    w = torch.where(s > 0, w / s.clamp_min(eps), w)

    # Return unbatched if both inputs were unbatched
    if not batched:
        return w.squeeze(0), mu.squeeze(0), sg.squeeze(0)
    return w, mu, sg


def sum_over_inputs_tree(pi, mu, sg, eps=1e-12):
    """
    Tree-reduce sum over input dimension I, batch-aware.

    Inputs:
      pi, mu, sg: (B, I, K) or (I, K)

    Output:
      pi_z, mu_z, sg_z: (B, Kout) or (Kout,)
    """
    # Promote to batched if needed
    batched = (pi.dim() == 3)
    if pi.dim() == 2:
        pi, mu, sg = pi.unsqueeze(0), mu.unsqueeze(0), sg.unsqueeze(0)

    B, I, K = pi.shape

    # Make a "level" list over i, each element is (B, K)
    level = [(pi[:, i, :], mu[:, i, :], sg[:, i, :]) for i in range(I)]

    while len(level) > 1:
        new = []
        for i in range(0, len(level), 2):
            if i + 1 < len(level):
                pi1, mu1, sg1 = level[i]
                pi2, mu2, sg2 = level[i + 1]
                new.append(gm_add(pi1, mu1, sg1, pi2, mu2, sg2, eps=eps))
            else:
                new.append(level[i])
        level = new

    pi_z, mu_z, sg_z = level[0]

    if not batched:
        return pi_z.squeeze(0), mu_z.squeeze(0), sg_z.squeeze(0)
    return pi_z, mu_z, sg_z


def preactivation_gm_for_neuron_tree(layer, j, pi_x, mu_x, sg_x, eps=1e-12):
    """
    z_j = sum_i w_{j,i} x_i + b_j
    Batch-aware if pi_x/mu_x/sg_x are (B, I, Kx).
    """
    # product per input i
    pi, mu, sigma = gm_product_per_neuron(
        layer.pi_w[j], layer.mu_w[j], layer.sigma_w[j],
        pi_x, mu_x, sg_x,
        eps=eps, normalize_pi=True
    )

    # tree sum over i
    pi_z, mu_z, sg_z = sum_over_inputs_tree(pi, mu, sigma, eps=eps)

    # add bias mixture
    if layer.bias:
        pi_z, mu_z, sg_z = gm_add(
            pi_z, mu_z, sg_z,
            layer.pi_b[j], layer.mu_b[j], layer.sigma_b[j],
            eps=eps
        )

    return pi_z, mu_z, sg_z


In [None]:
device = "cpu"

# Batch size 2, 3 inputs, 1 Gaussian each
pi_x = torch.ones(2, 3, 1, device=device)        # deterministic mixture
mu_x = torch.tensor([
    [[1.0], [2.0], [3.0]],   # batch 0
    [[-1.0], [0.5], [1.5]]   # batch 1
], device=device)

sigma_x = 0.01 * torch.ones_like(mu_x)


In [None]:
class FakeLayer:
    def __init__(self, I, Kw=1, bias=True, device="cpu"):
        self.bias = bias

        # weights: (neurons=1, I, Kw)
        self.pi_w = torch.ones(1, I, Kw, device=device)
        self.mu_w = torch.tensor(
            [[[0.5], [-1.0], [2.0]]], device=device
        )
        self.sigma_w = 0.5 * torch.ones(1, I, Kw, device=device)

        if bias:
            self.pi_b = torch.ones(1, Kw, device=device)
            self.mu_b = torch.tensor([[0.1]], device=device)
            self.sigma_b = torch.tensor([[0.05]], device=device)

layer = FakeLayer(I=3, Kw=1, bias=True, device=device)

pi_z, mu_z, sg_z = preactivation_gm_for_neuron_tree(
    layer,
    j=0,
    pi_x=pi_x,
    mu_x=mu_x,
    sg_x=sigma_x
)

print("pi_z:", pi_z)
print("mu_z:", mu_z)
print("sg_z:", sg_z)



pi_z: tensor([[1.],
        [1.]])
mu_z: tensor([[4.6000],
        [2.1000]])
sg_z: tensor([[0.0500],
        [0.0500]])


In [29]:
import torch
from torch.distributions.normal import Normal

def truncate_0(pi, mu, sigma, eps=1e-12):
    """
    Distribution of Y = max(X, 0) when X is a Gaussian mixture.
    Output is a mixture of:
      - truncated normals on (0, +inf)
      - plus a delta at 0 (represented as sigma=0 at mu=0)
    """
    std = Normal(0.0, 1.0)
    new_mixture = []
    prob_zero = torch.tensor(0.0, device=pi.device)

    for k in range(pi.numel()):
        alpha = -mu[k] / sigma[k]  # threshold in standard-normal units

        # P(X>0) = 1 - Phi(alpha)
        prob_mass = 1 - std.cdf(alpha)

        # NOTE: if prob_mass is extremely tiny, moments can blow up numerically.
        prob_mass = prob_mass.clamp_min(eps)

        # phi(alpha) is the standard normal pdf at alpha
        phi = std.log_prob(alpha).exp()

        # Truncated normal mean/var for truncation at 0 from below
        ratio = phi / prob_mass
        mu_trunc = mu[k] + sigma[k] * ratio
        var_trunc = sigma[k]**2 * (1 + alpha * ratio - ratio**2)
        sigma_trunc = torch.sqrt(var_trunc.clamp_min(eps))

        pi_trunc = pi[k] * prob_mass
        new_mixture.append((pi_trunc, mu_trunc, sigma_trunc))

        # P(Y=0) accumulates the mass that was <=0
        prob_zero = prob_zero + pi[k] * (1 - prob_mass)

    # Add the delta at zero
    new_mixture.append((prob_zero,
                        torch.tensor(0.0, device=pi.device),
                        torch.tensor(0.0, device=pi.device)))

    pi_new = torch.stack([c[0] for c in new_mixture])
    mu_new = torch.stack([c[1] for c in new_mixture])
    sigma_new = torch.stack([c[2] for c in new_mixture])

    # threshold + renorm
    pi_new = pi_new * (pi_new > eps)
    s = pi_new.sum(dim=-1, keepdim=True)
    pi_new = torch.where(s > 0, pi_new / s.clamp_min(eps), pi_new)

    return pi_new, mu_new, sigma_new


# --- helpers to test correctness ---

def sample_gmm(pi, mu, sigma, N, generator=None):
    """
    Sample X from a 1D Gaussian mixture defined by (pi, mu, sigma).
    Shapes: pi,mu,sigma = (K,)
    Returns: (N,)
    """
    K = pi.numel()
    cat = torch.distributions.Categorical(probs=pi)
    k = cat.sample((N,))  # component indices
    eps = torch.randn(N, device=pi.device,)
    return mu[k] + sigma[k] * eps


def mixture_moments(pi, mu, sigma):
    """
    Compute mean and variance of a 1D Gaussian mixture.
    Treats any sigma=0 component as a delta (still works).
    """
    mean = (pi * mu).sum()
    second = (pi * (sigma**2 + mu**2)).sum()
    var = second - mean**2
    return mean, var


# --- a simple test you can run ---

device = "cpu"
torch.set_default_dtype(torch.float64)

# A small GMM (3 components) including negative/positive means (good stress test)
pi = torch.tensor([0.2, 0.5, 0.3], device=device)
mu = torch.tensor([-1.5, 0.2, 2.0], device=device)
sigma = torch.tensor([0.7, 0.4, 0.8], device=device)

# Compute truncated mixture
pi_t, mu_t, sigma_t = truncate_0(pi, mu, sigma)

# Sanity checks (must pass)
print("weights sum:", pi_t.sum().item())
print("all weights >= 0:", bool((pi_t >= 0).all()))
print("delta at zero weight (last comp):", pi_t[-1].item(), "mu:", mu_t[-1].item(), "sigma:", sigma_t[-1].item())

# Compare moments to Monte Carlo
N = 2_000_000

x = sample_gmm(pi, mu, sigma, N)
y = torch.clamp(x, min=0.0)  # Y = max(X,0)

p0_mc = (y == 0).double().mean()
mean_mc = y.mean()
var_mc = y.var(unbiased=False)

mean_mix, var_mix = mixture_moments(pi_t, mu_t, sigma_t)

print("\nMonte Carlo vs mixture implied:")
print("P(Y=0):", p0_mc.item(), " vs ", pi_t[-1].item())
print("E[Y]:  ", mean_mc.item(), " vs ", mean_mix.item())
print("Var[Y]:", var_mc.item(),  " vs ", var_mix.item())


weights sum: 1.0
all weights >= 0: True
delta at zero weight (last comp): 0.3529192118399606 mu: 0.0 sigma: 0.0

Monte Carlo vs mixture implied:
P(Y=0): 0.3527345  vs  0.3529192118399606
E[Y]:   0.7405637686712465  vs  0.74084416998383
Var[Y]: 0.9248443516899084  vs  0.9265168107206846


In [30]:
relu_mixture = truncate_0(torch.tensor([1.]), torch.tensor([0.]), torch.tensor([1.]))

relu_mixture

(tensor([0.5000, 0.5000]), tensor([0.7979, 0.0000]), tensor([0.6028, 0.0000]))

In [31]:
import torch
from torch.distributions.normal import Normal

def truncate_0_vectorized(pi, mu, sigma, eps=1e-12):
    """
    Vectorized + batch-friendly version.

    Computes distribution of Y = max(X, 0) where X is a Gaussian mixture.

    Inputs:
      pi, mu, sigma: (..., K)
    Returns:
      pi_new, mu_new, sigma_new: (..., K+1)
        - first K components: truncated normals on (0, +inf)
        - last component: delta at 0 (mu=0, sigma=0)
    """
    # Standard normal (broadcasts over tensors)
    std = Normal(torch.tensor(0.0, device=pi.device, dtype=pi.dtype),
                 torch.tensor(1.0, device=pi.device, dtype=pi.dtype))

    # Avoid division by 0 in alpha
    sigma_safe = sigma.clamp_min(eps)

    # alpha = (0 - mu)/sigma = -mu/sigma
    alpha = -mu / sigma_safe

    # prob_mass = P(X > 0) = 1 - Phi(alpha)
    prob_mass = (1.0 - std.cdf(alpha)).clamp_min(eps)

    # phi(alpha) = standard normal pdf at alpha
    phi = std.log_prob(alpha).exp()

    # ratio = phi / prob_mass
    ratio = phi / prob_mass

    # Truncated component moments (lower truncation at 0)
    mu_trunc = mu + sigma_safe * ratio
    var_trunc = sigma_safe**2 * (1.0 + alpha * ratio - ratio**2)
    sigma_trunc = torch.sqrt(var_trunc.clamp_min(eps))

    # New weights for truncated components
    pi_trunc = pi * prob_mass

    # Mass that collapses to 0 (the "clamped" part)
    prob_zero = (pi * (1.0 - prob_mass)).sum(dim=-1, keepdim=True)  # (..., 1)

    # Append delta-at-zero as last component
    pi_new = torch.cat([pi_trunc, prob_zero], dim=-1)               # (..., K+1)

    zeros = torch.zeros_like(prob_zero)
    mu_new = torch.cat([mu_trunc, zeros], dim=-1)                   # (..., K+1)
    sigma_new = torch.cat([sigma_trunc, zeros], dim=-1)             # (..., K+1)

    # Threshold + renormalize along last dim
    pi_new = pi_new * (pi_new > eps)
    s = pi_new.sum(dim=-1, keepdim=True)
    pi_new = torch.where(s > 0, pi_new / s.clamp_min(eps), pi_new)

    return pi_new, mu_new, sigma_new



# Example batched test
B, K = 2, 3
pi = torch.tensor([[0.2, 0.5, 0.3],
                   [0.1, 0.2, 0.7]], dtype=torch.float64)
mu = torch.tensor([[-1.5, 0.2, 2.0],
                   [-0.3, 0.1, 1.0]], dtype=torch.float64)
sigma = torch.tensor([[0.7, 0.4, 0.8],
                      [0.5, 0.2, 0.3]], dtype=torch.float64)

pi_new, mu_new, sigma_new = truncate_0_vectorized(pi, mu, sigma)
print(pi_new.shape, mu_new.shape, sigma_new.shape)  # (2, 4)
print("delta weights:", pi_new[:, -1])              # P(Y=0) per batch


torch.Size([2, 4]) torch.Size([2, 4]) torch.Size([2, 4])
delta weights: tensor([0.3529, 0.1346])


In [32]:
pi_z, mu_z, sg_z = preactivation_gm_for_neuron_tree(
    layer,
    j=0,
    pi_x=pi_x,
    mu_x=mu_x,
    sg_x=sigma_x
)
print("Before ReLU:")
print("pi_z:", pi_z)
print("mu_z:", mu_z)
print("sg_z:", sg_z)
pi_a, mu_a, sg_a = truncate_0_vectorized(pi_z, mu_z, sg_z)
print("\nAfter ReLU:")
print("pi_a:", pi_a)
print("mu_a:", mu_a)
print("sg_a:", sg_a)


Before ReLU:
pi_z: tensor([[1.],
        [1.]], dtype=torch.float32)
mu_z: tensor([[4.6000],
        [2.1000]], dtype=torch.float32)
sg_z: tensor([[0.0500],
        [0.0500]], dtype=torch.float32)

After ReLU:
pi_a: tensor([[1., 0.],
        [1., 0.]], dtype=torch.float32)
mu_a: tensor([[4.6000, 0.0000],
        [2.1000, 0.0000]], dtype=torch.float32)
sg_a: tensor([[0.0500, 0.0000],
        [0.0500, 0.0000]], dtype=torch.float32)
