# Tutorial 2: Setting up the Conditional Variational Autoencoder in Pytorch

In this tutorial, we implement a Conditional Variational Autoencoder (CVAE) for medical image generation. The CVAE extends the standard VAE by conditioning both the encoder and decoder on class labels. In our case, we condition the model on a binary pathology label (Pneumonia present or not).

Conditioning allows the model to generate images that are not only realistic but also consistent with a specified clinical label, making CVAEs especially useful for synthetic medical image generation.

## Conditional Variational Autoencoder Architecture Overview

A CVAE consists of three main components:

**Encoder**
Maps an input image and label to a latent distribution parameterized by a mean and variance.

**Latent Space with Reparameterization**
Samples a latent vector from the learned distribution using the reparameterization trick.

**Decoder**
Reconstructs the image from the latent sample and the conditioning label.

The latent representation in a VAE is probabilistic, which enables uncertainty modeling. We make sure to concatenate the label here so that the latent space, q(zâˆ£x,y) is explicitly conditioned on the class label.

## Defining the Encoder

The encoder takes as input:

- A grayscale chest X-ray image $x \in \mathbb{R}^{1 \times 128 \times 128}$

- A conditioning label $y \in \mathbb{R}^{1}$

The image is passed through a series of convolutional layers to extract hierarchical feature representations.  
The conditioning label is then concatenated with the flattened feature vector before predicting the parameters of the latent distribution.


In [4]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, latent_dim=32, label_dim=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU(),
        )
        self.flatten = nn.Flatten()

        self.fc_mu = nn.Linear(128*16*16 + label_dim, latent_dim)
        self.fc_logvar = nn.Linear(128*16*16 + label_dim, latent_dim)

    def forward(self, x, y):
        h = self.conv(x)
        h = self.flatten(h)
        h = torch.cat([h, y], dim=1)
        return self.fc_mu(h), self.fc_logvar(h)


## Defining the Decoder

The decoder takes as input:

- A sampled latent vector $z \in \mathbb{R}^{\text{latent\_dim}}$

- The same conditioning label $y \in \mathbb{R}^{1}$

The latent vector and conditioning label are concatenated and projected back into a spatial feature map, which is then upsampled using transposed convolution layers.


In [5]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=32, label_dim=1):
        super().__init__()
        self.fc = nn.Linear(latent_dim + label_dim, 128*16*16)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1), nn.Sigmoid(),
        )

    def forward(self, z, y):
        h = torch.cat([z, y], dim=1)
        h = self.fc(h).view(-1, 128, 16, 16)
        return self.deconv(h)


## Defining the CVAE class

Training a CVAE involves minimizing the negative Evidence Lower Bound (ELBO), which consists of two terms:

**Reconstruction Loss**
Measures how well the reconstructed image matches the input.

**KL Divergence Loss**
Regularizes the latent distribution toward a unit Gaussian prior.

A useful explanation of the concept can be found here: https://beckham.nz/2023/04/27/conditional-vaes.html

In [6]:
class CVAE(nn.Module):
    def __init__(self, latent_dim=32, label_dim=1):
        super().__init__()
        self.encoder = Encoder(latent_dim, label_dim)
        self.decoder = Decoder(latent_dim, label_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, y):
        mu, logvar = self.encoder(x, y)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z, y)
        return recon, mu, logvar


def cvae_loss(recon_x, x, mu, logvar):
    batch_size = x.size(0)

    recon_loss = nn.functional.mse_loss(
        recon_x, x, reduction='sum'
    ) / batch_size

    kl = -0.5 * torch.sum(
        1 + logvar - mu.pow(2) - logvar.exp()
    ) / batch_size

    return recon_loss + kl


def evaluate(model, loader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            recon, mu, logvar = model(imgs, labels)
            loss = cvae_loss(recon, imgs, mu, logvar)
            total_loss += loss.item()

    return total_loss / len(loader)

## Summary

In this tutorial, we implemented a Conditional Variational Autoencoder in PyTorch, including:

A label-conditioned encoder and decoder

The reparameterization trick for latent sampling

A principled loss function based on variational inference

In the next tutorial, we will train the CVAE on the CheXpert dataset and visualize reconstruction and generation results.