In [1]:
import numpy as np, joblib, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from cvae import CVAE, kl_divergence

In [2]:
class TabDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.y[i]

In [3]:
def train_epoch(model, loader, opt, device, epoch, warmup_epochs=120, beta_max=1.0, lambda_prop=1.0):
    model.train()
    if epoch < warmup_epochs:
        beta = beta_max * (epoch+1) / warmup_epochs
    else:
        beta = beta_max

    recon_losses, kl_losses, prop_losses = [], [], []
    for Xb, yb in loader:
        Xb, yb = Xb.to(device), yb.to(device)
        X_hat, y_pred, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, yb)
        recon = nn.functional.mse_loss(X_hat, Xb, reduction="mean")
        kl    = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        prop  = nn.functional.mse_loss(y_pred, yb, reduction="mean")

        loss = recon + beta*kl + lambda_prop*prop

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        recon_losses.append(recon.item())
        kl_losses.append(kl.item())
        prop_losses.append(prop.item())

    return np.mean(recon_losses), np.mean(kl_losses), np.mean(prop_losses), beta

In [4]:
@torch.no_grad()
def eval_epoch(model, loader, device, lambda_prop=1.0):
    model.eval()
    recon_losses, kl_losses, prop_losses = [], [], []
    for Xb, yb in loader:
        Xb, yb = Xb.to(device), yb.to(device)
        X_hat, y_pred, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, yb)
        recon = nn.functional.mse_loss(X_hat, Xb, reduction="mean")
        kl    = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        prop  = nn.functional.mse_loss(y_pred, yb, reduction="mean")
        recon_losses.append(recon.item()); kl_losses.append(kl.item()); prop_losses.append(prop.item())
    return np.mean(recon_losses), np.mean(kl_losses), np.mean(prop_losses)

In [7]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = np.load("../data/processed/data_splits.npz")  # <- paths per your repo
    X_train, y_train = data["X_train"], data["y_train"]
    X_val,   y_val   = data["X_val"],   data["y_val"]
    x_dim, y_dim = X_train.shape[1], y_train.shape[1]

    train_loader = DataLoader(TabDataset(X_train, y_train), batch_size=16, shuffle=True)
    val_loader   = DataLoader(TabDataset(X_val, y_val), batch_size=32, shuffle=False)

    model = CVAE(x_dim=x_dim, y_dim=y_dim, z_dim=4, hidden=128).to(device)
    opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

    best_val = 1e9; patience, bad = 60, 0
    EPOCHS = 200
    for epoch in range(EPOCHS):
        tr_recon, tr_kl, tr_prop, beta = train_epoch(model, train_loader, opt, device,
                                                     epoch, warmup_epochs=120, beta_max=1.0, lambda_prop=1.0)
        va_recon, va_kl, va_prop = eval_epoch(model, val_loader, device, lambda_prop=1.0)
        print(f"Ep {epoch:03d} | beta={beta:.3f} | "
              f"train R {tr_recon:.4f} KL {tr_kl:.4f} P {tr_prop:.4f} | "
              f"val R {va_recon:.4f} KL {va_kl:.4f} P {va_prop:.4f}")

        # Early stop on (recon + prop) to respect both
        val_score = va_recon + va_prop
        if val_score < best_val - 1e-5:
            best_val = val_score; bad = 0
            torch.save(model.state_dict(), "../models/cvae_best.pt")
        else:
            bad += 1
        if bad >= patience:
            print("Early stopping."); break
        
main()

Ep 000 | beta=0.008 | train R 0.9515 KL 0.3671 P 0.1742 | val R 1.7081 KL 0.6130 P 0.0542
Ep 001 | beta=0.017 | train R 1.5597 KL 0.8637 P 0.0544 | val R 1.6844 KL 0.9779 P 0.0192
Ep 002 | beta=0.025 | train R 1.5184 KL 1.2003 P 0.0414 | val R 1.6218 KL 1.0488 P 0.0293
Ep 003 | beta=0.033 | train R 0.7947 KL 1.7120 P 0.0381 | val R 1.6676 KL 1.3464 P 0.0358
Ep 004 | beta=0.042 | train R 0.6625 KL 1.7029 P 0.0373 | val R 1.6152 KL 1.2205 P 0.0273
Ep 005 | beta=0.050 | train R 0.9143 KL 1.8811 P 0.0338 | val R 1.5870 KL 1.2749 P 0.0341
Ep 006 | beta=0.058 | train R 0.6377 KL 1.6419 P 0.0316 | val R 1.6057 KL 1.4838 P 0.0307
Ep 007 | beta=0.067 | train R 0.5295 KL 2.1282 P 0.0317 | val R 1.5707 KL 1.7599 P 0.0213
Ep 008 | beta=0.075 | train R 0.5251 KL 2.3717 P 0.0227 | val R 1.5178 KL 1.9200 P 0.0203
Ep 009 | beta=0.083 | train R 0.4903 KL 2.3777 P 0.0247 | val R 1.4726 KL 1.8635 P 0.0318
Ep 010 | beta=0.092 | train R 0.4249 KL 2.3038 P 0.0211 | val R 1.4766 KL 1.6755 P 0.0280
Ep 011 | b