In [47]:
import torch

def codebook_usage(encoding_indices, num_codebooks):
    """
    Computes the codebook usage as a histogram of codebook occurrences.

    Args:
        encoding_indices (Tensor): Tensor of shape (B, T) containing codebook indices.
        num_codebooks (int): Total number of codebook entries.

    Returns:
        Tensor: Histogram of shape (num_codebooks,), representing the count of each codebook entry.
    """
    # Flatten indices across batch and time
    flattened_indices = encoding_indices.view(-1)  # Shape: (B*T,)

    # Compute histogram
    histogram = torch.bincount(flattened_indices, minlength=num_codebooks).float()

    return histogram

# Dummy example
if __name__ == "__main__":
    B, T = 4, 6  # Batch size and sequence length
    num_codebooks = 10  # Total number of codebook entries

    # Randomly generated encoding indices (integers from 0 to num_codebooks-1)
    torch.manual_seed(42)
    encoding_indices = torch.randint(0, num_codebooks, (B, T))
    
    # Compute histogram
    histogram = codebook_usage(encoding_indices, num_codebooks)
    
    print("Encoding Indices:")
    print(encoding_indices)
    print("\nCodebook Usage Histogram:")
    print(histogram)

Encoding Indices:
tensor([[2, 7, 6, 4, 6, 5],
        [0, 4, 0, 3, 8, 4],
        [0, 4, 1, 2, 5, 5],
        [7, 6, 9, 6, 3, 1]])

Codebook Usage Histogram:
tensor([3., 2., 2., 2., 4., 3., 4., 2., 1., 1.])


In [50]:
print(histogram+histogram)

tensor([6., 4., 4., 4., 8., 6., 8., 4., 2., 2.])


In [51]:
histogram.shape

torch.Size([10])