In [1]:
import numpy as np, joblib, torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from cvae import CVAE, kl_divergence

In [2]:
class TabDataset(Dataset):
    def __init__(self, X, y_cond, y_prop):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y_cond = torch.tensor(y_cond, dtype=torch.float32)
        self.y_prop = torch.tensor(y_prop, dtype=torch.float32)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.y_cond[i], self.y_prop[i]

In [3]:
def train_epoch(model, loader, opt, device, epoch, warmup_epochs=100, beta_max=2.0, lambda_prop=0.5):
    model.train()
    beta = beta_max * min(1.0, (epoch + 1) / warmup_epochs)
    recon_losses, kl_losses, prop_losses = [], [], []
    for Xb, ycb, ypb in loader:
        Xb, ycb, ypb = Xb.to(device), ycb.to(device), ypb.to(device)
        X_hat, y_prop_pred, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, ycb)
        recon = F.mse_loss(X_hat, Xb)
        kl = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        prop = F.mse_loss(y_prop_pred, ypb)
        loss = recon + beta * kl + lambda_prop * prop
        opt.zero_grad(); loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        recon_losses.append(recon.item()); kl_losses.append(kl.item()); prop_losses.append(prop.item())
    return np.mean(recon_losses), np.mean(kl_losses), np.mean(prop_losses), beta

In [4]:

@torch.no_grad()
def eval_epoch(model, loader, device, lambda_prop=0.5):
    model.eval()
    recon_losses, kl_losses, prop_losses = [], [], []
    for Xb, ycb, ypb in loader:
        Xb, ycb, ypb = Xb.to(device), ycb.to(device), ypb.to(device)
        X_hat, y_prop_pred, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, ycb)
        recon = F.mse_loss(X_hat, Xb)
        kl = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        prop = F.mse_loss(y_prop_pred, ypb)
        recon_losses.append(recon.item()); kl_losses.append(kl.item()); prop_losses.append(prop.item())
    return np.mean(recon_losses), np.mean(kl_losses), np.mean(prop_losses)

In [5]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = np.load("../data/processed/data_splits.npz")
    X_train, X_val = data["X_train"], data["X_val"]
    y_cond_train, y_cond_val = data["y_cond_train"], data["y_cond_val"]
    y_prop_train, y_prop_val = data["y_prop_train"], data["y_prop_val"]
    train_loader = DataLoader(TabDataset(X_train, y_cond_train, y_prop_train), batch_size=32, shuffle=True)
    val_loader = DataLoader(TabDataset(X_val, y_cond_val, y_prop_val), batch_size=64, shuffle=False)

    model = CVAE(x_dim=X_train.shape[1], y_cond_dim=y_cond_train.shape[1], y_prop_dim=y_prop_train.shape[1],
                 z_dim=8, hidden=256).to(device)
    opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

    best_val, patience, bad = 1e9, 40, 0
    for epoch in range(200):
        tr_recon, tr_kl, tr_prop, beta = train_epoch(model, train_loader, opt, device, epoch)
        va_recon, va_kl, va_prop = eval_epoch(model, val_loader, device)
        val_score = va_recon + va_prop
        print(f"Ep {epoch:03d} | β={beta:.2f} | Train R {tr_recon:.4f} KL {tr_kl:.4f} P {tr_prop:.4f} | Val R {va_recon:.4f} KL {va_kl:.4f} P {va_prop:.4f}")
        if val_score < best_val - 1e-5:
            best_val, bad = val_score, 0
            torch.save(model.state_dict(), "../models/cvae_best.pt")
        else: bad += 1
        if bad >= patience: print("Early stopping."); break

In [6]:
if __name__ == "__main__":
    main()

Ep 000 | β=0.02 | Train R 0.9504 KL 0.7096 P 0.1779 | Val R 0.9813 KL 1.8434 P 0.1073
Ep 001 | β=0.04 | Train R 0.8500 KL 2.1829 P 0.0734 | Val R 0.8610 KL 2.0003 P 0.0664
Ep 002 | β=0.06 | Train R 0.7448 KL 1.8606 P 0.0531 | Val R 0.8342 KL 1.5353 P 0.0548
Ep 003 | β=0.08 | Train R 0.7444 KL 1.8643 P 0.0400 | Val R 0.7754 KL 2.0141 P 0.0571
Ep 004 | β=0.10 | Train R 0.6658 KL 1.8981 P 0.0427 | Val R 0.7798 KL 1.7623 P 0.0383
Ep 005 | β=0.12 | Train R 0.6861 KL 1.7717 P 0.0436 | Val R 0.7310 KL 1.7701 P 0.0505
Ep 006 | β=0.14 | Train R 0.6452 KL 1.6984 P 0.0382 | Val R 0.7512 KL 1.4545 P 0.0434
Ep 007 | β=0.16 | Train R 0.6863 KL 1.2853 P 0.0379 | Val R 0.7853 KL 1.3367 P 0.0534
Ep 008 | β=0.18 | Train R 0.7066 KL 1.2374 P 0.0365 | Val R 0.7493 KL 1.2271 P 0.0432
Ep 009 | β=0.20 | Train R 0.6805 KL 1.1107 P 0.0365 | Val R 0.7703 KL 1.2199 P 0.0386
Ep 010 | β=0.22 | Train R 0.7082 KL 1.0150 P 0.0328 | Val R 0.6846 KL 1.1076 P 0.0282
Ep 011 | β=0.24 | Train R 0.7087 KL 0.9080 P 0.0328 | 