In [1]:
import numpy as np, joblib, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from cvae import CVAE, kl_divergence

In [2]:
class TabDataset(Dataset):
    def __init__(self, X, y): self.X = torch.tensor(X, dtype=torch.float32); self.y = torch.tensor(y, dtype=torch.float32)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.y[i]

In [3]:
def train_epoch(model, loader, opt, device, epoch, total_epochs, beta_max=1.0, warmup_epochs=50):
    model.train()
    recon_losses, kl_losses = [], []
    # KL annealing
    if epoch < warmup_epochs:
        beta = beta_max * (epoch+1) / warmup_epochs
    else:
        beta = beta_max

    for Xb, yb in loader:
        Xb, yb = Xb.to(device), yb.to(device)
        X_hat, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, yb)
        recon = nn.functional.mse_loss(X_hat, Xb, reduction="mean")
        kl = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        loss = recon + beta * kl

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        recon_losses.append(recon.item()); kl_losses.append(kl.item())
    return np.mean(recon_losses), np.mean(kl_losses), beta

In [4]:
@torch.no_grad()
def eval_epoch(model, loader, device):
    model.eval()
    recon_losses, kl_losses = [], []
    for Xb, yb in loader:
        Xb, yb = Xb.to(device), yb.to(device)
        X_hat, (q_mu, q_logvar, p_mu, p_logvar) = model(Xb, yb)
        recon = nn.functional.mse_loss(X_hat, Xb, reduction="mean")
        kl = kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
        recon_losses.append(recon.item()); kl_losses.append(kl.item())
    return np.mean(recon_losses), np.mean(kl_losses)

In [5]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load splits
    data = np.load("../data/processed/data_splits.npz")
    X_train, y_train = data["X_train"], data["y_train"]
    X_val,   y_val   = data["X_val"],   data["y_val"]
    x_dim, y_dim = X_train.shape[1], y_train.shape[1]

    # Dataloaders
    train_loader = DataLoader(TabDataset(X_train, y_train), batch_size=16, shuffle=True)
    val_loader   = DataLoader(TabDataset(X_val, y_val), batch_size=32, shuffle=False)

    # Model
    model = CVAE(x_dim=9, y_dim=1, z_dim=6, hidden=128).to(device)
    opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

    best_val = 1e9; patience, bad = 30, 0
    for epoch in range(300):
        tr_recon, tr_kl, beta = train_epoch(model, train_loader, opt, device, epoch, 300, beta_max=1.0, warmup_epochs=80)
        va_recon, va_kl = eval_epoch(model, val_loader, device)
        print(f"Epoch {epoch:03d} | beta={beta:.3f} | train recon {tr_recon:.4f} KL {tr_kl:.4f} | val recon {va_recon:.4f} KL {va_kl:.4f}")

        # Early stopping on val recon
        if va_recon < best_val - 1e-5:
            best_val = va_recon
            bad = 0
            torch.save(model.state_dict(), "cvae_best.pt")
        else:
            bad += 1
        if bad >= patience: 
            print("Early stopping."); break
        
main()

Epoch 000 | beta=0.013 | train recon 0.9876 KL 0.3145 | val recon 1.6971 KL 0.5850
Epoch 001 | beta=0.025 | train recon 0.9244 KL 0.9903 | val recon 1.6805 KL 1.3679
Epoch 002 | beta=0.037 | train recon 0.8144 KL 2.4619 | val recon 1.6211 KL 1.7434
Epoch 003 | beta=0.050 | train recon 0.6565 KL 2.3806 | val recon 1.6080 KL 1.5788
Epoch 004 | beta=0.062 | train recon 0.6531 KL 2.1718 | val recon 1.5561 KL 1.3928
Epoch 005 | beta=0.075 | train recon 0.6192 KL 1.9079 | val recon 1.5928 KL 1.5939
Epoch 006 | beta=0.087 | train recon 0.6111 KL 2.0559 | val recon 1.5828 KL 1.7814
Epoch 007 | beta=0.100 | train recon 0.5130 KL 2.0628 | val recon 1.5265 KL 1.5912
Epoch 008 | beta=0.113 | train recon 0.5245 KL 2.2858 | val recon 1.4926 KL 1.3846
Epoch 009 | beta=0.125 | train recon 0.5262 KL 1.7573 | val recon 1.5475 KL 1.1256
Epoch 010 | beta=0.138 | train recon 0.6017 KL 1.4892 | val recon 1.5504 KL 1.1089
Epoch 011 | beta=0.150 | train recon 0.5171 KL 1.5621 | val recon 1.3829 KL 1.0969
Epoc