In [1]:
# default_exp batchbald

In [2]:
# hide
import blackhc.project.script
from nbdev.showdoc import *

Neither src found as subdirectory in %s nor was a notebooks directory found!
%load_ext autoreload
%autoreload 2


# BatchBALD Algorithm
> Greedy algorithm and score computation

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$. 
Then, we will implement BatchBALD and BALD.

In [3]:
# exports
import math
from dataclasses import dataclass
from typing import List

import torch
from toma import toma
from tqdm.auto import tqdm

from batchbald_redux import joint_entropy

We are going to define a couple of sampled distributions to use for our testing our code.

$K=20$ means 20 inference samples.

In [4]:
K = 20

In [5]:
import numpy as np


def get_mixture_prob_dist(p1, p2, m):
    return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])

  return torch.stack(list(map(torch.as_tensor, l)))


In [6]:
# 4 classes probs (sum = 1) x 20 inference samples (dropouts, models)
# y1_ws -- just one object

In [7]:
# hide

p = [0.25, 0.25, 0.25, 0.25]
yu_ws = [p for m in range(K)]
yus_ws = nested_to_tensor([yu_ws] * 4) # 0.25 is everywhere in that shape

In [8]:
ys_ws.shape # 4 objs x 20 inference samples x 4 classes

torch.Size([4, 20, 4])

## Conditional Entropies and Batched Entropies

To start with, we write two functions to compute the conditional entropy $H[y_i|w]$ and the entropy $H[y_i]$ for each input sample.

In [9]:
def compute_conditional_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor: # help to check at the end
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double) # good thing to immediately use data type

    pbar = tqdm(total=N, desc="Conditional Entropy", leave=False) # progress bar
#     print("probs_N_K_C:", probs_N_K_C)

    @toma.execute.chunked(probs_N_K_C, 1024) # to allocate GPU memory starting from batch=1024 and decresing if needed
    def compute(probs_n_K_C, start: int, end: int):
#     ^ above is an interface of toma: def compute(chunk, start, end):
#     N (number of objs) can be really big -- make a chunk of N: n

    # (chunk) chunk result and pass the chunks to compute_result one by one
    # range iterates over range(start, end, step)
    # execute -- just to exec func
    
    # probs_n_K_C, start, end -- depending on batch size will select
    # for 4*20*4 = 320 will take all at once

#         print("probs_n_K_C:", probs_n_K_C) # is the same as probs_N_K_C but not defined
#         print("start:", start) # 0
#         print("end:", end) # 4
        nats_n_K_C = probs_n_K_C * torch.log(probs_n_K_C)
        # to remove infs from log of zeros
        nats_n_K_C[probs_n_K_C == 0] = 0.0

        # summed for all inference samp-s (simultaneosly) and classes and average over samples at the end
        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K) # write chunk by chunk # return a copy
        pbar.update(end - start)

    pbar.close()

    return entropies_N


def compute_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Entropy", leave=False)

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        mean_probs_n_C = probs_n_K_C.mean(dim=1) # from start operate with mean by inference samples probs
        nats_n_C = mean_probs_n_C * torch.log(mean_probs_n_C) # entr of mean probs
        nats_n_C[mean_probs_n_C == 0] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1)) # at the end sum by classes
        pbar.update(end - start)

    pbar.close()

    return entropies_N

In [10]:
# Make sure everything is computed correctly.

assert np.allclose(compute_conditional_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

However, our neural networks usually use a `log_softmax` as final layer. To avoid having to call `.exp_()`, which is easy to miss and annoying to debug, we will instead use a version that uses `log_probs` instead of `probs`.

In [11]:
# exports


def compute_conditional_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Conditional Entropy", leave=False)

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        nats_n_K_C = log_probs_n_K_C * torch.exp(log_probs_n_K_C)

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K) # mean by samps at the end
        pbar.update(end - start)

    pbar.close()

    return entropies_N


def compute_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Entropy", leave=False)

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K) # mean of log probs
        nats_n_C = mean_log_probs_n_C * torch.exp(mean_log_probs_n_C)

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1)) # sum by classes at the end
        pbar.update(end - start)

    pbar.close()

    return entropies_N

In [12]:
# hide

# Make sure everything is computed correctly.
assert np.allclose(compute_conditional_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

### Examples

In [13]:
conditional_entropies = compute_conditional_entropy(ys_ws.log())

print(conditional_entropies)

assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069], atol=0.01)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2069, 1.2069, 1.2069, 1.2069], dtype=torch.float64)


In [14]:
entropies = compute_entropy(ys_ws.log())

print(entropies)

assert np.allclose(entropies, [1.2376, 1.2376, 1.2376, 1.2376], atol=0.01)

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2376, 1.2376, 1.2376, 1.2376], dtype=torch.float64)


## BatchBALD

To compute BatchBALD exactly for a candidate batch, we'd have to compute $I[(y_b)_B;w] = H[(y_b)_B] - H[(y_b)_B|w]$.

As the $y_b$ are independent given $w$, we can simplify $H[(y_b)_B|w] = \sum_b H[y_b|w]$.

Furthermore, we use a greedy algorithm to build up the candidate batch, so $y_1,\dots,y_{B-1}$ will stay fixed as we determine $y_{B}$. We compute
$H[(y_b)_{B-1}, y_i] - H[y_i|w]$ for each pool element $y_i$ and add the highest scorer as $y_{B}$.

We don't utilize the last optimization here in order to compute the actual scores.


### In the Paper

![BatchBALD algorithm in the paper](batchbald_algorithm.png)


### Implementation

In [15]:
# exports


@dataclass # make it easier to create dataclasses
class CandidateBatch:
    scores: List[float]
    indices: List[int]


def get_batchbald_batch(
    log_probs_N_K_C: torch.Tensor, batch_size: int, num_samples: int, dtype=None, device=None
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices) # returns class with scores and indeces

    conditional_entropies_N = compute_conditional_entropy(log_probs_N_K_C)
    # all above ?
    batch_joint_entropy = joint_entropy.DynamicJointEntropy(
        num_samples, batch_size - 1, K, C, dtype=dtype, device=device
    )

    # We always keep these on the CPU. # why cuda_is_available in this case?
    scores_N = torch.empty(N, dtype=torch.double, pin_memory=torch.cuda.is_available())

    for i in tqdm(range(batch_size), desc="BatchBALD", leave=False):
        if i > 0:
            latest_index = candidate_indices[-1]
            batch_joint_entropy.add_variables(log_probs_N_K_C[latest_index : latest_index + 1])

        shared_conditinal_entropies = conditional_entropies_N[candidate_indices].sum()

        batch_joint_entropy.compute_batch(log_probs_N_K_C, output_entropies_B=scores_N)

        scores_N -= conditional_entropies_N + shared_conditinal_entropies
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example

In [16]:
get_batchbald_batch(ys_ws.log().double(), 4, 1000, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878], indices=[1, 0, 2, 3])

## BALD

BALD is the same as BatchBALD, except that we evaluate points individually, by computing $I[y_i;w]$ for each, and then take the top $B$ scorers.

In [17]:
# exports


def get_bald_batch(log_probs_N_K_C: torch.Tensor, batch_size: int, dtype=None, device=None) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    scores_N = -compute_conditional_entropy(log_probs_N_K_C)
    scores_N += compute_entropy(log_probs_N_K_C)

    candiate_scores, candidate_indices = torch.topk(scores_N, batch_size)

    return CandidateBatch(candiate_scores.tolist(), candidate_indices.tolist())

### Example

In [18]:
get_bald_batch(ys_ws.log().double(), 4)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.030715639666234917, 0.030715639666234917, 0.030715639666234695], indices=[1, 2, 0, 3])

## Large Batch BALD

In [19]:
# # way #1

# def compute_entropy_vec(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
#     N, K, C = log_probs_N_K_C.shape

#     entropies_N = torch.empty(N, dtype=torch.double)

#     pbar = tqdm(total=N, desc="Entropy", leave=False)

#     @toma.execute.chunked(log_probs_N_K_C, 1024)
#     def compute(log_probs_n_K_C, start: int, end: int):
# #         print("log_probs_n_K_C.shape:", log_probs_n_K_C.shape)
#         mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
# #         nats_n_C = torch.zeros(N, C) #?
#         n = log_probs_n_K_C.shape[0]
#         nats_n = torch.zeros(n) # N

# #         for i in range(N-1):
#         for i in range(n):
# #             for j in range(i+1, N):
#             for j in range(n): # N can be bigger than batch_size -> error # one of possible solut-s -- tensor view
        
# #                 print("i, j:", (i, j))
# #                 print("eq:", torch.exp(mean_log_probs_n_C[i][:]) * torch.exp(mean_log_probs_n_C[j][:])*\
# #                 (mean_log_probs_n_C[i][:] + mean_log_probs_n_C[j][:]) - torch.exp(mean_log_probs_n_C[i][:])*\
# #                 mean_log_probs_n_C[i][:] - torch.exp(mean_log_probs_n_C[j][:]) * mean_log_probs_n_C[j][:])
# #                 nats_n_C
#                 if i != j:
# #                     print("i, j:", (i, j))
# #                     print("mean_log_probs_n_C.shape:", mean_log_probs_n_C.shape)
#                     nats_n[i] += torch.sum(torch.exp(mean_log_probs_n_C[i][:]) * torch.exp(mean_log_probs_n_C[j][:])*\
#                     (mean_log_probs_n_C[i][:] + mean_log_probs_n_C[j][:]) - torch.exp(mean_log_probs_n_C[i][:])*\
#                     mean_log_probs_n_C[i][:] - torch.exp(mean_log_probs_n_C[j][:]) * mean_log_probs_n_C[j][:])
# #                 print("nats_n_C:", nats_n_C)

# #                 print("nats_n:", nats_n) ###
    
# #         print("nats_n_C.shape:", nats_n_C.shape)
# #         print("nats_n_C:", nats_n_C)
# #         print("mean_log_probs_n_C:", mean_log_probs_n_C)
#         entropies_N[start:end].copy_(nats_n)
# #         -torch.sum(nats_n_C, dim=1)

# #         print("entropies_N:", entropies_N) ###
    
#         pbar.update(end - start)

#     pbar.close()

#     return entropies_N

In [20]:
# tensor way

# def compute_entropy_vec(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
#     N, K, C = log_probs_N_K_C.shape

#     entropies_N = torch.empty(N, dtype=torch.double)

#     pbar = tqdm(total=N, desc="Entropy", leave=False)

#     @toma.execute.chunked(log_probs_N_K_C, 1024)
#     def compute(log_probs_n_K_C, start: int, end: int):
#         mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
#         n = log_probs_n_K_C.shape[0]

#         a = torch.matmul(torch.matmul(torch.exp(mean_log_probs_n_C), torch.exp(mean_log_probs_n_C).t())*torch.ones_like(mean_log_probs_n_C).fill_diagonal_(0.0), mean_log_probs_n_C.t()).fill_diagonal_(0.0)
#         b = torch.matmul(torch.matmul(torch.exp(mean_log_probs_n_C), torch.exp(mean_log_probs_n_C).t())*torch.ones_like(mean_log_probs_n_C).fill_diagonal_(0.0), mean_log_probs_n_C.t())*torch.eye(mean_log_probs_n_C.shape[0])
#         c = torch.matmul(torch.exp(mean_log_probs_n_C), mean_log_probs_n_C.t())*torch.eye(mean_log_probs_n_C.shape[0])*n
#         d = torch.matmul(torch.exp(mean_log_probs_n_C), mean_log_probs_n_C.t())*torch.ones_like(mean_log_probs_n_C).fill_diagonal_(0.0)
        
#         nats_n = (torch.matmul(a, b) - c - d).sum(dim=1)

#         pbar.update(end - start)

#     pbar.close()

#     return entropies_N

In [21]:
# import copy

# log_probs_n_K_C = copy.copy(ys_ws.log().double())
# mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
# print(mean_log_probs_n_C.shape)

# a = torch.matmul(torch.matmul(torch.exp(mean_log_probs_n_C), torch.exp(mean_log_probs_n_C).t())*torch.ones_like(mean_log_probs_n_C).fill_diagonal_(0.0), mean_log_probs_n_C.t()).fill_diagonal_(0.0)
# b = torch.matmul(torch.matmul(torch.exp(mean_log_probs_n_C), torch.exp(mean_log_probs_n_C).t())*torch.ones_like(mean_log_probs_n_C).fill_diagonal_(0.0), mean_log_probs_n_C.t())*torch.eye(mean_log_probs_n_C.shape[0])
# c = torch.matmul(torch.exp(mean_log_probs_n_C), mean_log_probs_n_C.t())*torch.eye(mean_log_probs_n_C.shape[0])*3 # how many times
# d = torch.matmul(torch.exp(mean_log_probs_n_C), mean_log_probs_n_C.t())*torch.ones_like(mean_log_probs_n_C).fill_diagonal_(0.0) # how many times

# (torch.matmul(a, b) - c - d).sum(dim=1)

In [22]:
# torch.matmul(torch.matmul(torch.exp(mean_log_probs_n_C), torch.exp(mean_log_probs_n_C).t().fill_diagonal_(0.0)), mean_log_probs_n_C).fill_diagonal_(0.0) +\
# torch.matmul(torch.exp(mean_log_probs_n_C), torch.exp(mean_log_probs_n_C).t()).fill_diagonal_(0.0) -\
# torch.diag(torch.matmul(torch.exp(mean_log_probs_n_C), mean_log_probs_n_C.t()))*n -\
# torch.matmul(torch.exp(mean_log_probs_n_C), mean_log_probs_n_C.t()).fill_diagonal_(0.0)

In [None]:
def compute_entropy_vec(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Entropy", leave=False)

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
#         nats_n_K_C = log_probs_n_K_C * torch.exp(log_probs_n_K_C)
        n = log_probs_n_K_C.shape[0]

        # mb to use expand func as in joint prob file
        a = torch.matmul(torch.exp(log_probs_n_K_C).permute(2, 1, 0)[:, :, :, None], torch.exp(log_probs_n_K_C).permute(2, 1, 0)[:, :, None, :]) # .transpose(1, 3) 
        # .permute(2, 1, 0) # .permute(2, 1, 0)
        # .permute(2, 1, 0, 3) # .permute(3, 1, 2, 0)
#         a = torch.matmul(torch.exp(log_probs_n_K_C).permute(2, 0, 1), torch.exp(log_probs_n_K_C).permute(2, 1, 0))
    
        print("a.shape:", a.shape)
        a = a.permute(3, 2, 1, 0)
        print("a.shape:", a.shape)
        a = a.sum(dim=(1, 2)) / K
        print("a.shape:", a.shape)
        a = a.masked_fill_(torch.eye(n, C).byte(), 0.0)
        b = (a * torch.log(a)).sum(dim=1)
        c = (torch.matmul(a, mean_log_probs_n_C.t()) * torch.eye(n, n)).sum(dim=1) # n, C
        d = torch.matmul(a, mean_log_probs_n_C.t())
        d = d.masked_fill_(torch.eye(n, n).byte(), 0.0).sum(dim=1) # n, C

#         print("a.shape:", a.shape)
#         print("b.shape:", b.shape)
#         print("c.shape:", c.shape)
#         print("d.shape:", d.shape)
        nats_n = b - c - d
        entropies_N[start:end].copy_(nats_n)

        pbar.update(end - start)

    pbar.close()

    return entropies_N

In [23]:
def get_lbb_batch(log_probs_N_K_C: torch.Tensor, batch_size: int, dtype=None, device=None) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    scores_N = -compute_conditional_entropy(log_probs_N_K_C)
    scores_N += compute_entropy(log_probs_N_K_C)

    scores_N -= compute_entropy_vec(log_probs_N_K_C)

    candiate_scores, candidate_indices = torch.topk(scores_N, batch_size)

    return CandidateBatch(candiate_scores.tolist(), candidate_indices.tolist())

In [24]:
# way #2

# def get_lbb_batch(log_probs_N_K_C: torch.Tensor, batch_size: int, dtype=None, device=None) -> CandidateBatch:
#     N, K, C = log_probs_N_K_C.shape

#     batch_size = min(batch_size, N)

#     candidate_indices = []
#     candidate_scores = []

#     scores_N = -compute_conditional_entropy(log_probs_N_K_C)
#     scores_N += compute_entropy(log_probs_N_K_C)
    
#     log_probs_N_C = torch.sum(log_probs_N_K_C, dim=1)/K  # shold be according to log 
#     nats_N_C = torch.zeros(N, C)
    
#     for i in range(N-1):
#         for j in range(i+1, N):
#             nats_N_C += torch.exp(log_probs_N_C[i][:])*torch.exp(log_probs_N_C[j][:])*\
#             (log_probs_N_C[i][:] + log_probs_N_C[j][:]) - torch.exp(log_probs_N_C[i][:])*\
#             log_probs_N_C[i][:] - torch.exp(log_probs_N_C[j][:]) * log_probs_N_C[j][:]
            
#     scores_N += torch.sum(nats_N_C, dim=1)

#     candiate_scores, candidate_indices = torch.topk(scores_N, batch_size)

#     return CandidateBatch(candiate_scores.tolist(), candidate_indices.tolist())

In [25]:
get_lbb_batch(ys_ws.log().double(), 4)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[1.2683128078987922, 1.2683128078987917, 0.030715639666234917, 0.030715639666234917], indices=[2, 3, 1, 0])

In [28]:
# a = torch.ones(3, 3)
# torch.triu(a, diagonal=1)

In [27]:
# torch.mm(prob(y[i]), prob(y[j]))

# making upper triang matr without diag
# r,c = np.triu_indices(A.shape[0], 1)
# A[r,c] = 0.0
# torch.triu_indices()

#     cond_entropy_ytheta = -torch.sum(log_probs_K_C * torch.exp(log_probs_K_C), dim=0)/K
#     entropy = log_probs_C * torch.exp(log_probs_C)
#     cond_entropy_yij = -torch.sum(log_probs_C_C * torch.exp(log_probs_C_C), dim=0)/C
    # sum for all y_i, y_j