In [7]:
import torch
from astra.torch.al.acquisitions.base import MCAcquisition, EnsembleAcquisition


class MaxEntropyAcquisition(MCAcquisition, EnsembleAcquisition):
    def acquire_scores(self, logits: torch.Tensor):
        probs = torch.softmax(logits, dim=2)
        entropy = -torch.sum(probs * torch.log(probs), dim=2)
        score = torch.sum(entropy, dim=0)
        return score


# Create an instance of MaxEntropyAcquisition
max_entropy_acquisition = MaxEntropyAcquisition()

logits = torch.tensor(
    [
        [[0.2, 0.8], [0.7, 0.3], [0.4, 0.6]],
        [[0.6, 0.4], [0.3, 0.7], [0.8, 0.2]],
        [[0.3, 0.7], [0.5, 0.5], [0.9, 0.1]],
    ],
    dtype=torch.float32,
)


# Calculate acquisition scores using the ensemble context
ensemble_scores = max_entropy_acquisition.acquire_scores(logits)

# Calculate acquisition scores using the Monte Carlo context
mc_scores = max_entropy_acquisition.acquire_scores(mc_logits)

# Print the results
print("Ensemble Acquisition Scores:")
print(ensemble_scores)

Ensemble Acquisition Scores:
tensor([2.0118, 2.0402, 1.9574])
