# Tutorial 3: Training and Evaluating the CVAE

In this tutorial, we train a Conditional Variational Autoencoder (CVAE) using PyTorch.
We assume the model architecture, loss function, and dataset have already been defined in earlier tutorials.

We will:

Train the CVAE using labeled data

Monitor training and validation loss

Save the best-performing model for later use



## Training Configuration

These are the the main training hyperparameters:

- LATENT_DIM controls the dimensionality of the learned latent space

- LABEL_DIM specifies the size of the conditional input

- LR is the learning rate for the optimizer

- EPOCHS determines how many passes we make over the dataset

- CHECKPOINT_PATH specifies where to save the best model

- LOAD_EXISTING allows us to resume training from a saved checkpoint

In [None]:
LATENT_DIM = 32
LABEL_DIM = 1
LR = 1e-3
EPOCHS = 50
CHECKPOINT_PATH = "/content/drive/MyDrive/cvae_best2.pt"
LOAD_EXISTING = False

## Train

We now bring together the model, optimizer, and data loaders to run the training loop. The code automatically uses a GPU if one is available, otherwise falling back to CPU. If LOAD_EXISTING is enabled, we load a previously saved model checkpoint before training. We obtain training and validation data loaders from the data module defined in earlier tutorials.

For each training epoch, images and labels are moved to the appropriate device, the model produces reconstructions and latent statistics, the CVAE loss is computed and backpropagated, and the model parameters are updated using Adam.

In [None]:
import torch
from models.cvae import CVAE
from models.losses import cvae_loss
from models.eval import evaluate

from data.data_loader import get_chexpert_train_dataloader, get_chexpert_valid_dataloader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CVAE(latent_dim=LATENT_DIM, label_dim=LABEL_DIM).to(device)
best_val_loss = float("inf")
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

if LOAD_EXISTING:
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
    print("Loaded existing model")

train_loader = get_chexpert_train_dataloader()
valid_loader = get_chexpert_valid_dataloader()

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0

    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        recon, mu, logvar = model(imgs, labels)
        loss = cvae_loss(recon, imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # Validation
    val_loss = evaluate(model, valid_loader, device)

    print(
        f"Epoch {epoch+1}/{EPOCHS} | "
        f"Train: {train_loss:.4f} | Val: {val_loss:.4f}"
    )

    # Checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), CHECKPOINT_PATH)
        print("Saved best model")


## Summary


In this tutorial, we:

- Trained a Conditional Variational Autoencoder using labeled data

- Monitored training and validation loss across epochs

- Saved model checkpoints based on validation performance

In the next tutorial, we will visualize reconstructions and explore how the learned latent space captures meaningful structure in chest X-ray images