# GLLVM Simulations

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from gllvm import GLLVM, PoissonGLM, BinomialGLM, GaussianGLM

seed = 123

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ----------------------------------------------------------
# 1. Generate synthetic data from the ground-truth model
# ----------------------------------------------------------

gllvm0 = GLLVM(latent_dim=1, output_dim=5)
gllvm0.add_glm(PoissonGLM, idx=[0, 1], name="Poisson1")
gllvm0.add_glm(BinomialGLM, idx=[2], name="Binomial1")
gllvm0.add_glm(GaussianGLM, idx=[3, 4], name="Gaussian1")

num_samples = 10_000
z0 = gllvm0.sample_z(num_samples)
y0 = gllvm0.sample(z=z0)

# ----------------------------------------------------------
# 2. Build a fresh model that will be trained with VI
# ----------------------------------------------------------

gllvm = GLLVM(latent_dim=1, output_dim=5)
gllvm.add_glm(PoissonGLM, idx=[0, 1], name="Poisson1")
gllvm.add_glm(BinomialGLM, idx=[2], name="Binomial1")
gllvm.add_glm(GaussianGLM, idx=[3, 4], name="Gaussian1")

# ----------------------------------------------------------
# 3. Build the encoder q(z|y)
# ----------------------------------------------------------

class Encoder(nn.Module):
    def __init__(self, input_dim=5, latent_dim=1, hidden=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        self.mean = nn.Linear(hidden, latent_dim)
        self.logvar = nn.Linear(hidden, latent_dim)

    def forward(self, y):
        h = self.net(y)
        mu = self.mean(h)
        logvar = self.logvar(h)
        return mu, logvar

encoder = Encoder(input_dim=5, latent_dim=1)


# --- separate gllvm parameters ---
gllvm_scale = [gllvm.scale]   # this is learned by ELBO
gllvm_no_scale = []

for name, p in gllvm.named_parameters():
    if name != "scale":
        gllvm_no_scale.append(p)
        
        
optimizer_gllvm = optim.Adam(gllvm_no_scale, lr=1e-4)
optimizer_encoder = optim.Adam(list(encoder.parameters()) + gllvm_scale, lr=1e-4)

# ----------------------------------------------------------
# 4. VI training loop: ZQE + ELBO Encoder
# ----------------------------------------------------------

batch_size = 256
num_epochs = 500

dataset = y0
n = len(dataset)

for epoch in range(num_epochs):
    perm = torch.randperm(n)
    total_elbo = 0.0

    for i in range(0, n, batch_size):
        idx = perm[i:i+batch_size]
        y = dataset[idx]

        # ======================================================
        # 1. ENCODER UPDATE (phi) using ELBO ONLY
        # ======================================================
        optimizer_encoder.zero_grad()

        # forward encoder
        mu, logvar = encoder(y)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std

        # decoder log-likelihood for ELBO (no detach)
        logpy_z = gllvm.log_prob(y, z=z).sum(dim=-1)

        # KL(q||p)
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)

        elbo = (logpy_z - kl).mean()
        loss_elbo = -elbo

        loss_elbo.backward()
        optimizer_encoder.step()

        # ======================================================
        # 2. DECODER UPDATE (theta) using CENTERED ZQ LOSS ONLY
        # ======================================================

        optimizer_gllvm.zero_grad()

        # ---- sample z ~ q(z|y) but encoder is frozen ----
        with torch.no_grad():
            mu, logvar = encoder(y)
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mu + eps * std  # detached latent samples from q

        # ---- sample model-generated pairs (zq,yq) ----
        with torch.no_grad():
            zq0 = gllvm.sample_z(num_samples=len(y))   # prior
            yq = gllvm.sample(z=zq0)                  # model sample

            muq, logvarq = encoder(yq)
            stdq = torch.exp(0.5 * logvarq)
            epsq = torch.randn_like(stdq)
            zq = muq + epsq * stdq                    # detached

        # ---- compute the centered ZQ estimating function ----
        logpy_zqe  = gllvm.zq_log(y,  z=z).sum(dim=-1)
        logpy_zqe2 = gllvm.zq_log(yq, z=zq).sum(dim=-1)

        # ZQ loss = -(m_q - m_q_model)
        loss_zqe = -(logpy_zqe.mean() - logpy_zqe2.mean())

        loss_zqe.backward()
        optimizer_gllvm.step()

        total_elbo += elbo.item()

    print(f"Epoch {epoch+1}: ELBO={total_elbo:.2f}")




Epoch 1: ELBO=-350.64
Epoch 2: ELBO=-340.26
Epoch 3: ELBO=-330.61
Epoch 4: ELBO=-323.37
Epoch 5: ELBO=-312.27
Epoch 6: ELBO=-304.79
Epoch 7: ELBO=-297.22
Epoch 8: ELBO=-289.47
Epoch 9: ELBO=-285.26
Epoch 10: ELBO=-280.60
Epoch 11: ELBO=-276.68
Epoch 12: ELBO=-275.71
Epoch 13: ELBO=-273.43
Epoch 14: ELBO=-272.28
Epoch 15: ELBO=-271.75
Epoch 16: ELBO=-269.73
Epoch 17: ELBO=-270.53
Epoch 18: ELBO=-270.81
Epoch 19: ELBO=-270.25
Epoch 20: ELBO=-270.38
Epoch 21: ELBO=-268.44
Epoch 22: ELBO=-268.81
Epoch 23: ELBO=-269.56
Epoch 24: ELBO=-269.10
Epoch 25: ELBO=-267.94
Epoch 26: ELBO=-268.77
Epoch 27: ELBO=-268.68
Epoch 28: ELBO=-268.94
Epoch 29: ELBO=-268.42
Epoch 30: ELBO=-268.24
Epoch 31: ELBO=-267.65
Epoch 32: ELBO=-269.56
Epoch 33: ELBO=-268.56
Epoch 34: ELBO=-267.16
Epoch 35: ELBO=-267.36
Epoch 36: ELBO=-268.15
Epoch 37: ELBO=-267.55
Epoch 38: ELBO=-267.52
Epoch 39: ELBO=-266.88
Epoch 40: ELBO=-266.43
Epoch 41: ELBO=-267.38
Epoch 42: ELBO=-267.33
Epoch 43: ELBO=-267.14
Epoch 44: ELBO=-268.

KeyboardInterrupt: 

In [7]:
print("Decoder params:", [n for n,p in gllvm.named_parameters() if n!="scale"])
print("Scale param:", [n for n,p in gllvm.named_parameters() if n=="scale"])


Decoder params: ['wz', 'bias']
Scale param: ['scale']


In [2]:
gllvm0.scale

Parameter containing:
tensor([1., 1., 1., 1., 1.], requires_grad=True)

In [5]:
gllvm.scale

Parameter containing:
tensor([1.0000, 1.0000, 1.0000, 0.6235, 1.4561], requires_grad=True)

In [3]:
gllvm0.wz * -1

tensor([[ 0.1115, -0.1204,  0.3696,  0.2404,  1.1969]], grad_fn=<MulBackward0>)

In [4]:
gllvm.wz * 1

tensor([[ 0.1375, -0.1140,  0.1716,  0.3320,  1.1280]], grad_fn=<MulBackward0>)

In [11]:
print("True W_z:\n", gllvm0.wz)
print("Estimated W_z:\n", gllvm.wz)

print("True bias:\n", gllvm0.bias)
print("Estimated bias:\n", gllvm.bias)

print("True scale:\n", gllvm0.scale)
print("Estimated scale:\n", gllvm.scale)


True W_z:
 Parameter containing:
tensor([[-0.7856, -0.1608, -0.7251,  0.5009, -1.0102]], requires_grad=True)
Estimated W_z:
 Parameter containing:
tensor([[ 0.7671,  0.1664,  0.6880, -0.4997,  1.0128]], requires_grad=True)
True bias:
 Parameter containing:
tensor([0., 0., 0., 0., 0.], requires_grad=True)
Estimated bias:
 Parameter containing:
tensor([ 0.0222,  0.0058,  0.0084,  0.0049, -0.0112], requires_grad=True)
True scale:
 Parameter containing:
tensor([1., 1., 1., 1., 1.], requires_grad=True)
Estimated scale:
 Parameter containing:
tensor([1.0000, 1.0000, 1.0000, 0.9924, 1.0000], requires_grad=True)
