# GdVAE CelebA

In [1]:
# Copyright (C) 2024 Ruhr West University of Applied Sciences, Bottrop, Germany
# AND e:fs TechHub GmbH, Gaimersheim, Germany
#
# This Source Code Form is subject to the terms of the Apache License 2.0
# If a copy of the APL2 was not distributed with this
# file, You can obtain one at https://www.apache.org/licenses/LICENSE-2.0.txt.

## Imports
Handling all necessary imports for the rest of the notebook

In [2]:
import torch, torchvision
from sklearn.metrics import accuracy_score
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from typing import Tuple
%matplotlib inline

## CUDA
Check if CUDA is available

In [3]:
print(f"Is Cuda available? {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Is Cuda available? True


## Initialize Dataset
Since CelebA usually exceeds the daily quota it is recommended to download the dataset directly from their Google Drive

In [4]:
batch_size = 64

selected_classes = [31]
data_root = 'data/'


train_dataset = torchvision.datasets.CelebA(data_root,
                                            target_type='attr',
                                            download=False,
                                            split='train',
                                            transform=torchvision.transforms.Compose([
                                                torchvision.transforms.CenterCrop(178),
                                                torchvision.transforms.Resize(64),
                                                torchvision.transforms.ToTensor()
                                            ]))
test_dataset = torchvision.datasets.CelebA(data_root,
                                            target_type='attr',
                                            download=False,
                                            split='test',
                                            transform=torchvision.transforms.Compose([
                                                torchvision.transforms.CenterCrop(178),
                                                torchvision.transforms.Resize(64),
                                                torchvision.transforms.ToTensor()
                                            ]))
# Reduce Attributes of CelebA Dataset to only include the requested class
# e.g. Class 31 = Smiling
train_dataset.attr = train_dataset.attr[:, selected_classes]
test_dataset.attr = test_dataset.attr[:, selected_classes]

# Init Dataloader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          )
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         drop_last=True,
                                         )

## View Helper
torch.nn.Sequential view helper function

In [5]:
class ViewCeleba(torch.nn.Module):
    """
    Helper function to use .view() inside of torch.nn.Sequential()
    """
    def __init__(self, size: Tuple[int, ...]):
        super().__init__()
        self.size = size

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return inputs.view(self.size)

## CelebA CVAE-Encoder

In [6]:
class CelebAEncoder(torch.nn.Module):
    """
    Encoder q(z|x,y) to encode a given image to the latent space 
    """

    def __init__(self,
                 input_shape: Tuple[int, int, int] = (3, 64, 64),
                 latent_size: int = 64,
                 num_classes: int = 2):
        super().__init__()
        """
        CelebA Encoder
        Args:
            input_shape (Tuple[int, int, int]): Size of Input Images
            latent_size (int): Number of latent dimensions.
            num_classes (int): Number of classes.
        """
        channels, height, width = input_shape

        self.latent_size = latent_size

        self.in_channels = channels
        self.in_height = height
        self.in_width = width

        self.DIM = 32

        # CelebA Encoder
        self.conv_encoder = torch.nn.Sequential(
            torch.nn.Conv2d(self.in_channels+1, self.DIM, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),  # +1 for label encoder
            torch.nn.ReLU(True),
            torch.nn.Conv2d(self.DIM, self.DIM, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(self.DIM, self.DIM*2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(self.DIM*2, self.DIM*2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(self.DIM*2, self.DIM*8, kernel_size=(4, 4), stride=(1, 1)),
            torch.nn.ReLU(True),
            ViewCeleba((-1, self.DIM*8)), # Flatten
            torch.nn.Linear(self.DIM*8, self.latent_size * 2),
        )

        # Label Encoder
        self.label_encoder = torch.nn.Sequential(
            torch.nn.Linear(in_features=num_classes, out_features=1 * self.in_height * self.in_width),
            ViewCeleba((-1, 1, self.in_height, self.in_width))
        )

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        The forward step of the encoder returns the latent space(z) generated by x and y
        q(z|x,y)

        Args:
            x (torch.tensor): set of images
            y (torch.tensor): set of one hot encoded class labels

        Returns:
            (torch.tensor): mean
            (torch.tensor): log variance
        """
        y = self.label_encoder(y)

        # Concatenate transformed label with image along channel dimension
        x = torch.cat([x, y], dim=1)

        # Actual encoding of x and y to latent representation z
        x = self.conv_encoder(x)

        # Split into Mean and LogVar
        mean = x[:, :self.latent_size]
        log_variance = x[:, self.latent_size:]

        # Clamp LogVar
        log_variance.clamp_(None, 5)

        return mean, log_variance

## CelebA CVAE-Decoder

In [7]:
class CelebADecoder(torch.nn.Module):
    """
    Decoder p(x|z,y) to decode a given latent space vector back to the image space
    """
 
    def __init__(self,
                 input_shape: Tuple[int, int, int] = (3, 64, 64),
                 latent_size: int = 64,
                 num_classes: int = 2,
                 conditional_decoder: bool=True):
        super().__init__()
        """
        Args:
            input_shape (Tuple[int, int, int]): Size of Input Images
            latent_size (int): Number of latent dimensions.
            num_classes (int): Number of classes.
            conditional_decoder (bool): If the model should use a conditional decoder
        """
        self.out_channels, self.out_height, self.out_width = input_shape

        self.num_features = self.out_channels * self.out_height * self.out_width
        self.latent_size = latent_size  # n_dims_latent

        self.conditional_decoder = conditional_decoder
        self.latent_size_decoder = self.latent_size

        self.DIM = 32

        if self.conditional_decoder:
            self.latent_size_decoder = self.latent_size + 1

        self.conv_decoder = torch.nn.Sequential(
            torch.nn.Linear(self.latent_size_decoder, self.DIM*8),
            ViewCeleba((-1, self.DIM*8, 1, 1)),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(self.DIM*8, self.DIM*2, kernel_size=(4, 4)),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(self.DIM*2, self.DIM*2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(self.DIM*2, self.DIM, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(self.DIM, self.DIM, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(self.DIM, self.out_channels, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            torch.nn.ReLU(True)
        )
        self.label_decoder = torch.nn.Sequential(
            torch.nn.Linear(in_features=num_classes, out_features=1)
        )

    def forward(self, z: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        The forward step of the decoder returns the image(x) generated by z and y
        p(x|z,y)

        Args:
            z (torch.tensor): set of latent space vectors
            y (torch.tensor): set of one hot encoded class labels

        Returns:
            (torch.tensor): generated Image
        """
        # Conditional Decoder
        if self.conditional_decoder:
            y = self.label_decoder(y)
            z = torch.cat([z, y], dim=1)
        # Generate Image
        x = self.conv_decoder(z)

        return x


## CelebA Prior-Encoder

In [8]:
class CelebAPriorEncoder(torch.nn.Module):
    """
    The Prior Encoder p(z|y) learns normal distributions for each dimension of the latent representation
    for each class based on the class labels 'y'.

    The GDA utilizes the prior distribution to classify samples drawn from the encoder distributions.
    """

    def __init__(self,
                 latent_size: int = 64,
                 num_classes: int = 2):
        """

        Args:
            num_classes (int): number of classes
            latent_size (int): number of latent dimensions

        """
        super().__init__()

        self.num_classes = num_classes
        self.latent_size = latent_size
        self.DIM = 4

        self.prior_encoder_mean = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.num_classes, out_features=self.DIM),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=self.DIM, out_features=self.DIM),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=self.DIM, out_features=self.DIM),
            torch.nn.ReLU(),
        )

        self.prior_encoder_log_variance = torch.nn.Sequential(
            torch.nn.Linear(in_features=1, out_features=self.DIM),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=self.DIM, out_features=self.DIM),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=self.DIM, out_features=self.DIM),
            torch.nn.ReLU(),
        )

        self.mean_layer = torch.nn.Linear(self.DIM, self.latent_size)
        self.log_variance_layer = torch.nn.Linear(self.DIM, self.latent_size)

    def forward(self, y: torch.tensor) -> Tuple[torch.tensor, torch.tensor]:
        """
        The forward step of the prior network returns the latent distributions p(z|y) used by the GDA.


        Args:
            y (torch.tensor): set of one hot encoded class labels

        Returns:
            (torch.tensor): prior mean
            (torch.tensor): prior log variance
        """

        z_mean = self.prior_encoder_mean(y)
        mean = self.mean_layer(z_mean)

        y_lv = torch.ones([y.shape[0], 1]).to(device)
        z_log_variance = self.prior_encoder_log_variance(y_lv)
        log_variance = self.log_variance_layer(z_log_variance)
        log_variance.clamp_(None, 5)

        return mean, log_variance

    def get_prior_distributions_parameters(self):
        """
        Return the parameters of the prior distributions

        Returns:
            (torch.tensor): mean value of the normal distributions for each class
            (torch.tensor): standard deviation of the normal distributions for each class
        """
        # identity matrix with size equal to the number of classes
        # -> used to get the prior distributions for all classes to calculate likelihood values
        identity_matrix = torch.eye(self.num_classes, self.num_classes).to(device)

        # calculate conditional probabilities for all classes p(z | identity_matrix)
        mean_prior, log_variance_prior = self.forward(identity_matrix)
        # std = sqrt(e^(ln(variance)))
        std_prior = log_variance_prior.exp().sqrt() + torch.finfo(torch.float32).eps
        return mean_prior, std_prior


## CelebA GDA

In [9]:
class CelebAGDA(torch.nn.Module):
    """
    Gaussian Discriminant Analysis.

    This class employs the encoder as a feature extractor, which is then classified by the GDA based 
    on the distributions of the prior network. It constitutes the implementation of Algorithm 1 
    (an EM-based classifier) from the main paper.

    """

    def __init__(self,
                 prior_network: torch.nn.Module,
                 encoder: torch.nn.Module,
                 train_loader: torch.utils.data.DataLoader,
                 input_shape: Tuple[int, int, int] = (3, 64, 64),
                 latent_size: int = 64,
                 num_classes: int = 2,
                 iterations: int = 3,
                 sample_amount: int = 20):
        """
        Args:
            prior_network: used prior network reference
            encoder: used encoder reference
            train_loader: used dataloader
            input_shape: input shape of used images
            latent_size: Number of latent dimensions.
            num_classes: Number of classes.
            iterations: GDA iterations T
            sample_amount: GDA number of samples S
        """

        super().__init__()
        self.num_features = input_shape[0]*input_shape[1]*input_shape[2]
        self.num_classes = num_classes
        self.latent_size = latent_size
        self.iterations = iterations
        self.sample_amount = sample_amount
        self.prior_network = prior_network
        self.encoder = encoder

        # Compute prior for all classes based on their likelihood of occurrence
        _, class_counts = torch.unique(train_loader.dataset.attr, return_counts=True)
        #determine p(y)
        class_ratios = class_counts / class_counts.sum()

        # Restore dimensions used by the previously implemented code
        class_ratios = class_ratios.unsqueeze(dim=0).T
        # Keep dtype consistent
        class_ratios = class_ratios.type(torch.float32)
        #set prior probability p(y)
        self.prior = class_ratios.to(device)


    def forward(self, x: torch.tensor, use_softmax=True) -> Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
        """
        GDA forward pass.
        
        The forward pass implements Algorithm 1 from the paper and 
        generates a sample from q(z|x) that is then classified for 
        the loss calculation.

        Args:
            x : input batch
        Returns:
            torch.tensor: Class wise prediction for each sample in 'x'
            torch.tensor: latent space generated for input batch
            torch.tensor: One hot encoded predictions for each sample
            torch.tensor: Encoder Mean for input batch
            torch.tensor: Encoder log variance for input batch
        """

        mean_prior, std_prior = self.prior_network.get_prior_distributions_parameters()

        # Algorithm 1 (an EM-based classifier)
        # q(y|x) Init
        q_y_x_log = torch.transpose(torch.log(self.prior), 0, 1)
        q_y_x_log = q_y_x_log.expand(x.shape[0], -1)

        # Sampling Start
        for iter_idx in range(self.iterations):
            z_samples = []
            for samp_idx in range(self.sample_amount):
                # q(y|x)
                y_pred = torch.nn.functional.gumbel_softmax(logits=q_y_x_log, tau=1, hard=True)

                # q(z|x,y) for sampled y
                enc_mean, enc_logvar = self.encoder(x=x, y=y_pred)

                # z(s) => sampled z
                z_sample = reparameterize(enc_mean, enc_logvar)
                z_samples.append(z_sample)
            z_pred = torch.stack(z_samples)
            z_pred = torch.unsqueeze(z_pred, dim=2)

            # calculate a softmax confidence score based on the samples likelihoods
            # to be drawn from the respective distributions
            # p(z(s)|y)
            loglikelihood = log_prob_normal(mean_prior, std_prior, z_pred)

            # log(p(z|y)=log prod_z p(z_i|y)-> sum_l_
            loglikelihood_z_given_y = torch.sum(loglikelihood, dim=3)

            # p(y|z(s)) <- p(z(s)|y) p(y)
            loglikelihood_weighted_z_given_y = torch.add(loglikelihood_z_given_y, self.prior.log().view(1, -1))

            # q(y|x) <- sum(p(y|z(s)))
            q_y_x_log = torch.mean(loglikelihood_weighted_z_given_y, dim=0)

        #After implementing Algorithm 1, use q(y|x) 
        #to generate samples for the loss calculation, 
        if self.training:
            y_preds_hard = torch.nn.functional.gumbel_softmax(logits=q_y_x_log, tau=1, hard=True)
            enc_means, enc_logvars = self.encoder(x=x, y=y_preds_hard)
            z_preds = reparameterize(enc_means, enc_logvars)
            z_preds = torch.unsqueeze(z_preds, dim=1)

            # p(z(s)|y)
            loglikelihood = log_prob_normal(mean_prior, std_prior, z_preds)

            # log(p(z|y)=log prod_z p(z_i|y)-> sum_l_
            loglikelihood_z_given_y = torch.sum(loglikelihood, dim=2)  # (r, k)

            # p(y|z(s)) <- p(z(s)|y) p(y)
            loglikelihood_weighted_z_given_y = torch.add(loglikelihood_z_given_y, self.prior.log().view(1, -1))  # (r,k)
            y_preds = loglikelihood_weighted_z_given_y

            z_preds = z_preds.squeeze(dim=1)
        else:
            # during inference only use encoded means and mean class prediction
            y_preds = q_y_x_log
            y_preds_hard = torch.argmax(y_preds, dim=1)
            y_preds_hard = torch.nn.functional.one_hot(y_preds_hard, num_classes=self.num_classes).float()  # (n, k)
            enc_means, enc_logvars = self.encoder(x=x, y=y_preds_hard)
            z_preds = enc_means

        if not use_softmax:
            y_preds = q_y_x_log

        return y_preds, z_preds, y_preds_hard, enc_means, enc_logvars


def log_prob_normal(means: torch.tensor, std_prior: torch.tensor, z: torch.tensor) -> torch.tensor:
        """
        Calculates the log likelihood using a normal distribution 
        Args:
            means_prior (torch.tensor): Mean prior value of the prior_network for both classes
            std_prior (torch.tensor): std prior value of the prior_network for both classes
            z (torch.tensor): latentspace value

        Returns:
            log probability
        """
        log_prob_z = -0.5 * torch.log(2. * torch.tensor(torch.pi)) - torch.log(std_prior) - 0.5 * (z - means) ** 2 / (
                    std_prior ** 2)
        return log_prob_z


def reparameterize(mean: torch.tensor, log_variance: torch.tensor) -> torch.tensor:
        """
        Uses mean and standard deviation to sample from a normal distribution.

        Reparameterization trick: sample a value from the standard normal distribution.
        Multiply this value with the standard deviation described by 'log_variance'
        and add the mean.

        Args:
            mean: mean values of each latent dimension of each sample in a batch
            log_variance: logarithm of the variance of each latent dimension of each
                sample in a batch

        Returns:
            Set of samples derived from the described normal distributions.
            Same size as mean and log_variance.

        """
        std = torch.exp(0.5 * log_variance)
        random_values = torch.randn_like(std)
        return mean + random_values * (std + torch.finfo(torch.float32).eps)

### Counterfactual Generation

In [10]:
class CounterfactualGenerator:
    """
    Class for generating counterfactual. Is used to save the most important features
    """
    def __init__(self, prior, probability_cf:torch.tensor = torch.tensor([0.99])):
        self.probability_cf = probability_cf
        self.prior = prior
        self.epsilon = torch.log(self.probability_cf / (1 - self.probability_cf))
        self.w = None
        self.b = None
        self.mean_prior = None
        self.std_prior = None
    
    
    def local_l2(self,
                 z_preds: torch.tensor,
                 y_preds: torch.tensor,
                 delta: torch.tensor) -> torch.tensor:
        """
        Determine counterfactual examples in latent space using "walk in direction of the gradient (local_l2)" (mode=False)
        Args:
            z_preds (torch.tensor): latent representations of a batch with input data
            y_preds (torch.tensor): class prediction only used to select the classes
            delta (torch.tensor): probability values for the generation of counterfactuals
        Returns:
            z_deltas (torch.tensor): latent variable for counterfactual
        """
        # select the class c (predicted) and k (counterfactual)
        c,k = get_predicted_class_and_cf_class(y_preds)

        # decision function for class c -> log(p(y|c)) and class k
        # covariances are shared
        # f(z) = w*z + b
        # with w = COV^-1*(mu_c-mu_k)
        # and b=b_c-b_k with b_i =-0.5*mu_i^T *COV^-1*mu_i+log(p(y=i))

        wn = self.w[c] - self.w[k]
        b = self.b[c] - self.b[k]

        # counterfactuals when walking in the direction of the gradient
        kappa = -(torch.sum(wn * z_preds, dim=1, keepdim=True) + b - delta) / (
                    torch.sum(wn ** 2, dim=1, keepdim=True))
        
        z_deltas = z_preds + kappa * wn

        return z_deltas
    
    def glob(self,
                 z_preds: torch.tensor,
                 y_preds: torch.tensor,
                 delta: torch.tensor) -> torch.tensor:
        """
        Determine counterfactual examples in latent space using the global method
        Args:
            z_preds (torch.tensor): latent representations of a batch with input data
            y_preds (torch.tensor): class prediction only used to select the classes
            delta (torch.tensor): probability values for the generation of counterfactuals
        Returns:
            z_deltas (torch.tensor): latent variable for counterfactual
        """
        # select the class c (predicted) and k (counterfactual)
        c,k = get_predicted_class_and_cf_class(y_preds)

        # decision function for class c -> log(p(y|c)) and class k
        # covariances are shared
        # f(z) = w*z + b
        # with w = COV^-1*(mu_c-mu_k)
        # and b=b_c-b_k with b_i =-0.5*mu_i^T *COV^-1*mu_i+log(p(y=i))

        wn = self.w[c] - self.w[k]
        b = self.b[c] - self.b[k]
        zk = self.mean_prior[k]
        ## counterfactuals when walking to the cf baseline
        kappa = (torch.sum(wn * z_preds, dim=1, keepdim=True) + b - delta) / (
                torch.sum(wn * (z_preds - zk), dim=1, keepdim=True))
        z_deltas = z_preds + kappa * (zk - z_preds)

        return z_deltas
    
    def local_m(self,
                z_preds: torch.Tensor,
                y_preds: torch.Tensor,
                delta: torch.Tensor) -> torch.tensor:
        """
        Determine counterfactual examples in latent space using the generalized version (local_m)
        local_m optimizes the mahalanobis distance
        Args:
            z_preds (torch.tensor): latent representations of a batch with input data
            y_preds (torch.tensor): class prediction only used to select the classes
            delta (torch.tensor): probability values for the generation of counterfactuals
        Returns:
            z_deltas (torch.tensor): latent variable for counterfactual
        """
        # select the class c (predicted) and k (counterfactual)
        c,k = get_predicted_class_and_cf_class(y_preds)

        wn = self.w[c] - self.w[k]
        b = self.b[c] - self.b[k]

        std_prior = self.std_prior[0].unsqueeze(dim=0)
        sigma = (std_prior ** 2)

        lmbd = -2*(1/torch.sum(sigma*wn*wn, dim=1, keepdim=True))*((torch.sum(-wn * (self.mean_prior[k] + z_preds - self.mean_prior[c]), dim=1, keepdim=True) - b + delta))

        z_deltas = -(lmbd/2)*sigma*wn + (self.mean_prior[k] + z_preds - self.mean_prior[c])

        return z_deltas
    
    def calculate_w_and_b(self):
        """
        Calculates w and b. This is done to save time during training, since the decision boundary only changes during gradient updates.
        After the model is completly trained w and b stay the same during interference, therefore there is no need to recalculate them with the prior_encoder
        """
        if self.mean_prior is None:
            print("Error Mean-Prior not set in Counterfactual Generator")
            raise ValueError
        self.w = (1 / (self.std_prior ** 2)) * self.mean_prior
        self.b = -0.5 * torch.sum(
            self.mean_prior * (1 / (self.std_prior ** 2)) * self.mean_prior,
            dim=1, keepdim=True) + torch.log(self.prior)
        
    
    def calculate_counterfactual_delta_distribution(self,
                                                    z_preds: torch.tensor,
                                                    y_preds: torch.tensor) -> torch.distributions.uniform:
        """
        During training we train by sampling random probabilties for the counterfactual method. 
        Args:
            z_preds (torch.tensor): latent representations of a batch with input data
            y_preds (torch.tensor): class prediction only used to select the classes
        Returns:
            delta (torch.distributions.uniform): Uniform distribution to sample from during training
        """
        c,k = get_predicted_class_and_cf_class(y_preds)
        wn = self.w[c] - self.w[k]
        b = self.b[c] - self.b[k]
        zk = self.mean_prior[k]
        
        fmin = torch.sum(wn * zk, dim=1, keepdim=True) + b

        fmax = torch.sum(wn * z_preds, dim=1, keepdim=True) + b
        epsilon_max = torch.min(fmax, self.epsilon)
        epsilon_min = torch.max(fmin, -self.epsilon)
        el = torch.min(epsilon_max, epsilon_min)
        eu = torch.max(epsilon_max, epsilon_min)

        el = torch.where(el < eu, el, -self.epsilon)
        eu = torch.where(eu > el, eu, self.epsilon)
        delta = torch.distributions.Uniform(el, eu)

        return delta

### Counterfactual Helper

In [11]:
def get_predicted_class_and_cf_class(y_preds: torch.tensor) -> Tuple[torch.tensor, torch.tensor]:
    """
    Args:
        y_preds (torch.tensor): class prediction
    Returns:
        c (torch.tensor): Original Class
        k (torch.tensor): Counterfactual Class
    """
    # Works with multiclass
    # Make sure to get the class with the highest prediction
    y_samples_class_idx = torch.argmax(y_preds, dim=1)
    # Make sure to get the class with the lowest prediction
    y_samples_neg_class_idx = torch.argmin(y_preds, dim=1)

    c = y_samples_class_idx.detach()
    k = y_samples_neg_class_idx.detach()
    c = c.to(torch.long)
    k = k.to(torch.long)
    return c, k

def get_prediction_delta(y_preds: torch.tensor, delta: torch.tensor):
    """
    Args:
        y_preds (torch.tensor): class prediction only used to select the classes works for multiple classes
        delta (torch.tensor): pregenerated by calculate_counterfactuals_delta
    Returns:
        y_preds_delta (torch.tensor): probability values formated to be passed to the decoder
    """
    c, k = get_predicted_class_and_cf_class(y_preds)
    # Create a tensor of zeros with the same shape as y_preds_delta
    y_preds_delta = torch.zeros_like(y_preds).float()

    # Set the values using advanced indexing
    y_preds_delta[torch.arange(y_preds.shape[0]), c] = torch.sigmoid(delta.squeeze())
    y_preds_delta[torch.arange(y_preds.shape[0]), k] = 1.0 - y_preds_delta[torch.arange(y_preds.shape[0]), c]

    return y_preds_delta



GdVAE

In [12]:
class GDVAE(torch.nn.Module):
    """
    Gaussian Discriminant Variational Autoencoder (GDVAE).
    
    This class encapsulates all models essential for constructing a GDVAE, managing their interaction. 
    """


    def __init__(self,
                 input_shape: Tuple[int, int, int] = (3, 64, 64),
                 latent_size: int = 64,
                 num_classes: int = 2,
                 ALPHA: float = 1.0,
                 BETA: float = 1.0,
                 GAMMA: float = 1.0
                 ):
        """

        Args:
            input_shape (int): input size of the images
            latent_size (int): size of the latent space
            num_classes (int): amount of classes
            ALPHA (float): Weight for MSE and KLD-Loss
            BETA (float): Weight for MSE and KLD_Prior-Loss
            GAMMA (float): Weight for Consistency-Loss

        """
        super().__init__()

        # Init Variables
        self.num_classes = num_classes
        self.ALPHA = ALPHA
        self.BETA = BETA
        self.GAMMA = GAMMA

        # Settings
        num_pixel = (input_shape[0]*input_shape[1]*input_shape[2])
        self.global_mode = True
        self.sample_amount_cf = 10
        # 10% of num_pixel
        self.ce_weight = 0.1
        # Weights
        self.kld_weight = latent_size / num_pixel

        # Init Probability for Counterfactuals
        # probability range for generation of counterfactuals during training
        # e.g. for 0.99 counterfactuals are generated in the range [0.99,0.01]
        self.probability_cf = torch.tensor(0.99)
    
        # Init Dataloaders
        self.train_loader, self.test_loader = train_loader, test_loader
        # Init Encoder
        self.encoder = CelebAEncoder(input_shape=input_shape, latent_size=latent_size, num_classes=num_classes)
        # Init Decoder
        self.decoder = CelebADecoder(input_shape=input_shape, latent_size=latent_size, num_classes=num_classes)
        # Init Prior Network
        self.prior_network = CelebAPriorEncoder(latent_size=latent_size, num_classes=num_classes)
        # Init GDA Sampling Method
        self.gda = CelebAGDA(self.prior_network, self.encoder, self.train_loader)
        
        self.cf_generator = CounterfactualGenerator(self.gda.prior,self.probability_cf)

    def encode(self,
               x: torch.tensor,
               y: torch.tensor) -> \
            Tuple[torch.tensor, torch.tensor]:
        """
        Transform a batch with input images and their labels into latent representations consisting of normal
        distribution parameters per latent dimension.

        Args:
            x (torch.tensor): batch with input data
            y (torch.tensor): one hot encoded class labels for x

        Returns:
            (torch.tensor): mean values of the latent representation
            (torch.tensor): log(variances) of the latent representation
        """

        mean, log_variance = self.encoder(x, y)
        return mean, log_variance

    def decode(self, z: torch.tensor, y: torch.Tensor) -> torch.tensor:
        """
        Transform a batch with scalar latent representations into reconstructions of the input images.

        Args:
            z (torch.tensor): latent representations of a batch with input data
            y (torch.tensor): class prediction
        Returns:
            (torch.tensor): Reconstructions of 'x'
        """

        reconstruction = self.decoder(z, y)
        return reconstruction

    

## Visualization Helper

In [13]:
def normalize(image):
    """
    Normalizes a given input image
    Args:
        image (torch.tensor): input_image
    Returns:
         torch.tensor: normalized input image
    """
    image[:, :, :] -= torch.min(image[:, :, :])
    image[:, :, :] /= torch.max(image[:, :, :])
    return image

## Loss Functions
Includes the CVAE-Loss and the Consistency-Loss Function

In [14]:
def loss_function_cvae(x: torch.tensor,
                       reconstruction: torch.tensor,
                       mean: torch.tensor,
                       log_variance: torch.tensor,
                       mean_prior: torch.tensor,
                       log_variance_prior: torch.tensor,
                       eps: float = 1.0e-5) \
        -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
    """
    Calculates the GDVAE loss, an extended CVAE loss as described in Eq. 7+8. 

    Args:
        x: Input batch.
        reconstruction: Reconstruction (decoder output) of the input batch 'x'.
        mean: Mean values of the latent space for each sample in 'x'.
        log_variance: ln of the variances of the latent space for each sample in 'x'
        mean_prior: mean values of the conditional latent prior distributions
        log_variance_prior: ln of the variances of the conditional latent prior distributions.
        eps: Value added to the denominator for numerical stability. Default: 1e-5.

    Returns:
        MSE. Reconstruction loss.
        Kullback-Leibler Divergence (KLD). KLD between prior distribution and latent
            distribution.
        Kullback-Leibler Divergence (KLD). KLD between prior distribution and standard normal
            distribution

    """
    # log probability for reconstruction
    p_x_z = torch.distributions.Normal(loc=x, scale=torch.ones_like(x) * 0.6)
    log_prob_p_x_z = p_x_z.log_prob(reconstruction)

    mse_reconstruction = -torch.sum(torch.mean(log_prob_p_x_z, dim=(1, 2, 3)))

    kld_prior_normal = -0.5 * torch.sum(torch.mean(
        1 + log_variance_prior - mean_prior.pow(2) - log_variance_prior.exp()
        , dim=1))
    kld_prior_latent = torch.sum(torch.mean(
        0.5 * (log_variance_prior - log_variance
               + (log_variance.exp() + (mean - mean_prior).pow(2)) /
               (log_variance_prior.exp() + eps) - 1.)
        , dim=1))

    return mse_reconstruction, kld_prior_latent, kld_prior_normal

def loss_function_consistency(mean_decoder: torch.tensor,
                              log_variance_decoder: torch.tensor,
                              mean: torch.tensor,
                              log_variance: torch.tensor,
                              eps: float = 1.0e-5) \
        -> torch.tensor:
    """
    Calculates the consistency loss for counterfactual examples and uses
    Kullback-Leibler Divergence between reconstruction x^delta and 
    input q(z|x) that is modified by the CF method. 
    
    Args:
        mean_decoder: Mean values of the latent space for each sample after resconstruction and gda 'x^delta'. (reconstruction)
        log_variance_decoder: ln of the variances of the latent space for each sample in 'x^delta' (reconstruction)
        mean: Mean values of the latent space for each sample in 'x'.
        log_variance: ln of the variances of the latent space for each sample in 'x'
        eps: Value added to the denominator for numerical stability. Default: 1e-5.

    Returns:
        Kullback-Leibler Divergence (KLD). KLD between q(z|x^delta) and q(z^delta|x) distributions.
    """

    # calculate KLD for two multivariate normal distributions (latent distribution from input and decoder)
    # 0.5 * torch.sum(log(std^2)+(mean^2)-1-std^2)
    kld_latent_consistency = torch.sum(torch.mean(
        0.5 * (log_variance - log_variance_decoder
               + (log_variance_decoder.exp() + (mean_decoder - mean).pow(2)) /
               (log_variance.exp() + eps) - 1.)
        , dim=1))

    return kld_latent_consistency

## Train Routine and Test Routine

In [15]:
def train_single_epoch(model: GDVAE, opti: torch.optim.Adam, with_generalized: bool) -> Tuple[torch.tensor, int, torch.tensor]:
    """
    Wrapper for single train epoch
    Args:
        model: GDVAE model to train
        opti: Optimizer of GDVAE model
        with_generalized: If the generalized cf generation should be included during training
    Returns:
        torch.tensor: loss of epoch
        int: Accuracy of GDA
        torch.tensor: reconstruction loss of vae

    """
    loss, acc, mse = single_epoch(model, opti, train=True, include_generalized=with_generalized)
    return loss, acc, mse

@torch.no_grad()
def test_single_epoch(model: GDVAE, opti: torch.optim.Adam, with_generalized: bool) -> Tuple[torch.tensor, int, torch.tensor]:
    """
    Wrapper to disable Gradiant Compution on test run
    Args:
        model: GDVAE model to train
        opti: Optimizer of GDVAE model
        with_generalized: If the generalized cf generation should be included during training
    Returns:
        torch.tensor: loss of epoch
        int: Accuracy of GDA
        torch.tensor: reconstruction loss of vae

    """
    loss, acc, mse = single_epoch(model, opti, train=False, include_generalized=with_generalized)
    return loss, acc, mse

def single_epoch(model: GDVAE, opti: torch.optim.Adam, train: bool=True, include_generalized: bool=False) -> Tuple[torch.tensor, int, torch.tensor]:
    """
    Train a model for a single epoch.

    Args:
        model (GDVAE): Model
        opti (torch.optim.Adam): Optimizer
        train: running single epoch in training or test mode
        include_generalized: If the generalized cf generation should be included during training
    Returns:
        (torch.tensor): training loss
        (int): training accuracy
        (torch.tensor): reconstruction loss of vae
    """
    y_preds_cat = torch.tensor([]).to(device)  # GDA class predictions
    y_gts = torch.tensor([]).to(device)  # ground truth class labels

    running_loss = 0.0
    running_loss_cvae_mse = 0
    running_counter = 0
    if train:
        loader = model.train_loader
    else:
        loader = model.test_loader

    for (batch_data, y) in tqdm(loader):
        batch_data = batch_data.to(device)  # (n,1,h,w)
        y = y.to(device).squeeze()
        y_one_hot = torch.nn.functional.one_hot(y, num_classes=model.num_classes).float().to(device).squeeze()  # (n, k)
        if train:
            opti.zero_grad()

        y_preds, z_preds, y_preds_hard, enc_means, enc_logvars = model.gda(batch_data)

        y_gts = torch.cat([y_gts, y], dim=0)
        y_preds_cat = torch.cat([y_preds_cat, y_preds.detach()], dim=0)

        mean_prior, log_variance_prior = model.prior_network(y_one_hot)  # --> p(y)
        mean, log_variance = model.encode(batch_data, y_one_hot)  # --> q(z|x,y)

        reconstruction = model.decode(z_preds, y_preds_hard)  # (n,1,h,w) --> p(x|z,y)

        # classification_loss
        classification_loss = torch.nn.functional.cross_entropy(y_preds, y, reduction='sum')

        # reconstruction loss, kld, kld_prior
        cvae_mse, cvae_kld, cvae_kld_prior = loss_function_cvae(batch_data, reconstruction, mean, log_variance,
                                                                mean_prior, log_variance_prior)

        loss_cvae = (model.ALPHA + model.BETA) * cvae_mse +\
                    model.ALPHA * model.kld_weight * cvae_kld +\
                    model.BETA * model.kld_weight * cvae_kld_prior

        loss_ce = model.ce_weight * model.BETA * classification_loss

        # consistency loss
        _, _, _, enc_means_decoder, enc_logvars_decoder = model.gda(reconstruction)

        consistency_kld = loss_function_consistency(enc_means_decoder, enc_logvars_decoder, enc_means,
                                                    enc_logvars) / (model.sample_amount_cf + 1)

        # calculate consistency loss to q(z|x) (cycle consistency for reconstructions and counterfactuals)
        consistency_delta_kld = torch.tensor(0.).to(device)
        # Pregenerate w and b
        model.cf_generator.mean_prior, model.cf_generator.std_prior = model.prior_network.get_prior_distributions_parameters()
        model.cf_generator.calculate_w_and_b()
        delta_distribution = model.cf_generator.calculate_counterfactual_delta_distribution(enc_means, y_one_hot)   
        # Create 10 Counterfactuals for every Image
        for i in range(0, model.sample_amount_cf):
            delta = delta_distribution.rsample()
            delta = delta.view(batch_size, 1)
            y_preds_delta = get_prediction_delta(y_preds, delta)
            if include_generalized:
                random_mode = torch.randint(0, 3, ())
            else:
                random_mode = torch.randint(0, 2, ())
            # Randomly switch between local_l2 and glob
            if random_mode == 0:
                enc_means_delta = model.cf_generator.local_l2( enc_means, y_one_hot, delta)
            elif random_mode == 1:
                enc_means_delta = model.cf_generator.glob( enc_means, y_one_hot, delta)
            elif random_mode == 2:
                enc_means_delta = model.cf_generator.local_m( enc_means, y_one_hot, delta)
            else:
                print("Random_Mode Error")
                raise ValueError
            
            enc_logvars_delta = enc_logvars.clone()

            x_deltas_rec = model.decode(enc_means_delta, y_preds_delta)  # (n,1,h,w)

            _, _, _, enc_means_delta_rec, enc_logvars_delta_rec = model.gda(x_deltas_rec)

            consistency_delta_kld += loss_function_consistency(enc_means_delta_rec,
                                                               enc_logvars_delta_rec,
                                                               enc_means_delta,
                                                               enc_logvars_delta) / (model.sample_amount_cf + 1)

        loss_consistency_kld = model.kld_weight * model.GAMMA * consistency_kld + \
                               model.kld_weight * model.GAMMA * consistency_delta_kld

        loss = loss_cvae + loss_ce + loss_consistency_kld
        if train:
            loss.backward()
            opti.step()

        running_counter += batch_data.shape[0]
        running_loss_cvae_mse += (model.ALPHA + model.BETA) * cvae_mse.item()
        running_loss += loss.item()
    y_preds_cat = torch.argmax(y_preds_cat, dim=1).cpu()
    accuracy = accuracy_score(y_gts.cpu(), y_preds_cat)

    return (running_loss/running_counter), accuracy, (running_loss_cvae_mse/running_counter)

## Training GdVAE (CelebA)
To test the counterfactual generation process without prior training, we provide pre-trained weights (gdvae_celeba_pretrained.pth). Thus, it is unnecessary to perform this step, and you may proceed directly to the final stage of generating counterfactuals.

In [None]:
gdvae = GDVAE().to(device)

EPOCHS = 24

# Include the Generalized Counterfactual Generation during training?
# NOTE: Paper results don't include the generalized cf (local_m) generation during training
with_generalized = False

# Optimizer
optimizer = torch.optim.Adam(gdvae.parameters(),
                             lr=0.0005,
                             weight_decay=0.0)
for epoch in range(EPOCHS):
    gdvae.train()
    loss, acc, loss_mse = train_single_epoch(gdvae, optimizer, with_generalized)
    print(f"Epoch {epoch:04d} "
          f"train_loss {(loss):.4f} "
          f"train_accuracy {acc:.4f} "
          f"train_mse {loss_mse:.4f}")
    gdvae.eval()
    loss, acc, loss_mse = test_single_epoch(gdvae, optimizer, with_generalized)
    print(f"Epoch {epoch:04d} "
          f"test_loss {(loss):.4f} "
          f"test_accuracy {acc:.4f} "
          f"test_mse {loss_mse:.4f}")
    torch.save(gdvae.state_dict(), f"./gdvae_celeba.pth")

  0%|          | 0/2543 [00:00<?, ?it/s]

## Generating Counterfactuals
The counterfactual generation mode can be switched between global and local_l2 by changing cf_mode to True(global) or False(local_l2).

In [None]:
gdvae = GDVAE().to(device)

# Load Weights
gdvae.load_state_dict(torch.load(f"./gdvae_celeba_pretrained.pth", map_location=device))
gdvae.eval()

# Counterfactual Mode
# Global = True
# Local_l2 = False
cf_mode = False

# Counterfactual probabilities
cf_probs = np.array([0.95, 0.75, 0.5, 0.25, 0.05])

# Figure
fig, axs = plt.subplots(1, 2, figsize=(15, 8))
with torch.no_grad():
    (batch_data, y) = next(iter(gdvae.test_loader))
    batch_data = batch_data.to(device)
    plt.imshow(torch.permute(normalize(batch_data[0]), (1, 2, 0)).cpu().detach().numpy())
    axs[0].imshow(torch.permute(normalize(batch_data[0]), (1, 2, 0)).cpu().detach().numpy())
    axs[0].set_title(f"Input")
    axs[0].set_xticklabels([]),axs[0].set_yticklabels([])
    y = y.to(device).squeeze()
    y_one_hot = torch.nn.functional.one_hot(y, num_classes=gdvae.num_classes).float().to(device).squeeze()
    _, _, _, enc_means, _ = gdvae.gda(batch_data)
    reco = gdvae.decode(enc_means, y_one_hot)
    axs[1].imshow(torch.permute(normalize(reco[0]), (1, 2, 0)).cpu().detach().numpy())
    axs[1].set_title(f"Recon")
    axs[1].set_xticklabels([]),axs[0].set_yticklabels([])
    plt.show()
    fig, axs = plt.subplots(1, len(cf_probs), figsize=(15, 8))
    gdvae.cf_generator.mean_prior, gdvae.cf_generator.std_prior = gdvae.prior_network.get_prior_distributions_parameters()
    gdvae.cf_generator.calculate_w_and_b()
    for i, probs in enumerate(cf_probs):
        probvalue = torch.log(torch.tensor([probs]) / (1 - torch.tensor([probs]))).float().to(device)
        delta = probvalue.repeat(enc_means.shape[0], 1)
        y_pred_delta_gt = get_prediction_delta(y_one_hot, delta)
        if cf_mode:
            z_deltas = gdvae.cf_generator.glob(enc_means, y_one_hot, delta)
        else:
            z_deltas = gdvae.cf_generator.local_l2(enc_means, y_one_hot, delta)
        counterfactual = gdvae.decode(z_deltas.to(torch.float32), y_pred_delta_gt)
        axs[i].imshow(torch.permute(normalize(counterfactual[0]), (1, 2, 0)).cpu().detach().numpy())
        axs[i].set_title(f"Prob:{probs}")
        axs[i].set_xticklabels([]),axs[i].set_yticklabels([])
plt.show()

## Generating Counterfactuals (Generalized,local_m)
Create Counterfactuals with the generalized cf generator

In [None]:
gdvae = GDVAE().to(device)

# Load Weights
gdvae.load_state_dict(torch.load(f"./gdvae_celeba_pretrained.pth", map_location=device))
gdvae.eval()
# Counterfactual Mode
# Global = True
# Local = False
cf_mode = True

# Counterfactual probabilities
cf_probs = np.array([0.95, 0.75, 0.5, 0.25, 0.05])

# Figure
fig, axs = plt.subplots(1, 2, figsize=(15, 8))
with torch.no_grad():
    (batch_data, y) = next(iter(gdvae.test_loader))
    batch_data = batch_data.to(device)
    plt.imshow(torch.permute(normalize(batch_data[0]), (1, 2, 0)).cpu().detach().numpy())
    axs[0].imshow(torch.permute(normalize(batch_data[0]), (1, 2, 0)).cpu().detach().numpy())
    axs[0].set_title(f"Input")
    axs[0].set_xticklabels([]),axs[0].set_yticklabels([])
    y = y.to(device).squeeze()
    y_one_hot = torch.nn.functional.one_hot(y, num_classes=gdvae.num_classes).float().to(device).squeeze()
    _, _, _, enc_means, _ = gdvae.gda(batch_data)
    reco = gdvae.decode(enc_means, y_one_hot)
    axs[1].imshow(torch.permute(normalize(reco[0]), (1, 2, 0)).cpu().detach().numpy())
    axs[1].set_title(f"Recon")
    axs[1].set_xticklabels([]),axs[0].set_yticklabels([])
    plt.show()
    fig, axs = plt.subplots(1, len(cf_probs), figsize=(15, 8))
    gdvae.cf_generator.mean_prior, gdvae.cf_generator.std_prior = gdvae.prior_network.get_prior_distributions_parameters()
    gdvae.cf_generator.calculate_w_and_b()
    for i, probs in enumerate(cf_probs):
        probvalue = torch.log(torch.tensor([probs]) / (1 - torch.tensor([probs]))).float().to(device)
        delta = probvalue.repeat(enc_means.shape[0], 1)
        y_pred_delta_gt = get_prediction_delta(y_one_hot, delta)
        z_deltas  = gdvae.cf_generator.local_m(enc_means, y_one_hot, delta)
        counterfactual = gdvae.decode(z_deltas.to(torch.float32), y_pred_delta_gt)
        axs[i].imshow(torch.permute(normalize(counterfactual[0]), (1, 2, 0)).cpu().detach().numpy())
        axs[i].set_title(f"Prob:{probs}")
        axs[i].set_xticklabels([]),axs[i].set_yticklabels([])
plt.show()