In [1]:
from representation.bayesian_nn import BNN_GMM
import torch

In [2]:
device = "cpu"

#DATASET
pi_x = torch.ones(2, 3, 1, device=device)        # deterministic mixture
mu_x = torch.tensor([
    [[1.0], [2.0], [3.0]],   # batch 0
    [[-1.0], [0.5], [1.5]]   # batch 1
], device=device)

sigma_x = 0.01 * torch.ones_like(mu_x)

#BNN FORWARD
bnn = BNN_GMM(layer_sizes=[3, 10, 1], K=1, bias=True).to(device)

pi_out, mu_out, sg_out = bnn(pi_x, mu_x, sigma_x, max_components=100)
print(pi_out.shape)
print(mu_out.shape)
print(sg_out.shape)

#count how many components have non-zero weight
non_zero_components = (pi_out > 1e-12).sum(dim=2)
print("Non-zero components per batch sample:", non_zero_components)

torch.Size([2, 1, 4])
torch.Size([2, 1, 4])
torch.Size([2, 1, 4])
Non-zero components per batch sample: tensor([[2],
        [4]])


In [3]:
def gmm_nll(pi, mu, sg, y, eps=1e-12):
    """
    Negative log-likelihood of targets y under a Gaussian mixture.
    Shapes:
      pi,mu,sg: (B, D, K)
      y:        (B, D) or (B, D, 1)

    Returns:
      scalar loss
    """
    if y.dim() == 2:
        y = y.unsqueeze(-1)  # (B, D, 1)

    sg = sg.clamp_min(eps)
    logpi = torch.log(pi.clamp_min(eps))  # (B, D, K)

    # log N(y | mu, sg)
    log_norm = -0.5 * torch.log(2.0 * torch.pi * (sg ** 2))
    log_exp = -0.5 * ((y - mu) / sg) ** 2
    logp = log_norm + log_exp  # (B, D, K)

    # logsumexp over components
    log_mix = torch.logsumexp(logpi + logp, dim=-1)  # (B, D)
    return (-log_mix).mean()


In [4]:
@torch.no_grad()
def project_bnn_params_(bnn, eps=1e-12):
    for layer in bnn.layers:
        # keep sigmas positive
        layer.sigma_w.clamp_(min=eps)
        if layer.bias:
            layer.sigma_b.clamp_(min=eps)

        # keep pis valid probabilities
        layer.pi_w.clamp_(min=0.0)
        layer.pi_w.div_(layer.pi_w.sum(dim=-1, keepdim=True).clamp_min(eps))

        if layer.bias:
            layer.pi_b.clamp_(min=0.0)
            layer.pi_b.div_(layer.pi_b.sum(dim=-1, keepdim=True).clamp_min(eps))


In [5]:
def train_bnn_gmm_regression(
    bnn,
    dataloader,
    epochs=10,
    lr=1e-3,
    eps=1e-12,
    max_components=None,
    last_relu=False,
    device="cpu",
):
    bnn.train()
    opt = torch.optim.Adam(bnn.parameters(), lr=lr)

    for ep in range(epochs):
        total = 0.0
        n_batches = 0

        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            # Convert deterministic inputs x (B, d_in) to GMM input (B, d_in, 1)
            pi_x = torch.ones(x.shape[0], x.shape[1], 1, device=device, dtype=x.dtype)
            mu_x = x.unsqueeze(-1)
            sg_x = 1e-2 * torch.ones_like(mu_x)  # choose your input noise

            # Forward through mixture network
            pi_out, mu_out, sg_out = bnn(
                pi_x, mu_x, sg_x,
                eps=eps,
                last_relu=last_relu,
                max_components=100
            )

            # Loss (regression NLL)
            loss = gmm_nll(pi_out, mu_out, sg_out, y, eps=eps)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            project_bnn_params_(bnn, eps=eps)

            total += loss.item()
            n_batches += 1

        print(f"epoch {ep+1}/{epochs} | loss {total / max(1, n_batches):.6f}")


In [6]:
from torch.utils.data import TensorDataset, DataLoader

# Example: y = x0 - 2*x1 + 0.5*x2
B = 512
d_in = 3
X = torch.randn(B, d_in)
y = (X[:, 0] - 2.0 * X[:, 1] + 0.5 * X[:, 2]).unsqueeze(1)  # (B,1)

ds = TensorDataset(X, y)
dl = DataLoader(ds, batch_size=64, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
bnn = BNN_GMM(layer_sizes=[3, 4, 1], K=1, bias=True).to(device)

train_bnn_gmm_regression(bnn, dl, epochs=20, lr=1e-2, device=device)


epoch 1/20 | loss 226.037613
epoch 2/20 | loss 56.191576
epoch 3/20 | loss 40.072413
epoch 4/20 | loss 32.063205
epoch 5/20 | loss 27.285764
epoch 6/20 | loss 24.485750
epoch 7/20 | loss 22.284459
epoch 8/20 | loss 20.475117
epoch 9/20 | loss 19.025531
epoch 10/20 | loss 17.837642
epoch 11/20 | loss 16.763264
epoch 12/20 | loss 15.797644
epoch 13/20 | loss 14.897493
epoch 14/20 | loss 14.049182
epoch 15/20 | loss 13.282319
epoch 16/20 | loss 12.586281
epoch 17/20 | loss 11.913484
epoch 18/20 | loss 11.274634
epoch 19/20 | loss 10.740121
epoch 20/20 | loss 10.250258


In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split

# ---------- metrics helpers ----------

def classification_accuracy_from_gmm_output(pi, mu, y_onehot):
    """
    pi, mu: (B, 10, K)
    y_onehot: (B, 10)
    """
    # Mixture mean per class => (B,10)
    y_hat = (pi * mu).sum(dim=-1)

    pred = torch.argmax(y_hat, dim=1)          # (B,)
    true = torch.argmax(y_onehot, dim=1)       # (B,)
    return (pred == true).float().mean().item()

def mixture_mean(pi, mu):
    # pi, mu: (B, D, K) -> (B, D)
    return (pi * mu).sum(dim=-1)

def mixture_variance(pi, mu, sg, eps=1e-12):
    mean = (pi * mu).sum(dim=-1)
    second = (pi * (sg**2 + mu**2)).sum(dim=-1)
    return second - mean**2

def regression_metrics_from_output(pi, mu, sg, y, eps=1e-12):
    """
    pi,mu,sg: (B,D,K)
    y:        (B,D)
    """
    y_hat = mixture_mean(pi, mu)
    mse = ((y_hat - y) ** 2).mean()
    rmse = torch.sqrt(mse)
    mae = (y_hat - y).abs().mean()

    var = mixture_variance(pi, mu, sg, eps=eps)
    std = torch.sqrt(var.clamp_min(eps))
    cov_1sigma = (((y >= y_hat - std) & (y <= y_hat + std)).float().mean())

    return {
        "MSE": mse.item(),
        "RMSE": rmse.item(),
        "MAE": mae.item(),
        "Cov@1σ": cov_1sigma.item(),
    }

def pad_gmm_to_K(pi, mu, sg, K_target, eps=1e-12):
    B, D, K = pi.shape
    if K == K_target:
        return pi, mu, sg
    if K > K_target:
        # better not to happen if K_target = global max
        pi = pi[:, :, :K_target]
        mu = mu[:, :, :K_target]
        sg = sg[:, :, :K_target]
        s = pi.sum(dim=-1, keepdim=True)
        pi = torch.where(s > eps, pi / s, pi)
        return pi, mu, sg

    pad = K_target - K
    pi_pad = torch.zeros(B, D, pad, device=pi.device, dtype=pi.dtype)
    mu_pad = torch.zeros(B, D, pad, device=mu.device, dtype=mu.dtype)
    sg_pad = torch.ones(B, D, pad, device=sg.device, dtype=sg.dtype)
    return (
        torch.cat([pi, pi_pad], dim=-1),
        torch.cat([mu, mu_pad], dim=-1),
        torch.cat([sg, sg_pad], dim=-1),
    )




def train_test_split_train_and_eval(
    bnn,
    X, y,
    test_ratio=0.2,
    batch_size=64,
    epochs=20,
    lr=1e-3,
    input_sigma=1e-2,
    eps=1e-12,
    last_relu=False,
    device=None,
    seed=0,
    max_components=None,
):
    """
    Takes bnn + data, splits into train/test, trains, then evaluates on test at the end.

    X: (N, d_in)
    y: (N, d_out)
    Returns: dict with test metrics (including NLL)
    """
    if device is None:
        device = next(bnn.parameters()).device

    # Dataset + split
    ds = TensorDataset(X, y)
    n_total = len(ds)
    n_test = int(round(test_ratio * n_total))
    n_train = n_total - n_test

    g = torch.Generator().manual_seed(seed)
    ds_train, ds_test = random_split(ds, [n_train, n_test], generator=g)

    dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
    dl_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False)

    # Optimizer
    opt = torch.optim.Adam(bnn.parameters(), lr=lr)

    # ---- train ----
    bnn.train()
    for ep in range(epochs):
        running = 0.0
        n_batches = 0

        for xb, yb in dl_train:
            xb = xb.to(device)
            yb = yb.to(device)

            # Deterministic inputs -> GM inputs (B, d_in, 1)
            pi_x = torch.ones(xb.shape[0], xb.shape[1], 1, device=device, dtype=xb.dtype)
            mu_x = xb.unsqueeze(-1)
            sg_x = input_sigma * torch.ones_like(mu_x)

            pi_out, mu_out, sg_out = bnn(pi_x, mu_x, sg_x, eps=eps, last_relu=last_relu, max_components = max_components)
            print("sg min/median:", sg_out.min().item(), sg_out.median().item())

            loss = gmm_nll(pi_out, mu_out, sg_out, yb, eps=eps)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            project_bnn_params_(bnn, eps=eps)

            running += loss.item()
            n_batches += 1

            print(f"epoch {ep+1}/{epochs} | train NLL {running / max(1,n_batches):.6f}")

    # ---- test eval (only at end) ----
    bnn.eval()
    all_pi, all_mu, all_sg, all_y = [], [], [], []
    K_max = 0

    with torch.no_grad():
        for xb, yb in dl_test:
            xb = xb.to(device)
            yb = yb.to(device)

            pi_x = torch.ones(xb.shape[0], xb.shape[1], 1, device=device, dtype=xb.dtype)
            mu_x = xb.unsqueeze(-1)
            sg_x = input_sigma * torch.ones_like(mu_x)

            pi_out, mu_out, sg_out = bnn(
                pi_x, mu_x, sg_x, eps=eps, last_relu=last_relu, max_components=max_components
            )

            K_max = max(K_max, pi_out.shape[-1])

            all_pi.append(pi_out)
            all_mu.append(mu_out)
            all_sg.append(sg_out)
            all_y.append(yb)

    # pad to the largest K seen in the whole test set
    all_pi2, all_mu2, all_sg2 = [], [], []
    for pi_b, mu_b, sg_b in zip(all_pi, all_mu, all_sg):
        pi_b, mu_b, sg_b = pad_gmm_to_K(pi_b, mu_b, sg_b, K_max, eps=eps)
        all_pi2.append(pi_b)
        all_mu2.append(mu_b)
        all_sg2.append(sg_b)

    pi_out = torch.cat(all_pi2, dim=0)
    mu_out = torch.cat(all_mu2, dim=0)
    sg_out = torch.cat(all_sg2, dim=0)
    y_test = torch.cat(all_y, dim=0)


    test_nll = gmm_nll(pi_out, mu_out, sg_out, y_test, eps=eps).item()
    test_metrics = regression_metrics_from_output(pi_out, mu_out, sg_out, y_test, eps=eps)
    test_metrics["NLL"] = test_nll

    acc = classification_accuracy_from_gmm_output(pi_out, mu_out, y_test)
    test_metrics["Acc"] = acc
    print("TEST Acc:", acc)


    print("\nTEST metrics:", test_metrics)
    return test_metrics


In [8]:
# X: (N,3), y: (N,1)
metrics = train_test_split_train_and_eval(
    bnn, X, y,
    test_ratio=0.2,
    epochs=30,
    lr=1e-2,
    batch_size=64,
    input_sigma=1e-2,
    max_components=100,

)


epoch 1/30 | train NLL 8.870159
epoch 1/30 | train NLL 8.565257
epoch 1/30 | train NLL 8.478438
epoch 1/30 | train NLL 7.542206
epoch 1/30 | train NLL 7.145145
epoch 1/30 | train NLL 7.034742
epoch 1/30 | train NLL 6.804343
epoch 2/30 | train NLL 4.109103
epoch 2/30 | train NLL 3.626495
epoch 2/30 | train NLL 3.439876
epoch 2/30 | train NLL 3.211209
epoch 2/30 | train NLL 3.215924
epoch 2/30 | train NLL 3.187468
epoch 2/30 | train NLL 2.974972
epoch 3/30 | train NLL 2.734701
epoch 3/30 | train NLL 2.306779
epoch 3/30 | train NLL 2.163274
epoch 3/30 | train NLL 2.068856
epoch 3/30 | train NLL 2.095836
epoch 3/30 | train NLL 2.041606
epoch 3/30 | train NLL 1.972659
epoch 4/30 | train NLL 1.836417
epoch 4/30 | train NLL 1.836692
epoch 4/30 | train NLL 1.766028
epoch 4/30 | train NLL 1.734610
epoch 4/30 | train NLL 1.733114
epoch 4/30 | train NLL 1.666056
epoch 4/30 | train NLL 1.595639
epoch 5/30 | train NLL 1.383654
epoch 5/30 | train NLL 1.468257
epoch 5/30 | train NLL 1.435868
epoch 5/

In [9]:
import torch
from torchvision import datasets, transforms

def load_mnist_onehot_tensors(
    root="./data",
    train=True,
    n_max=None,
    normalize=True,
    device="cpu",
    dtype=torch.float32,
):
    """
    Returns:
      X: (N, 784) float
      y: (N, 10)  one-hot float
    """
    if normalize:
        tfm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),  # standard MNIST stats
        ])
    else:
        tfm = transforms.ToTensor()

    ds = datasets.MNIST(root=root, train=train, download=True, transform=tfm)

    if n_max is None:
        n_max = len(ds)
    n_max = min(n_max, len(ds))

    X_list = []
    y_list = []

    for i in range(n_max):
        x_i, y_i = ds[i]              # x_i: (1,28,28), y_i: int in [0..9]
        X_list.append(x_i.view(-1))   # -> (784,)
        y_list.append(int(y_i))

    X = torch.stack(X_list, dim=0).to(device=device, dtype=dtype)  # (N,784)

    y_idx = torch.tensor(y_list, device=device, dtype=torch.long)  # (N,)
    y = torch.zeros((n_max, 10), device=device, dtype=dtype)
    y.scatter_(1, y_idx.unsqueeze(1), 1.0)                          # (N,10) one-hot

    return X, y



def run_mnist_classification_with_your_trainer(
    bnn,
    device=None,
    root="./data",
    n_max_train=1000,
    test_ratio=0.2,
    batch_size=128,
    epochs=10,
    lr=1e-3,
    input_sigma=1e-2,
    eps=1e-12,
    last_relu=False,
    seed=0,
    normalize=True,
    max_components=None,
):
    if device is None:
        device = next(bnn.parameters()).device
    bnn.to(device)

    # Load MNIST train split (we'll do train/test split inside your function)
    X, y = load_mnist_onehot_tensors(
        root=root,
        train=True,
        n_max=n_max_train,
        normalize=normalize,
        device=device,
        dtype=torch.float32,
    )

    # Train + final test metrics (your function)
    test_metrics = train_test_split_train_and_eval(
        bnn=bnn,
        X=X,
        y=y,
        test_ratio=test_ratio,
        batch_size=batch_size,
        epochs=epochs,
        lr=lr,
        input_sigma=input_sigma,
        eps=eps,
        last_relu=last_relu,
        device=device,
        seed=seed,
        max_components=max_components,
    )

    return test_metrics



In [10]:
# Example architecture
# NOTE: pick sizes you like; this is a reasonable start.
layer_sizes = [784, 10, 5, 10]
K = 5

bnn = BNN_GMM(layer_sizes=layer_sizes, K=K, bias=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
bnn.to(device)

metrics = run_mnist_classification_with_your_trainer(
    bnn=bnn,
    device=device,
    epochs=10,
    batch_size=12,
    lr=1e-4,
    input_sigma=1e-2,
    last_relu=False,     # keep last layer linear (recommended)
    normalize=True,
    max_components=50,
)

print("Final test metrics:", metrics)


epoch 1/10 | train NLL 61.594402
epoch 1/10 | train NLL 62.799814
epoch 1/10 | train NLL 67.541935
epoch 1/10 | train NLL 72.295737
epoch 1/10 | train NLL 74.231345
epoch 1/10 | train NLL 77.124811
epoch 1/10 | train NLL 77.774408
epoch 1/10 | train NLL 77.481841
epoch 1/10 | train NLL 77.096055
epoch 1/10 | train NLL 75.528369
epoch 1/10 | train NLL 75.227519
epoch 1/10 | train NLL 75.122584
epoch 1/10 | train NLL 74.673053
epoch 1/10 | train NLL 74.488170
epoch 1/10 | train NLL 74.376057
epoch 1/10 | train NLL 73.617378
epoch 1/10 | train NLL 72.905398
epoch 1/10 | train NLL 71.687535
epoch 1/10 | train NLL 70.749522
epoch 1/10 | train NLL 70.066523
epoch 1/10 | train NLL 68.918155
epoch 1/10 | train NLL 67.943067
epoch 1/10 | train NLL 67.215976
epoch 1/10 | train NLL 66.552923
epoch 1/10 | train NLL 65.830845
epoch 1/10 | train NLL 65.368674
epoch 1/10 | train NLL 64.725141
epoch 1/10 | train NLL 64.287114
epoch 1/10 | train NLL 63.507536
epoch 1/10 | train NLL 62.576983
epoch 1/10

In [14]:
bnn.layers[1].pi_w.shape

torch.Size([5, 10, 5])