# [BatchBALD](https://blackhc.github.io/batchbald_redux/batchbald.html)

BALD creates $K$ plausible parameter settings under the posterior distribution and generates $K$ prediction probabilities from these different parameter settings. Instead of training seperate models, BALD samples the distribution of possible models given the data.

Each parameter setting in BALD then produces a probability distribution over the classes. We then measure the
1. Marginal Entropy $H(y|x,D)$: 
    - Measures the overall uncertainty about the class of data point $x$, considering all the different probability distributions from all the "models." 
    - This tells us how uncertain the ensemble is in general.
2. Expected Conditinoal Entropy $E_{\theta\sim p(\theta|D)}\big[H(y|x,\theta)\big]$: 
    - Measures the average uncertainity of each individual model with parameter setting $\theta$ that follows from posterior distribution $p(\theta|D)$. 
    - It tells us how certain an individual model is.
We then integrate over all possible parameter settings $\theta$
$$
E_{\theta \sim p(\theta | \mathcal{D})} [H[y|x, \theta]] = \int p(\theta | \mathcal{D}) H[y|x, \theta] d\theta
$$

The selection criteria is then given by

$$ 
\arg \max_x H(y|x,D) - E_{\theta \sim p(\theta|D)}\big[H(y|x,\theta)\big]
$$

This can be viewed as the difference between the uncertainty of the entire ensemble minus the certainty of a single model. As such, BALD is looking for an $x$ where the ensemble is very uncertain, but the individual models are certain. That is, finding the $x$ where the parameters under the posterior distribution disagree about the outcome the most, but the ensemble is still certain.

This selection criteria captures the difference in disagreements among the plausible model configurations.

**I bet there can be some connection to Rashomon here regarding the fact that BALD *samples* the plausible parameter settings from the posterior but TreeFarms exhaustively generates all trees**

## Set up

In [1]:
import math
from dataclasses import dataclass
from typing import List

import torch
from toma import toma
from tqdm.auto import tqdm

from batchbald_redux import joint_entropy

  from .autonotebook import tqdm as notebook_tqdm


## Data generation

In [2]:
K = 20

In [3]:
# import numpy as np


# def get_mixture_prob_dist(p1, p2, m):
#     return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


# p1 = [0.7, 0.1, 0.1, 0.1]
# p2 = [0.3, 0.3, 0.2, 0.2]
# y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

# p1 = [0.1, 0.7, 0.1, 0.1]
# p2 = [0.2, 0.3, 0.3, 0.2]
# y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

# p1 = [0.1, 0.1, 0.7, 0.1]
# p2 = [0.2, 0.2, 0.3, 0.3]
# y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

# p1 = [0.1, 0.1, 0.1, 0.7]
# p2 = [0.3, 0.2, 0.2, 0.3]
# y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

# p1 = [0.2, 0.3, 0.4, 0.1]
# p2 = [0.1, 0.4, 0.3, 0.2]
# y5_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


# def nested_to_tensor(l):
#     return torch.stack(list(map(torch.as_tensor, l)))


# ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws, y5_ws])

  return torch.stack(list(map(torch.as_tensor, l)))


In [4]:
ys_ws.shape

torch.Size([5, 20, 4])

In [5]:
ys_ws

tensor([[[0.7000, 0.1000, 0.1000, 0.1000],
         [0.6789, 0.1105, 0.1053, 0.1053],
         [0.6579, 0.1211, 0.1105, 0.1105],
         [0.6368, 0.1316, 0.1158, 0.1158],
         [0.6158, 0.1421, 0.1211, 0.1211],
         [0.5947, 0.1526, 0.1263, 0.1263],
         [0.5737, 0.1632, 0.1316, 0.1316],
         [0.5526, 0.1737, 0.1368, 0.1368],
         [0.5316, 0.1842, 0.1421, 0.1421],
         [0.5105, 0.1947, 0.1474, 0.1474],
         [0.4895, 0.2053, 0.1526, 0.1526],
         [0.4684, 0.2158, 0.1579, 0.1579],
         [0.4474, 0.2263, 0.1632, 0.1632],
         [0.4263, 0.2368, 0.1684, 0.1684],
         [0.4053, 0.2474, 0.1737, 0.1737],
         [0.3842, 0.2579, 0.1789, 0.1789],
         [0.3632, 0.2684, 0.1842, 0.1842],
         [0.3421, 0.2789, 0.1895, 0.1895],
         [0.3211, 0.2895, 0.1947, 0.1947],
         [0.3000, 0.3000, 0.2000, 0.2000]],

        [[0.1000, 0.7000, 0.1000, 0.1000],
         [0.1053, 0.6789, 0.1105, 0.1053],
         [0.1105, 0.6579, 0.1211, 0.1105],
         

## Conditional Entropies and Batched Entropies


In [6]:
def ComputeConditionalEntropyFunction(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    
    ### Set Up ###
    N, K, C = log_probs_N_K_C.shape
    entropies_N = torch.empty(N, dtype=torch.double)
    pbar = tqdm(total=N, desc="Conditional Entropy", leave=False)


    ### Compute entropy ###
    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        EntropyVals = log_probs_n_K_C * torch.exp(log_probs_n_K_C)
        entropies_N[start:end].copy_(-torch.sum(EntropyVals, dim=(1, 2)) / K)
        pbar.update(end - start)
    pbar.close()

    ### Return ###
    return entropies_N


def ComputeEntropyFunction(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    
    ### Set Up ###
    N, K, C = log_probs_N_K_C.shape
    entropies_N = torch.empty(N, dtype=torch.double)
    pbar = tqdm(total=N, desc="Entropy", leave=False)

    ### Compute entropy ###
    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
        nats_n_C = mean_log_probs_n_C * torch.exp(mean_log_probs_n_C)
        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
        pbar.update(end - start)

    ### Return ###
    pbar.close()

    return entropies_N

### Examples

In [7]:
conditional_entropies = ComputeConditionalEntropyFunction(ys_ws.log())

print(conditional_entropies)


Conditional Entropy:   0%|          | 0/5 [00:00<?, ?it/s]

                                                          

tensor([1.2069, 1.2069, 1.2069, 1.2069, 1.2952], dtype=torch.float64)




In [8]:
entropies = ComputeEntropyFunction(ys_ws.log())

print(entropies)


                                              

tensor([1.2376, 1.2376, 1.2376, 1.2376, 1.3040], dtype=torch.float64)




## BALD

In [9]:
@dataclass
class CandidateBatch:
    scores: List[float]
    indices: List[int]


def get_batchbald_batch(
    log_probs_N_K_C: torch.Tensor, batch_size: int, num_samples: int, dtype=None, device=None
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    conditional_entropies_N = ComputeConditionalEntropyFunction(log_probs_N_K_C)

    batch_joint_entropy = joint_entropy.DynamicJointEntropy(
        num_samples, batch_size - 1, K, C, dtype=dtype, device=device
    )

    # We always keep these on the CPU.
    scores_N = torch.empty(N, dtype=torch.double, pin_memory=torch.cuda.is_available())

    for i in tqdm(range(batch_size), desc="BatchBALD", leave=False):
        if i > 0:
            latest_index = candidate_indices[-1]
            batch_joint_entropy.add_variables(log_probs_N_K_C[latest_index : latest_index + 1])

        shared_conditinal_entropies = conditional_entropies_N[candidate_indices].sum()

        batch_joint_entropy.compute_batch(log_probs_N_K_C, output_entropies_B=scores_N)

        scores_N -= conditional_entropies_N + shared_conditinal_entropies
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example

In [10]:
get_batchbald_batch(ys_ws.log().double(), 4, 1000, dtype=torch.double)


                                                          

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158337, 0.08691070514744714, 0.11275304532467878], indices=[0, 3, 1, 2])

## Batch BALD

In [11]:
def BaldSelectorFunction(log_probs_N_K_C: torch.Tensor, 
                         batch_size: int) -> CandidateBatch:
    
    ### Set Up ###
    UncertaintyMetrics = []   
    N, K, C = log_probs_N_K_C.shape
    Indices = []
    batch_size = min(batch_size, N)


    ### Compute Uncertainty Metrics ###
    EnsembleEntropy = ComputeEntropyFunction(log_probs_N_K_C)
    ConditionalEntropy = ComputeConditionalEntropyFunction(log_probs_N_K_C)
    UncertaintyMetrics = EnsembleEntropy - ConditionalEntropy

    ### Get Top Scores ###
    UncertaintyMetrics, Indices = torch.topk(UncertaintyMetrics, batch_size)

    return CandidateBatch(UncertaintyMetrics.tolist(), Indices.tolist())

### Example

In [12]:
BatchBaldResults = BaldSelectorFunction(ys_ws.log().double(), 1)
print(BatchBaldResults.scores)
print(BatchBaldResults.indices)

                                                          

[0.030715639666234917]
[0]


