In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# -------------------------
# Data (matches your notebook)
# -------------------------
def sample_mog(n=5000, pi=0.4, mu1=-3.0, sigma1=1.0, mu2=3.0, sigma2=1.5, seed=538):
    torch.manual_seed(seed)
    z = torch.bernoulli(torch.full((n,), pi))
    x = torch.where(
        z == 1,
        torch.normal(mu1, sigma1, size=(n,)),
        torch.normal(mu2, sigma2, size=(n,))
    )
    return x

# -------------------------
# Generator g(eps) -> x
# -------------------------
class Generator(nn.Module):
    def __init__(self, z_dim=8, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden, 1),
        )

    def forward(self, z):
        return self.net(z).squeeze(-1)  # (B,)

# -------------------------
# Engression loss (unconditional, 1D)
# ES â‰ˆ mean |x - g(eps)| - 0.5 * mean |g(eps) - g(eps')|
# We use m samples per data point in the minibatch (algorithm in notes).
# -------------------------
def engression_loss(x_batch, G, z_dim=8, m=8):
    """
    x_batch: (B,)
    We draw m noises per x_i: z_{i,1..m}, generate yhat_{i,1..m}.
    term1 = mean_{i} mean_{j} |x_i - yhat_{i,j}|
    term2 = mean_{i} mean_{j<k} |yhat_{i,j} - yhat_{i,k}|
    loss = term1 - 0.5 * term2
    """
    B = x_batch.shape[0]
    # (B, m, z_dim)
    z = torch.randn(B, m, z_dim, device=x_batch.device)
    yhat = G(z.view(B * m, z_dim)).view(B, m)  # (B, m)

    term1 = (yhat - x_batch[:, None]).abs().mean()

    # pairwise absolute differences within each i (B, m, m)
    diffs = (yhat[:, :, None] - yhat[:, None, :]).abs()
    # average over j<k (avoid diagonal / double counting)
    # mask upper triangle
    triu_mask = torch.triu(torch.ones(m, m, device=x_batch.device), diagonal=1).bool()
    term2 = diffs[:, triu_mask].mean()

    return term1 - 0.5 * term2

# -------------------------
# Train
# -------------------------
def train_engression(x, steps=4000, batch_size=256, lr=2e-4, z_dim=8, m=8, device="cpu"):
    ds = TensorDataset(x.to(device))
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)

    G = Generator(z_dim=z_dim).to(device)
    opt = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.9))

    it = 0
    while it < steps:
        for (xb,) in dl:
            opt.zero_grad()
            loss = engression_loss(xb, G, z_dim=z_dim, m=m)
            loss.backward()
            opt.step()

            it += 1
            if it % 500 == 0:
                print(f"[Engression] step={it:5d}  loss={loss.item():.4f}")
            if it >= steps:
                break
    return G

@torch.no_grad()
def sample_from_G(G, n=5000, z_dim=8, device="cpu"):
    z = torch.randn(n, z_dim, device=device)
    return G(z).cpu()

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    x = sample_mog(n=5000)  # your same mixture setup
    G = train_engression(x, steps=4000, batch_size=256, lr=2e-4, z_dim=8, m=8, device=device)

    x_fake = sample_from_G(G, n=5000, z_dim=8, device=device)
    print("Generated mean/std:", x_fake.mean().item(), x_fake.std().item())
