In [1]:
import numpy as np, joblib, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from cvae import CVAE, kl_divergence

In [2]:
class TabDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.y[i]

In [3]:
def train_epoch(model, loader, opt, device, epoch, warmup_epochs=120, beta_max=1.0, lambda_prop=1.0):
    model.train()
    if epoch < warmup_epochs:
        beta = beta_max * (epoch+1) / warmup_epochs
    else:
        beta = beta_max

    recon_losses, kl_losses, prop_losses = [], [], []
    for Xb, yb in loader:
        Xb, yb = Xb.to(device), yb.to(device)
        X_hat, y_pred, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, yb)
        recon = nn.functional.mse_loss(X_hat, Xb, reduction="mean")
        kl    = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        prop  = nn.functional.mse_loss(y_pred, yb, reduction="mean")

        loss = recon + beta*kl + lambda_prop*prop

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        recon_losses.append(recon.item())
        kl_losses.append(kl.item())
        prop_losses.append(prop.item())

    return np.mean(recon_losses), np.mean(kl_losses), np.mean(prop_losses), beta

In [4]:
@torch.no_grad()
def eval_epoch(model, loader, device, lambda_prop=1.0):
    model.eval()
    recon_losses, kl_losses, prop_losses = [], [], []
    for Xb, yb in loader:
        Xb, yb = Xb.to(device), yb.to(device)
        X_hat, y_pred, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, yb)
        recon = nn.functional.mse_loss(X_hat, Xb, reduction="mean")
        kl    = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        prop  = nn.functional.mse_loss(y_pred, yb, reduction="mean")
        recon_losses.append(recon.item()); kl_losses.append(kl.item()); prop_losses.append(prop.item())
    return np.mean(recon_losses), np.mean(kl_losses), np.mean(prop_losses)

In [5]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = np.load("../data/processed/data_splits.npz")  # <- paths per your repo
    X_train, y_train = data["X_train"], data["y_train"]
    X_val,   y_val   = data["X_val"],   data["y_val"]
    x_dim, y_dim = X_train.shape[1], y_train.shape[1]

    train_loader = DataLoader(TabDataset(X_train, y_train), batch_size=16, shuffle=True)
    val_loader   = DataLoader(TabDataset(X_val, y_val), batch_size=32, shuffle=False)

    model = CVAE(x_dim=x_dim, y_dim=y_dim, z_dim=4, hidden=128).to(device)
    opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

    best_val = 1e9; patience, bad = 60, 0
    EPOCHS = 200
    for epoch in range(EPOCHS):
        tr_recon, tr_kl, tr_prop, beta = train_epoch(model, train_loader, opt, device,
                                                     epoch, warmup_epochs=120, beta_max=1.0, lambda_prop=1.0)
        va_recon, va_kl, va_prop = eval_epoch(model, val_loader, device, lambda_prop=1.0)
        print(f"Ep {epoch:03d} | beta={beta:.3f} | "
              f"train R {tr_recon:.4f} KL {tr_kl:.4f} P {tr_prop:.4f} | "
              f"val R {va_recon:.4f} KL {va_kl:.4f} P {va_prop:.4f}")

        # Early stop on (recon + prop) to respect both
        val_score = va_recon + va_prop
        if val_score < best_val - 1e-5:
            best_val = val_score; bad = 0
            torch.save(model.state_dict(), "../models/cvae_best.pt")
        else:
            bad += 1
        if bad >= patience:
            print("Early stopping."); break
        
main()

Ep 000 | beta=0.008 | train R 0.9830 KL 0.2507 P 0.1428 | val R 1.3181 KL 0.4729 P 0.0644
Ep 001 | beta=0.017 | train R 0.8954 KL 0.6110 P 0.0544 | val R 1.2966 KL 0.5310 P 0.0302
Ep 002 | beta=0.025 | train R 0.8768 KL 0.6492 P 0.0367 | val R 1.2750 KL 0.6341 P 0.0259
Ep 003 | beta=0.033 | train R 0.8782 KL 1.0380 P 0.0256 | val R 1.2455 KL 0.9645 P 0.0284
Ep 004 | beta=0.042 | train R 0.8957 KL 1.6814 P 0.0225 | val R 1.2272 KL 1.1535 P 0.0362
Ep 005 | beta=0.050 | train R 0.7506 KL 1.5714 P 0.0207 | val R 1.2274 KL 1.2477 P 0.0191
Ep 006 | beta=0.058 | train R 0.7835 KL 1.5979 P 0.0249 | val R 1.1943 KL 1.1999 P 0.0264
Ep 007 | beta=0.067 | train R 0.7182 KL 2.0141 P 0.0245 | val R 1.1793 KL 1.2392 P 0.0209
Ep 008 | beta=0.075 | train R 0.6798 KL 1.8302 P 0.0295 | val R 1.1526 KL 1.2697 P 0.0233
Ep 009 | beta=0.083 | train R 0.6434 KL 1.6847 P 0.0224 | val R 1.1539 KL 1.2098 P 0.0199
Ep 010 | beta=0.092 | train R 0.6542 KL 1.5838 P 0.0235 | val R 1.0921 KL 1.2704 P 0.0235
Ep 011 | b