# Tutorial 4: Visualizing Reconstructions and Exploring the Latent Space

In this tutorial, we analyze the behavior of the trained Conditional Variational Autoencoder (CVAE) by visualizing reconstructed images and probing the learned latent space.

While training loss provides a quantitative measure of performance, qualitative inspection is also important for understanding what the model has actually learned.

We will:

- Visualize original chest X-ray images alongside their reconstructions

- Examine how reconstruction quality varies across samples

- Explore how the latent space changes as we modify the conditional label

In [None]:
import torch
from models.cvae import CVAE

LATENT_DIM = 32
LABEL_DIM = 1
CHECKPOINT_PATH = "/content/drive/MyDrive/cvae_best.pt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CVAE(latent_dim=LATENT_DIM, label_dim=LABEL_DIM).to(device)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
model.eval()

print("Model loaded successfully.")


## Visualizing Reconstructions

To evaluate reconstruction quality, we compare the original input images with their corresponding reconstructions produced by the CVAE.

In [None]:
import matplotlib.pyplot as plt

from data.data_loader import get_chexpert_train_dataloader, get_chexpert_valid_dataloader


valid_loader = get_chexpert_valid_dataloader()

imgs, labels = next(iter(valid_loader))
imgs = imgs.to(device)
labels = labels.to(device)

with torch.no_grad():
    recon, _, _ = model(imgs, labels)

# Plot a few examples
n = 5
plt.figure(figsize=(10, 4))

for i in range(n):
    # Original
    plt.subplot(2, n, i + 1)
    plt.imshow(imgs[i][0], cmap="gray")
    plt.axis("off")
    if i == 0:
        plt.title("Original")

    # Reconstruction
    plt.subplot(2, n, i + 1 + n)
    plt.imshow(recon[i][0], cmap="gray")
    plt.axis("off")
    if i == 0:
        plt.title("Reconstruction")

plt.tight_layout()
plt.show()