In [1]:
import numpy as np
import joblib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from cvae import CVAE, kl_divergence

In [2]:
# Dataset Class
class TabDataset(Dataset):
    def __init__(self, X, y_cond, y_prop):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y_cond = torch.tensor(y_cond, dtype=torch.float32)
        self.y_prop = torch.tensor(y_prop, dtype=torch.float32)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, i):
        return self.X[i], self.y_cond[i], self.y_prop[i]

In [3]:
# Training Loop
def train_epoch(model, loader, opt, device,
                epoch, warmup_epochs=120,
                beta_max=1.0, lambda_prop=1.0):
    model.train()

    # β warm-up schedule
    beta = beta_max * min(1.0, (epoch + 1) / warmup_epochs)

    recon_losses, kl_losses, prop_losses = [], [], []

    for Xb, ycb, ypb in loader:
        Xb, ycb, ypb = Xb.to(device), ycb.to(device), ypb.to(device)

        X_hat, y_prop_pred, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, ycb)

        # Loss terms
        recon = nn.functional.mse_loss(X_hat, Xb, reduction="mean")
        kl = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        prop = nn.functional.mse_loss(y_prop_pred, ypb, reduction="mean")

        loss = recon + beta * kl + lambda_prop * prop

        # Backprop
        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        # Logging
        recon_losses.append(recon.item())
        kl_losses.append(kl.item())
        prop_losses.append(prop.item())

    return np.mean(recon_losses), np.mean(kl_losses), np.mean(prop_losses), beta

In [4]:
# Evaluation Loop
@torch.no_grad()
def eval_epoch(model, loader, device, lambda_prop=1.0):
    model.eval()
    recon_losses, kl_losses, prop_losses = [], [], []

    for Xb, ycb, ypb in loader:
        Xb, ycb, ypb = Xb.to(device), ycb.to(device), ypb.to(device)

        X_hat, y_prop_pred, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, ycb)

        recon = nn.functional.mse_loss(X_hat, Xb, reduction="mean")
        kl = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        prop = nn.functional.mse_loss(y_prop_pred, ypb, reduction="mean")

        recon_losses.append(recon.item())
        kl_losses.append(kl.item())
        prop_losses.append(prop.item())

    return np.mean(recon_losses), np.mean(kl_losses), np.mean(prop_losses)

In [5]:
# Main Training Script
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load preprocessed data
    data = np.load("../data/processed/data_splits.npz")
    X_train, X_val = data["X_train"], data["X_val"]
    y_cond_train, y_cond_val = data["y_cond_train"], data["y_cond_val"]
    y_prop_train, y_prop_val = data["y_prop_train"], data["y_prop_val"]

    # Dataloaders
    train_loader = DataLoader(TabDataset(X_train, y_cond_train, y_prop_train),
                              batch_size=16, shuffle=True)
    val_loader = DataLoader(TabDataset(X_val, y_cond_val, y_prop_val),
                            batch_size=32, shuffle=False)

    # Model
    x_dim = X_train.shape[1]
    y_cond_dim = y_cond_train.shape[1]
    y_prop_dim = y_prop_train.shape[1]
    model = CVAE(x_dim=x_dim, y_cond_dim=y_cond_dim, y_prop_dim=y_prop_dim,
                 z_dim=4, hidden=128).to(device)

    opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

    # Training loop
    best_val = 1e9
    patience, bad = 60, 0
    EPOCHS = 200

    for epoch in range(EPOCHS):
        tr_recon, tr_kl, tr_prop, beta = train_epoch(
            model, train_loader, opt, device,
            epoch, warmup_epochs=120, beta_max=1.0, lambda_prop=1.0
        )
        va_recon, va_kl, va_prop = eval_epoch(model, val_loader, device, lambda_prop=1.0)

        print(f"Ep {epoch:03d} | beta={beta:.3f} | "
              f"train R {tr_recon:.4f} KL {tr_kl:.4f} P {tr_prop:.4f} | "
              f"val R {va_recon:.4f} KL {va_kl:.4f} P {va_prop:.4f}")

        # Early stopping on (recon + prop)
        val_score = va_recon + va_prop
        if val_score < best_val - 1e-5:
            best_val = val_score
            bad = 0
            torch.save(model.state_dict(), "../models/cvae_best.pt")
        else:
            bad += 1

        if bad >= patience:
            print("Early stopping.")
            break

In [6]:
if __name__ == "__main__":
    main()

Ep 000 | beta=0.008 | train R 0.9850 KL 0.2505 P 0.0955 | val R 1.0165 KL 0.4006 P 0.0657
Ep 001 | beta=0.017 | train R 0.9196 KL 0.8059 P 0.0446 | val R 0.9299 KL 1.1986 P 0.0504
Ep 002 | beta=0.025 | train R 0.7924 KL 1.9193 P 0.0371 | val R 0.8143 KL 2.1898 P 0.0396
Ep 003 | beta=0.033 | train R 0.6818 KL 2.5110 P 0.0289 | val R 0.7096 KL 2.5562 P 0.0407
Ep 004 | beta=0.042 | train R 0.5997 KL 2.8608 P 0.0287 | val R 0.6675 KL 2.6972 P 0.0328
Ep 005 | beta=0.050 | train R 0.5638 KL 2.7443 P 0.0250 | val R 0.6357 KL 2.4682 P 0.0217
Ep 006 | beta=0.058 | train R 0.5551 KL 2.5383 P 0.0222 | val R 0.6008 KL 2.5052 P 0.0281
Ep 007 | beta=0.067 | train R 0.5365 KL 2.4971 P 0.0261 | val R 0.5976 KL 2.4773 P 0.0240
Ep 008 | beta=0.075 | train R 0.5069 KL 2.4326 P 0.0208 | val R 0.5426 KL 2.3877 P 0.0224
Ep 009 | beta=0.083 | train R 0.5078 KL 2.4176 P 0.0198 | val R 0.5385 KL 2.5060 P 0.0342
Ep 010 | beta=0.092 | train R 0.5042 KL 2.3180 P 0.0225 | val R 0.5441 KL 2.2851 P 0.0213
Ep 011 | b